Python将数据库的父子关系表画成树形结构

如何像下图一样将关系型数据库的上下级关系画成树形图

     

 

测试数据准备

 为了程序的通用性,也方便进行验证,本例采用最通用的sql写法,数据库采用SQLiter3, 如果你的数据库是ORACLE, MS-SQL, MYSQL,不用修改任何代码,只需要在调用的时候传入相应的db连接即可

def sampledata():
    db=sqlite3.connect('dbname.db')
    cur=db.cursor()
    cur.execute("create table if not exists relation(mother, child)");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1000', '1100')");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1000', '1200')");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1000', '1300')");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1000', '1400')");

    cur.execute("INSERT INTO relation(mother, child) VALUES('1200', '1210')");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1200', '1220')");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1200', '1230')");

    cur.execute("INSERT INTO relation(mother, child) VALUES('1220', '1221')");
    cur.execute("INSERT INTO relation(mother, child) VALUES('1220', '1222')");

    db.commit();

 

看起来是这样的

 程序编写

 节点信息类(相当于C/C++里面的结构体)

value就是当前节点的值

isleaf记录的是当前节点是否是叶子节点

leafcounts记录的是当前节点下面有多少个叶子节点,这个主要是为了在排版的时候知道占用多宽,

maxlevel记录的是当前节点下面最深的深度是多少,这个主要是为了在排版的时候知道生成多长的图片

class decisionnode:
    def __init__(self, value, isleaf=False, leafcounts=0, maxlevel=1):
        self.childs=[]
        self.value=value
        self.isleaf=isleaf
        self.leafcounts=leafcounts
        self.maxlevel=maxlevel
    
    def addchild(self, child):
        self.childs.append(child)

 关系树类编写

 gentree 方法用于生成整棵树,需要传值:db:数据库连接,mathervalue:根节点的数值 tablename:数据库表名 childcol:子节点在表中的对应的字段名称 mothercol:上级节点在表中的字段名称

 draweachnode 方法用于化每个节点的值和两个节点之间的连线

 drawTree 方法用于画整棵树,它会调用draweachnode然后利用draweachnode的递归画完整棵树

class RelationTree:
    def __init__(self, basewidth=100, basedepth=100):
        self.basewidth = basewidth
        self.basedepth = basedepth
        self.root=None
        
    def gentree(self, db, mothervalue, tablename, childcol, mothercol):
        self.root=decisionnode(mothervalue)
        cur=db.cursor()
        def swap_gentree(node):
            cur.execute("select %s from %s where %s = '%s'" % \
                (childcol, tablename, mothercol, node.value));
            results=cur.fetchall()
            
            #如果是叶子节点,则直接返回
            if not results:
                return decisionnode(node.value, isleaf=True, maxlevel=1)
            
            #程序运行到这里,说明是非叶子节点
            #对非叶子节点进行其下包含的叶子节点进行统计(leafcounts)
            #该节点之下最深的深度maxlevel收集
            maxlevel=1
            for each in results:
                entrynode=swap_gentree(decisionnode(each[0]))
                if(entrynode.isleaf):
                    node.leafcounts += 1
                else:
                    node.leafcounts += entrynode.leafcounts
                
                if (entrynode.maxlevel > maxlevel):
                    maxlevel = entrynode.maxlevel
                node.addchild(entrynode)
            
            node.maxlevel = maxlevel+1
            return node
        swap_gentree(self.root)



    def draweachnode(self, tree, draw, x, y):
        draw.text((x,y), tree.value, (0,0,0))
        
        if not tree.childs:
            return
        
        childs_leafcounts=[child.leafcounts if child.leafcounts else 1 for child in tree.childs]

        leafpoint=x-sum(childs_leafcounts)*self.basewidth/2

        cumpoint=0
        for childtree, point in zip(tree.childs, childs_leafcounts):
            centerpoint=leafpoint+self.basewidth*cumpoint+self.basewidth*point/2
            cumpoint += point
            draw.line((x,y, centerpoint, y+self.basedepth), (255,0,0))
            self.draweachnode(childtree, draw, centerpoint, y+self.basedepth)
            

    def drawTree(self, filename='tree.jpg'):
        width=self.root.leafcounts * self.basewidth + self.basewidth
        depth=self.root.maxlevel * self.basedepth + self.basedepth
        img=Image.new(mode="RGB", size=(width, depth), color=(255,255,255))
        draw=ImageDraw.Draw(img)
        self.draweachnode(self.root, draw, width/2, 20)
        
        img.save(filename)

 

 测试验收

确认运行代码的环境是否已经安装以下包

pillow (如果没有请pip install pillow (这个模块实际上就是PIL))

sqlite3 (如果没有请pip install sqlite3)

 

新建一个文件 drawtree.py

在文件开头导入需要的模块

import sqlite3
from PIL import Image, ImageDraw

然后将上面的代码复制进去

然后根据下面操作

 

 到这里,就会在工作目录下生成tree.jpg了,效果就像文章开头那样,

为了程序的通用性,本例的程序已经写成很通用的代码了,如果你用的是其他数据库

则不要用我的样本数据,也就是不要运行drawtree.sampledata这个函数

然后将db=sqlite3.connect('daname,db')换成你数据库对应的连接,例如你的数据库是oracle,则用db=cx_Oracle.connect('usernae/password@tnsname')

 

posted @ 2017-06-11 18:45  littlepai  阅读(11831)  评论(1编辑  收藏  举报