TP3 exo2 fixes.
[TP_AA.git] / TP3 / exo2 / tp3_exo2.py
index 2a831877d1afa81b9794af3a9cc89541789eb242..4ca092ddd0015126efe66d2a3b02fceb8b444433 100755 (executable)
@@ -28,7 +28,7 @@ def generateData2(n):
     Generates a 2D linearly separable dataset with 2n samples.
     The third element of the sample is the label
     """
-    xb = (rand(n) * 2 - 1) / 2 - 0.5
+    xb = (rand(n) * 2 - 1) / 2 + 0.5
     yb = (rand(n) * 2 - 1) / 2
     xr = (rand(n) * 2 - 1) / 2 + 1.5
     yr = (rand(n) * 2 - 1) / 2 - 0.5
@@ -117,7 +117,7 @@ def kg(x, y, sigma=10):
 
 def perceptron_k(X, Y, k):
     coeffs = []
-    support_set = []
+    support_set = np.array([])
     # Go in the loop at least one time
     classification_error = 1
     while not classification_error == 0:
@@ -125,20 +125,31 @@ def perceptron_k(X, Y, k):
         for i in range(X.shape[0]):
             if Y[i] * f_from_k(coeffs, support_set, k, X[i]) <= 0:
                 classification_error += 1
-                support_set.append([Y[i], X[i]])
+                np.append(support_set, X[i])
                 coeffs.append(1)
             else:
                 coeffs[len(coeffs) - 1] = coeffs[len(coeffs) - 1] + 1
-    return coeffs, support_set
+    return np.array(coeffs), support_set
 
 
-print(perceptron_k(X, Y, k1))
-# print(perceptron_k(X, Y, kg))
+def f(w, x, y):
+    return w[0] + w[1] * x + w[2] * y + w[3] * x**2 + w[4] * x * y + w[5] * y**2
 
-X = apply_plongement(X, plongement_phi)
-w = perceptron_nobias(X, Y)
-print(w)
 
 pl.scatter(X[:, 0], X[:, 1], c=Y, s=training_set_size)
 pl.title(u"Perceptron - hyperplan")
+
+# coeffs, support_set = perceptron_k(X, Y, k1)
+# coeffs, support_set = perceptron_k(X, Y, kg)
+res = training_set_size
+# for c, X in zip(coeffs, support_set):
+#     pl.plot(X[0], 'xr')
+
+X = apply_plongement(X, plongement_phi)
+w = perceptron_nobias(X, Y)
+for x in range(res):
+    for y in range(res):
+        if abs(f(w, -3 / 2 + 3 * x / res, -3 / 2 + 3 * y / res)) < 0.01:
+            pl.plot(-3 / 2 + 3 * x / res, -3 / 2 + 3 * y / res, 'xb')
+
 pl.show()