使用例2.1的数据点作为输入。
class Perception:
def __init__(self, learningrate):
self.lr = learningrate # 学习率
self.wh = np.array([0.0, 0.0]) #初始化wh = [0.0, 0.0]
self.b = 0
def train(self,inputs, target):
inputs = np.asarray(inputs)
if self.check(inputs,target): # 如果被误分类,则更新wh,b
self.wh += self.lr * np.dot(inputs, target)
self.b += target * self.lr
print(self.wh, self.b)
#判断是否被正确分类
def check(self, inputs, target):
flag = False
res = 0.0
res += (np.dot(inputs, self.wh)+self.b)*target
if res <=0:
flag = True
return flag
例2.1的迭代过程为 程序输出迭代过程为:
scikit-learn实现
from sklearn.linear_model import Perceptron
# max_iter为最大迭代次数,eta0为学习率
perceptron = Perceptron(max_iter=1000, eta0=1)
# coef_init设置w初始向量,intercept_init设置初始参数b
p_fit = perceptron.fit(x, y, coef_init=np.array([[0.0, 0.0]]).reshape(-1,1),intercept_init=np.array([0]),sample_weight=np.array([1,1,1]))
# 对未知数据进行预测
p_fit.predict(np.array([4,4]).reshape(1,-1))
Ref:
1.统计学习方法-李航
2.scikit-learn官方文档