利用朴素贝叶斯对名字进行性别预测

完整代码

#-*-coding:utf-8-*-
import pandas as pd
import math
from collections import defaultdict

# load the data and preprocess the data

train = pd.read_csv("./data/train.txt")
test = pd.read_csv("./data/test.txt")
def loadData():
	# divide the data into two parts female and male
	names_male = train[train['gender'] == 0]
	names_female = train[train['gender'] == 1]

	totals = {
		'f':len(names_female),
		'm':len(names_male),
	}

	# use total to storage the oss
	return names_male,names_female,totals

# cal the posibilitied of the word in the name 

def calFreq(names_male,names_female,totals):
	# the word appereanced in female's name
	freq_list_f = defaultdict(int)
	for name in names_female :
		for char in name:
			freq_list_f[char] += 1.0 / totals['f']

		# the word appereanced in female's name
	freq_list_m = defaultdict(int)
	for name in names_male :
		for char in name:
			freq_list_f[char] += 1.0 / totals['m']

	return freq_list_m, freq_list_f		

# to avoid some word not disapperenced in the train data
def LaplaceSmooth(char, freq_list,total,alpha=1.0):
	count = freq_list[char * total]
	distinct_chars = len(freq_list)
	freq_smooth = (count+alpha)/(total+ distinct_chars * alpha)
	return freq_smooth

## ??

def GetLogProb(char, frequency_list, total):
    freq_smooth = LaplaceSmooth(char, frequency_list, total)
    return math.log(freq_smooth) - math.log(1 - freq_smooth)

def getBase(freq_list_m,freq_list_f,train):
	base_f = math.log(1 - train['gender'].mean())
	base_f += sum([math.log(1 - freq_list_f[char]) for char in freq_list_f])
	base_m = math.log(train['gender'].mean())
	base_m += sum([math.log(1 - freq_list_m[char]) for char in freq_list_m])
	bases = {'f': base_f, 'm': base_m}
	return bases

def calLogProb(name, bases,totals, freq_list_m,freq_list_f):
	logprob_m = bases['m']
	logprob_f = bases['f']
	for char in name:
		logprob_m += GetLogProb(char,freq_list_m,totals['m'])
		logprob_f += GetLogProb(char,freq_list_f,totals['f'])
	return {'male':logprob_m,'female':logprob_f}

def getGender(logProbs):
	return logProbs['male'] > logProbs['female']

def getResult(bases, totals, freq_list_m, freq_list_f):
	result = []
	for name in test['name']:
		LogProbs = calLogProb(name, bases, totals, freq_list_m, freq_list_f)
		gender = getGender(LogProbs)
		result.append(int(gender))
	test['pred'] = result
	print(test.head(20))
	return result
def main():
	names_male,names_female,totals = loadData()
	freq_list_m, freq_list_f = calFreq(names_male,names_female,totals)
	base = getBase(freq_list_m,freq_list_f,train)
	result = getResult(base, totals, freq_list_m, freq_list_f)


main()



posted @ 2019-01-30 14:19  今夕何夕兮  阅读(599)  评论(0编辑  收藏  举报