06bc95b7fe9c2630538d914e3a8620c38bf03a7b
3 # -*- coding: utf-8 -*-
6 from mpl_toolkits
.mplot3d
import Axes3D
9 data
= np
.loadtxt("dataRegLin2D.txt")
16 ones
= np
.ones((sample
.shape
[0], 1))
17 new_sample
= np
.append(sample
, ones
, axis
=-1)
23 return np
.array(new_sample
)
26 def train_regression(X
, Y
):
28 return np
.dot(np
.dot(np
.linalg
.inv(np
.dot(np
.transpose(X
), X
)), np
.transpose(X
)), Y
)
32 return np
.dot(w
[:len(w
) - 1], x
) + w
[-1]
35 def error(X
, Y
, w
, idx
):
37 for i
in range(len(X
)):
38 y
= predict(X
[i
, :idx
], w
)
45 ax
= fig
.add_subplot(131, projection
='3d')
46 ax
.scatter(X
[:, 0], X
[:, 1], Y
)
47 w1
= train_regression(X
, Y
)
48 print(error(X
, Y
, w1
, 2))
50 ax
= fig
.add_subplot(132)
51 ax
.scatter(X
[:, 0], Y
[:])
52 w2
= train_regression(X
[:, 0], Y
)
53 print(error(X
[:, 0].reshape((len(X
), 1)), Y
, w2
, 1))
55 ax
= fig
.add_subplot(133)
56 ax
.scatter(X
[:, 1], Y
[:])
57 w3
= train_regression(X
[:, 1], Y
)
58 print(error(X
[:, 1].reshape((len(X
), 1)), Y
, w3
, 1))