创新工场“工场很忙”问题基于遗传算法的实现
这个题目是创新工场在微博上发布的题目,工场很忙。
程序使用python编写,使用遗传算法实现,在python2.7下运行通过。
File.py文件
def Read(): ''' 读取文件得到输入学生面试信息 Returns : dict 学生面试信息 ''' source={} print "输入初始数据:" f=open("iw.in") try: lines=f.readlines() f.close() finally: f.close() for line in lines: if line[0:3]=='0 0': break print line[0:3] if not source.has_key(line[0]): source[line[0]]=[] source[line[0]].append(line[2]) return source def Write(result={}): ''' 把结果写入文件 :type result : dict :计算结果 ''' print "输出结果:" f=open("iw.out",'w') keys=sorted(result.keys()) try: for key in keys: line="" for value in result[key]: line+=value line+=" " print line f.write(line+'\n') finally: f.close() f.close() if __name__=="__main__": print '这是“工场很忙”题目使用遗传算法的实现中用于读取文件输入数据和写文件输出结果!'
Heredity.py文件
import random from File import * dataSource={} #:type dict 保存输入数据,即每个学生的面试数据 param:type key: string 学生名 :type value:list 项目名 dataOut={} #:type dict 保存输出结果 param:type key: string 项目名 :type value:list 面试顺序 projectNames={} #:type dict 保存项目的面试学生信息 param:type key: string 项目名:type value:list 学生名 projectNamesLen=0 #:type integer 项目数 encodeLen=0 #:type integer 编码长度 encodeLenPer=0 #:type integer 每个项目编码长度 encodeSlc=[] #:type list 每次面试可选学生二维列表 initPplLen=50 #:type integer 初始种群大小 maxGent=500 #:type integer 遗传代数 ps=0.1 #:type float 选择概率 pc=0.1 #:type float 交叉概率 pm=0.05 #:type float 变异概率 def GetAllProjects(dataSource={}): ''' 获取每个项目面试的学生信息 :type dataSource : dict :每个学生参加面试的项目信息 Returns : dict 每个项目面试的学生信息 integer 项目个数 ''' proNams={} for key in dataSource: studPro=dataSource[key] for value in studPro: if not proNams.has_key(value): proNams[value]=[] proNams[value].append(key) return proNams,len(proNams.keys()) def GetEncodeLen(dataSource={},proNams={},proNamLen=0): ''' 计算编码长度和每个项目编码长度 :type dataSource : dict :每个学生参加的面试项目信息 :type proNams : dict :每个项目面试的学生信息 :type proNamLen : integer :项目个数 Returns : integer 编码长度 integer 每个项目编码长度 ''' maxStudLen=0 maxProLen=0 for key in dataSource: length=len(dataSource[key]) if maxStudLen<length: maxStudLen=length for key in proNams: length=len(proNams[key]) if maxProLen<length: maxProLen=length length=max(maxStudLen,maxProLen) return length*proNamLen,length def InitEncodeSlc(encodeLen=0,perLen=0,proNams={}): ''' 初始化编码某个位置可选信息 :type encodeLen : integer :编码长度 :type perLen : integer :每个项目的编码长度 :type proNams : dict :每个项目的面试学生 Returns : list 可选信息 ''' es=[] pn=proNams.keys() for i in range(0,encodeLen,1): es.append(proNams[pn[i/perLen]][:]) es[i].append('0') return es def GetOneSelect(index=0,encodeSlc=[]): ''' 得到编码串中某个位置选择的学生 :type index : integer :编码选择位置索引 :type encodeSlc : list :可选信息 Returns : char 选择编码 ''' slc=random.randint(0,len(encodeSlc[index])-1) return encodeSlc[index][slc] def CheckEncodeInvalidate(ec=[],encodeLen=0,ecPerLen=0,proNams={}): ''' 对编码进行合法性检查 :type ec : list :编码信息 :type encodeLen : integer :编码长度 :type ecPerLen : integer :每个项目编码长度 :type proNams : dict :项目面试学生信息 Returns: boolean Ture合法,False不合法 ''' if len(ec)<encodeLen: return False proLen=len(proNams.keys()) for i in range(0,proLen,1): select={} for j in range(0,ecPerLen,1): cutSlc=ec[i*ecPerLen+j] if cutSlc!='0': if not select.has_key(cutSlc): select[cutSlc]=1 else: return False for k in range(i+1,proLen,1): if cutSlc==ec[k*ecPerLen+j]: return False if len(select.keys())<len(proNams[proNams.keys()[i]]): return False return True def GetOneEncode(): ''' 得到一个编码 ''' global projectNames,encodeLen,encodeLenPer,encodeSlc ec=[] while(not CheckEncodeInvalidate(ec,encodeLen,encodeLenPer,projectNames)): ec=[] for i in range(0,encodeLen,1): ec.append(GetOneSelect(i,encodeSlc)) return ec def Init(): ''' 数据初始化 ''' print '初始化数据!' global projectNames,projectNamesLen,encodeLen,encodeLenPer,encodeSlc (projectNames,projectNamesLen)=GetAllProjects(dataSource) (encodeLen,encodeLenPer)=GetEncodeLen(dataSource,projectNames,projectNamesLen) encodeSlc=InitEncodeSlc(encodeLen,encodeLenPer,projectNames) def InitPupulation(pplLen=0): ''' 初始化种群 :type pplLen : integer :初始化种群大小 Returns : list 初始化种群 ''' population=[] i=0 print '初始化种群(',pplLen,'):' while i<pplLen: population.append(GetOneEncode()) i+=1 print '第',i,'个个体:',population[i-1] return population def Sort(population=[],populationFitness=[],pplLen=0): ''' 对种群按适应度排序 :type population : list :种群编码信息 :type populationFitness : list :种群适应度 :type pplLen : integer :种群大小 ''' for i in range(0,pplLen-1,1): for j in range(0,pplLen-i-1,1): if populationFitness[j]>populationFitness[j+1]: tmp=populationFitness[j] populationFitness[j]=populationFitness[j+1] populationFitness[j+1]=tmp tmp=population[j] population[j]=population[j+1] population[j+1]=tmp def Select(population=[],populationFitness=[],pplLen=0): ''' 种群选择操作 :type population : list :种群编码信息 :type populationFitness : list :种群适应度 :type pplLen : integer :种群大小 ''' global ps Sort(population,populationFitness,pplLen) count=int(pplLen*ps) return population[0:pplLen-count]+population[0:count] def CrossPopulationOfProject(populationLeft=[],populationRight=[],index=0,encodeLen=0,encodeLenPer=0): ''' 对两个个体进行一次交叉运算 :type populationLeft : list :第一个个体 :type populationRight : list :第二个个体 :type index : integer :交换判断索引 :type encodeLen : integer :编码长度 :type encodeLenPer : integer :一个项目编码长度 Returns : list 第一个个体交叉结果 list 第二个个体交叉结果 ''' lLeft=populationLeft[0:index*encodeLenPer] lRight=populationRight[0:index*encodeLenPer] tmpLeft=populationLeft[index*encodeLenPer:(index+1)*encodeLenPer] tmpRight=populationRight[index*encodeLenPer:(index+1)*encodeLenPer] rLeft=populationLeft[(index+1)*encodeLenPer:encodeLen] rRight=populationRight[(index+1)*encodeLenPer:encodeLen] return lLeft+tmpRight+rLeft,lRight+tmpLeft+rRight def CrossPopulation(populationLeft=[],populationRight=[]): ''' 对两个个体完成交叉运算过程 :type populationLeft : list :第一个个体 :type populationRight : list :第二个个体 Returns: list 第一个个体交叉结果 list 第二个个体交叉结果 ''' global projectNames,projectNamesLen,encodeLen,encodeLenPer left=populationLeft[:] right=populationRight[:] crossIndex=0 (left,right)=CrossPopulationOfProject(left,right,crossIndex,encodeLen,encodeLenPer) crossIndex+=1 while not (CheckEncodeInvalidate(left,encodeLen,encodeLenPer,projectNames) and CheckEncodeInvalidate(right,encodeLen,encodeLenPer,projectNames)): if crossIndex>=projectNamesLen: return populationLeft,populationRight left=populationLeft[:] right=populationRight[:] (left,right)=CrossPopulationOfProject(left,right,crossIndex,encodeLen,encodeLenPer) crossIndex+=1 return left,right def Cross(population=[],pplLen=0): ''' 种群交叉运算 :type population : list :种群信息 :type pplLen : integer :种群大小 ''' global pc for i in range(0,pplLen,2): if pc<random.random(): (population[i],population[i+1])=CrossPopulation(population[i],population[i+1]) def VariationPopulation(population=[]): ''' 对单个个体进行变异运算 :type population : list :变异个体 Returns : list 变异结果 ''' global projectNames,projectNamesLen,encodeLen,encodeLenPer,encodeSlc attept=0 pop=population[:] index=random.randint(0,encodeLen-1) pop[index]=GetOneSelect(index,encodeSlc) while not CheckEncodeInvalidate(pop,encodeLen,encodeLenPer,projectNames): if attept>=encodeLen: return population pop=population[:] index=random.randint(0,encodeLen-1) pop[index]=GetOneSelect(index,encodeSlc) attept+=1 return pop def Variation(population=[],pplLen=0): ''' 种群变异运算 :type population : list :种群信息 :type pplLen : integer :种群大小 ''' global pm for i in range(0,pplLen,1): if pm<random.random(): population[i]=VariationPopulation(population[i]) def CalculateStudentTime(encode=[]): ''' 计算一次面试顺序中学生所用时间 :type encode : list :编码信息 Returns : integer 学生所用时间 ''' global dataSource,projectNamesLen,encodeLen,encodeLenPer time=0 for key in dataSource: studStart=0 studEnd=0 isStart=True for i in range(studStart,encodeLenPer,1): for j in range(0,projectNamesLen,1): if isStart and key==encode[i+j*encodeLenPer]: studStart=i studEnd=i+1 isStart=False elif not isStart and i>studEnd and key==encode[i+j*encodeLenPer]: studEnd=i+1 time+=studEnd-studStart return time def CalculateBossTime(encode=[]): ''' 计算一次面试顺序中老板所用时间 :type encode : list :编码信息 Returns : integer 老板所用时间 ''' global projectNamesLen,encodeLenPer time=0 for i in range(0,projectNamesLen,1): bossStart=-1 bossEnd=-1 for j in range(0,encodeLenPer,1): if bossStart<0 and encode[i*encodeLenPer+j]!='0': bossStart=j bossEnd=j+1 elif bossStart>=0 and encode[i*encodeLenPer+j]!='0': bossEnd=j+1 time+=bossEnd-bossStart return time def CalculateHRTime(encode=[]): ''' 计算一次面试顺序中HR所用时间 :type encode : list :编码信息 Returns : integer HR所用时间 ''' global projectNamesLen,encodeLenPer HRStart=0 HREnd=0 isStart=True for i in range(HRStart,encodeLenPer,1): for j in range(0,projectNamesLen,1): if isStart and encode[i+j*encodeLenPer]!='0': HRStart=0 HREnd=i+1 isStart=False elif not isStart and i>HREnd and encode[i+j*encodeLenPer]!='0': HREnd=i+1 return HREnd-HRStart def CalculateFitness(population=[],pplLen=0): ''' 计算种群适应度 :type population : list :种群编码信息 :type pplLen : integer :种群大小 Returns : list 适应度信息 ''' populationFitness=[] for i in range(0,pplLen,1): populationFitness.append(CalculateStudentTime(population[i])*4) populationFitness[i]+=CalculateBossTime(population[i]*2) populationFitness[i]+=CalculateHRTime(population[i]) return populationFitness def EncodeToResult(encode=[]): ''' 把计算结果编码转换为字典 :type encode : list :编码列表 Returns : dict ''' global encodeLenPer,projectNames,projectNamesLen result={} for i in range(0,projectNamesLen,1): result[projectNames.keys()[i]]=encode[i*encodeLenPer:(i+1)*encodeLenPer] return result def Heredity(): ''' 遗传算法实现过程 ''' global initPplLen,maxGent population=InitPupulation(initPplLen) populationFitness=[] print '遗传过程开始(总',initPplLen,'代):' for i in range(0,maxGent,1): populationFitness=CalculateFitness(population,initPplLen) population=Select(population,populationFitness,initPplLen) print '第',i,"代种群,最好个体:",population[0] Cross(population,initPplLen) Variation(population,initPplLen) print '遗传过程结束!' Sort(population,populationFitness,initPplLen) return EncodeToResult(population[0]) def Start(): ''' 算法开始入口 ''' global dataSource,dataOut dataSource=Read() Init() dataOut=Heredity() Write(dataOut) if __name__=="__main__": print '这是‘工场很忙’题目使用遗传算法实现过程!' print '测试数据:' print '1 1' print '1 2' print '1 3' print '2 1' print '3 1' print '3 2' print '0 0' dataSource={"1":["1","2","3"],"2":["1"],"3":["1","2"]} Init() dataOut=Heredity() print '计算结果为:' keys=sorted(dataOut.keys()) for key in keys: line="" for value in dataOut[key]: line+=value+' ' print line