稍微复杂的分类器(加入了Normalization)

class Classifier:
    def __init__(self, filename):
        self.data = []
        self.getData(filename)
        self.dimension = 2
        self.medians = []
        self.asds = []
        self.normalizeColumn()
        
        
    def normalizeColumn(self):
        """
        given a column number, normalize that column in self.data
        """
        columns = []
        for i in range(self.dimension):
            columns.append([item[1][i] for item in self.data])
            self.medians.append(self.getMedian(columns[i]))
            self.asds.append(self.getAbsoluteStandardDeviation(columns[i], self.medians[i]))
            
        length = len(self.data)
               
        for i in range(length):
            for d in range(self.dimension):
                self.data[i][1][d] = (self.data[i][1][d]-self.medians[d])/self.asds[d]
                
    def getData(self, filename):
        """
        get data from filename.txt
        """
        f = open(filename)
        f.readline()
        for line in f:
            #rawData = line.strip().split()
            #data.append((rawData[-3],[int(rawData[-2]),int(rawData[-1])],[' '.join(rawData[:-3])]))
            rawData = line.strip().split('\t')#use \t to seperate the name and others~!
            self.data.append((rawData[1], map(int,rawData[2:4]),rawData[0:1])) 
        f.close()
    
    def getMedian(self, data):
        """
        get the median of data list
        """
        length = len(data)
        sortedData = sorted(data)
        if length%2<>0:
            return sortedData[length/2]
        else:
            return (sortedData[length/2]+sortedData[(length-1)/2])/2.0

    def getAbsoluteStandardDeviation(self, alist, median):
        """
        given alist and median return absolute standard deviation
        """
        return sum([abs(x-median) for x in alist])/len(alist)
    
    def manhattan(self, v1, v2):
        """Computes the Manhattan distance."""
        distance = 0
        n = len(v1)
        for i in range(n):
            distance += abs(v1[i]-v2[i])
        return distance

    def computeNearestNeighbor(self, itemName,itemVector):
        """creates a sorted list of items based on their distance to item"""
        distances = []
        for data in self.data:
            distances.append((data[0],self.manhattan(data[1], itemVector)))
                             
        distances.sort(key=lambda a:a[1])
        #print distances
        return distances

    def classify(self, itemName, itemVector):    
        """Classify the itemName based on user ratings       
        Should really have items and users as parameters"""
        for d in range(self.dimension):
            itemVector[d] = (itemVector[d]-self.medians[d])/self.asds[d]
        nearestCategory = self.computeNearestNeighbor(itemName, itemVector)[0][0]                 
        return nearestCategory
        
#print getData("athletesTrainingSet")
def unitTest():
    list1 = [54, 72, 78, 49, 65, 63, 75, 67, 54]
    list2 = [54, 72, 78, 49, 65, 63, 75, 67, 54, 68]
    list3 = [69]
    list4 = [69, 72]
    classifier = Classifier('athletesTrainingSet.txt')
    
    m1 = classifier.getMedian(list1)
    m2 = classifier.getMedian(list2)
    m3 = classifier.getMedian(list3)
    m4 = classifier.getMedian(list4)
   
    asd1 = classifier.getAbsoluteStandardDeviation(list1, m1)
    asd2 = classifier.getAbsoluteStandardDeviation(list2, m2)
    asd3 = classifier.getAbsoluteStandardDeviation(list3, m3)
    asd4 = classifier.getAbsoluteStandardDeviation(list4, m4)
    
    assert(round(m1, 3) == 65)
    assert(round(m2, 3) == 66)
    assert(round(m3, 3) == 69)
    assert(round(m4, 3) == 70.5)
    
    assert(round(asd1, 3) == 8)
    assert(round(asd2, 3) == 7.5)
    assert(round(asd3, 3) == 0)
    assert(round(asd4, 3) == 1.5)
    
    print("getMedian and getAbsoluteStandardDeviation work correctly")
#unitTest()
    
def myTest(trainingFile,testFile):
    classifier = Classifier(trainingFile)
    f = open(testFile)
    testData = []
    for line in f:
        rawData = line.strip().split('\t')
        testData.append((rawData[1], map(int,rawData[2:4]),rawData[0:1]))
    f.close()
    
    errorList = []
    for data in testData:
        if classifier.classify(data[2],data[1])<>data[0]:
            errorList.append(data)
    print "result:"
    print "correct ratio is: ",1-len(errorList)/float(len(testData))
    print "error list:",errorList
    
    
myTest('athletesTrainingSet.txt','athletesTestSet.txt')

版权声明:本文为博主原创文章,未经博主允许不得转载。

posted on 2013-08-15 15:49  江小鱼2015  阅读(235)  评论(0编辑  收藏  举报

导航