绘制层次结构图

绘制层次结构图

Word的SmartArt挺好,先来个免费的不美观的版本。

基于
matplotlib,networkx,graphviz,pydot

按需修改
输入内容
input_data 为输入的文本。

外观
rankdir 为指定方向。
mpatches.Rectangle 为节点外形。

比例
缩放matplotlib窗口,调整节点长宽。
调整字体大小,当前为 plt.text(fontsize=10)。

import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.patches as mpatches

plt.rcParams['font.sans-serif'] = ['SimHei']  # 正常显示中文
plt.rcParams['axes.unicode_minus'] = False    # 正常显示负号


def parse_hierarchy(input_str):
    """
    解析组织架构的层级输入,生成树结构。
    """
    lines = input_str.strip().split("\n")
    root = None
    stack = []  # 用于追踪当前的父层级节点
    hierarchy = {}

    for line in lines:
        # 计算当前行的缩进级别
        stripped_line = line.lstrip("- ").strip()
        level = (len(line) - len(line.lstrip("- "))) // 2

        # 创建当前节点
        if root is None:
            root = stripped_line
            hierarchy[root] = []
            stack.append((root, level))
        else:
            while stack and stack[-1][1] >= level:  # 回退到上一层级节点
                stack.pop()
            if stack:
                parent, _ = stack[-1]
                hierarchy[parent].append(stripped_line)
            else:
                # 如果栈为空但仍有节点,则说明输入格式有问题
                raise ValueError(f"错误的层级结构:无法找到父节点来连接 {stripped_line}")
            hierarchy[stripped_line] = []
            stack.append((stripped_line, level))

    return root, hierarchy


def plot_organization_chart(root, hierarchy, rankdir = "TB"):
    G = nx.DiGraph()

    def add_edges(parent, children):
        for child in children:
            G.add_edge(parent, child)
            add_edges(child, hierarchy.get(child, []))

    add_edges(root, hierarchy[root])
    
    # 创建一个 Pydot 的图对象
    dot = nx.drawing.nx_pydot.to_pydot(G)

    # 设置图的方向
    dot.set_rankdir(rankdir)  # 'LR' 为从左到右,'TB' 为从上到下

    # Pydot 的图对象 倒腾到 G
    G = nx.drawing.nx_pydot.from_pydot(dot)
    
    # 使用层次布局定位节点
    pos = nx.drawing.nx_pydot.graphviz_layout(G, prog='dot')

    # 缩小节点间距
    def scale_pos(pos, scale_x=1.0, scale_y=1.0):
        return {node: (x * scale_x, y * scale_y) for node, (x, y) in pos.items()}

    # 调整缩放比例以减少节点之间的间距
    pos = scale_pos(pos, scale_x=0.5, scale_y=0.5)

    plt.figure(figsize=(10, 6))

    # 创建长方形节点
    def draw_rect_node(node, pos):
        x, y = pos[node]
        rect_width, rect_height = 40, 15  # 矩形宽度和高度
        rect = mpatches.Rectangle((x - rect_width / 2, y - rect_height / 2),
                                  rect_width, rect_height,
                                  edgecolor='black', facecolor='lightblue', alpha=0.8)
        plt.gca().add_patch(rect)  # 添加矩形到当前轴
        plt.text(x, y, node, ha='center', va='center', fontsize=10, fontweight='bold', color='black')

    # 绘制节点和矩形
    rect_width, rect_height = 40, 15
    for node in G.nodes():
        draw_rect_node(node, pos)

    for parent, child in G.edges():
        x0, y0 = pos[parent]
        x1, y1 = pos[child]

        if rankdir == "TB" or rankdir == "BT":
        
            # 计算起点和终点,避开矩形区域
            if y0 > y1:  # 从上到下
                start_y = y0 - rect_height / 2
                end_y = y1 + rect_height / 2
            else:  # 从下到上
                start_y = y0 + rect_height / 2
                end_y = y1 - rect_height / 2

            # 保持 x 坐标不变
            start_x = x0
            end_x = x1

            y_mid = (start_y + end_y) / 2  # 中间的水平线 y 坐标

            # 绘制边
            plt.plot([start_x, start_x], [start_y, y_mid], "k-", linewidth=0.8)  # 垂直线
            plt.plot([start_x, end_x], [y_mid, y_mid], "k-", linewidth=0.8)      # 水平线
            plt.plot([end_x, end_x], [y_mid, end_y], "k-", linewidth=0.8)        # 垂直线
        else:
            # 计算起点和终点,避开矩形区域
            if x0 < x1:  # 从左到右
                start_x = x0 + rect_width / 2
                end_x = x1 - rect_width / 2
            else:  # 从右到左 (虽然这里不太可能出现,但为了代码的完整性,还是加上)
                start_x = x0 - rect_width / 2
                end_x = x1 + rect_width / 2

            # 保持 y 坐标不变
            start_y = y0
            end_y = y1

            x_mid = (start_x + end_x) / 2  # 中间的垂直线 x 坐标

            # 绘制边
            plt.plot([start_x, x_mid], [start_y, start_y], "k-", linewidth=0.8)  # 水平线
            plt.plot([x_mid, x_mid], [start_y, end_y], "k-", linewidth=0.8)      # 垂直线
            plt.plot([x_mid, end_x], [end_y, end_y], "k-", linewidth=0.8)        # 水平线


    plt.title("Organization Chart", fontsize=14)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


# 输入的层次结构文本
input_data = """
顶层节点
- 一级节点1
- - 二级节点1
- - 二级节点2
- - 二级节点3
- - 二级节点4
- - - 三级节点1
- - - 三级节点2
- - - 三级节点3
- 一级节点2
- - 二级节点5
- - 二级节点6
"""

try:
    root, hierarchy = parse_hierarchy(input_data)
    plot_organization_chart(root, hierarchy, "LR")
except ValueError as e:
    print(f"输入解析错误:{e}")
posted @ 2024-11-15 22:43  lusonixs  阅读(18)  评论(0编辑  收藏  举报