Genetic Algorithms with python 学习笔记ch5

利用遗传算法解决图着色问题

我们使用4个颜色来为美国地图着色,但要确保没有相邻的州具有相同的颜色。程序实现的过程中将不过分讨论地图的表示,只需各个州之间的相邻关系简单的表示就可以了。
graphColoringTests.py 完整代码如下;

import csv
import unittest
import datetime
import genetic

def load_data(localFileName):
    with open(localFileName, mode= 'r') as infile:
        reader = csv.reader(infile)
        lookup = {row[0]: row[1].split(';') for row in reader if row}
    return lookup

class Rule:
    def __init__(self, node, adjacent):
        if node < adjacent:
            node, adjacent = adjacent, node
        self.Node = node
        self.Adjacent = adjacent

    def __eq__(self, other):
        return self.Node == other.Node and \
            self.Adjacent == other.Node

    def __hash__(self):
        return hash(self.Node) * 397 ^ hash(self.Adjacent)

    def __str__(self):
        return self.Node + " -> " + self.Adjacent

    def IsValid(self, genes, nodeIndexLookup):
        index = nodeIndexLookup[self.Node]
        adjacentStateIndex = nodeIndexLookup[self.Adjacent]
        return genes[index] != genes[adjacentStateIndex]

def build_rules(items):
    rulesAdded = {}

    for state, adjacent in items.items():
        for adjacentState in adjacent:
            if adjacentState == '':
                continue
            rule = Rule(state, adjacentState)
            if rule in rulesAdded:
                rulesAdded[rule] += 1
            else:
                rulesAdded[rule] = 1

        for k,v in rulesAdded.items():
            if v != 2:
                print("rule {} is not bidirectional".format(k))

        return rulesAdded.keys()

class GraphColoringTests(unittest.TestCase):
    def test(self):
        states = load_data("adjacent_states.csv")
        rules = build_rules(states)
        optimalValue = len (rules)
        stateIndexLookup ={key: index
                           for index, key in enumerate(sorted(states))}

        colors = ["Orange", "Yellow", "Green", "Blue"]
        colorLookup = {color[0]: color for color in colors}
        geneset = list(colorLookup.keys())

        startTime = datetime.datetime.now()

        def fnDisplay(candidate):
            display(candidate, startTime)

        def fnGetFitness(genes):
            return get_fitness(genes, rules, stateIndexLookup)

        best = genetic.get_best(fnGetFitness, len(states),
                                optimalValue, geneset, fnDisplay)
        self.assertTrue(not optimalValue > best.Fitness)

        keys = sorted(states.keys())
        for index in range(len(states)):
            print(keys[index] + " is " + colorLookup[best.Genes[index]])

def display(candidate, startTime):
    timeDiff = datetime.datetime.now() - startTime
    print("{}\t{}\t{}".format(
        ''.join(map(str, candidate.Genes)),
        candidate.Fitness,
        timeDiff))

def get_fitness(genes, rules, stateIndexLookup):
    rulesThatPass = sum(1 for rule in rules
                        if rule.IsValid(genes, stateIndexLookup))
    return rulesThatPass

第一部分为读取文件,文件中的每一行格式为:州1,(州2;州3;州4),表示州1与有分号分隔开的其他三个州相邻。下面的函数 load_data 功能为读取文件然后将有关各州之间相邻关系的键值对进行返回。

def load_data(localFileName):
    with open(localFileName, mode= 'r') as infile:
        reader = csv.reader(infile)
        lookup = {row[0]: row[1].split(';') for row in reader if row}
    return lookup

下面所描述的 Rule 类主要表示一对相邻关系。

class Rule:
    def __init__(self, node, adjacent):
        if node < adjacent:
            node, adjacent = adjacent, node
        self.Node = node
        self.Adjacent = adjacent

    def __eq__(self, other):
        return self.Node == other.Node and \
            self.Adjacent == other.Node

    def __hash__(self):
        return hash(self.Node) * 397 ^ hash(self.Adjacent)

    def __str__(self):
        return self.Node + " -> " + self.Adjacent

    def IsValid(self, genes, nodeIndexLookup):
        index = nodeIndexLookup[self.Node]
        adjacentStateIndex = nodeIndexLookup[self.Adjacent]
        return genes[index] != genes[adjacentStateIndex]

下面的函数 build_rules 表示为每一对相邻的点建立相邻关系,并且保证相邻关系是双向的。

def build_rules(items):
    rulesAdded = {}

    for state, adjacent in items.items():
        for adjacentState in adjacent:
            if adjacentState == '':
                continue
            rule = Rule(state, adjacentState)
            if rule in rulesAdded:
                rulesAdded[rule] += 1
            else:
                rulesAdded[rule] = 1

        for k,v in rulesAdded.items():
            if v != 2:
                print("rule {} is not bidirectional".format(k))

        return rulesAdded.keys()

下面这个类表示测试图着色问题的具体流程:
首先读取文件,然后建立邻接关系,在调用engine求最优解。

class GraphColoringTests(unittest.TestCase):
    def test(self):
        states = load_data("adjacent_states.csv")
        rules = build_rules(states)
        optimalValue = len (rules)
        stateIndexLookup ={key: index
                           for index, key in enumerate(sorted(states))}

        colors = ["Orange", "Yellow", "Green", "Blue"]
        colorLookup = {color[0]: color for color in colors}
        geneset = list(colorLookup.keys())

        startTime = datetime.datetime.now()

        def fnDisplay(candidate):
            display(candidate, startTime)

        def fnGetFitness(genes):
            return get_fitness(genes, rules, stateIndexLookup)

        best = genetic.get_best(fnGetFitness, len(states),
                                optimalValue, geneset, fnDisplay)
        self.assertTrue(not optimalValue > best.Fitness)

        keys = sorted(states.keys())
        for index in range(len(states)):
            print(keys[index] + " is " + colorLookup[best.Genes[index]])

另一个重要的部分就是计算适应值,这里的适应值就是计算有效的着色对数,即当两个相邻的州为不同颜色,那么这一对州之间的着色是有效的,如果有效的对数等于所有相邻的关系数,则为最优解。

def get_fitness(genes, rules, stateIndexLookup):
    rulesThatPass = sum(1 for rule in rules
                        if rule.IsValid(genes, stateIndexLookup))
    return rulesThatPass
posted @ 2020-07-31 23:00  idella  阅读(149)  评论(0编辑  收藏  举报