感知机学习算法及其实现

感知机学习算法及其实现


这篇文章给出了感知机学习算法,并用Python实现之

python代码如下:

# coding:utf-8

import numpy as np
import re

class Perceptron():
    def __init__(self):
        """
        设定初值w0, b0
        设定学习速率(又称步长)
        """
        self.w = np.array([0 for _ in xrange(2)])
        self.b = 0
        self.learning_rate = 1

    def readDataFromFile(self, filename):
        """
        从文件中读取训练数据
        :param filename: 存有训练数据的文件,格式为"+1 3 3\n"
        :return:
        """
        inputs = []
        with open(filename) as fi:
            for line in fi:
                ws = re.split(" |\n", line)
                inputs.append([int(ws[0]), int(ws[1]), int(ws[2])])
        return inputs

    def misClassified(self, data):
        """
        判断是否是误分类的点
        :param data: 一条数据
        :return: 如果是误分,就返回True,否则返回False
        """
        x = np.array(data[1:])
        res = data[0] * (np.dot(self.w, x) + self.b)
        if res <= 0:
            return True
        else:
            return False

    def train(self, inputs):
        """
        读取训练数据,更新权值,直到没有误分类点
        :param inputs: 训练数据
        """
        flag = True
        while flag:
            flag = False
            for data in inputs:
                x = np.array(data[1:])
                y = data[0]
                if self.misClassified(data):
                    flag = True
                    self.w += self.learning_rate * y * x
                    self.b += self.learning_rate * y

    def test(self, testfile, outfile):
        """
        读取测试文件中的测试数据,并对其类型进行预测
        :param testfile: 存有测试数据的文件
        :param outfile:  输出的预测值: {+1, -1}
        """
        with open(testfile) as tf:
            with open(outfile, 'w') as of:
                for line in tf:
                    ws = re.split(" |\n", line)
                    x = np.array([int(ws[0]), int(ws[1])])
                    value = np.dot(self.w, x) + self.b
                    if value >= 0:
                        pre = '+1'
                    else:
                        pre = '-1'
                    of.write(pre + '\n')


if __name__ == '__main__':
    per = Perceptron()
    inputs = per.readDataFromFile("..\data\\trainSet.txt")
    per.train(inputs)
    print "weight vector: " + str(per.w)
    print "bias: " + str(per.b)
    per.test("..\data\\testSet.txt", "..\data\predict.txt")

posted @ 2015-08-27 18:51  TinaYo  阅读(398)  评论(0编辑  收藏  举报