Python代码相似度计算(基于AST和SW算法)

代码相似度计算将基于AST和Smith-Waterman算法

AST (抽象语法树)

AST即Abstract Syntax Trees,是源代码的抽象语法结构的树状表示,树上的每个节点都表示源代码中的一种结构。一般的,在源代码的翻译和编译过程中,语法分析器创建出分析树,然后从分析树生成AST。

生成AST

使用Python中的ast库来生成源代码的AST

最简单的例子:

import ast
root_node = ast.parse("print('hello world')")
print(root_node)

输出: <_ast.Module object at 0x7f702f13a550>

这里返回一个object,并不能直观地看到这个树状结构,使用astpretty就能清晰地输出这棵树。

import ast
import astpretty
root_node = ast.parse("print('hello world')")
astpretty.pprint(root_node)

输出

Module(
    body=[
        Expr(
            lineno=1,
            col_offset=0,
            value=Call(
                lineno=1,
                col_offset=0,
                func=Name(lineno=1, col_offset=0, id='print', ctx=Load()),
                args=[Str(lineno=1, col_offset=6, s='hello world')],
                keywords=[],
            ),
        ),
    ],
)

现在可以更直观一点的看到这棵树的结构

AST图像

上述程序打印出来的AST虽然已经比较清晰了,但生成一个图像会更加清晰。

测试代码 test.py:

def func():
    a = 2 + 3 * 4 + 5
    return a

使用instaviz库可以查看AST结构图

import ast
import astpretty
import instaviz
from test import func

code = open("./test.py", "r").read()
code_node = ast.parse(code)
astpretty.pprint(code_node)

instaviz.show(func)

打印出来的AST结构

Module(
    body=[
        FunctionDef(
            lineno=1,
            col_offset=0,
            name='func',
            args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
            body=[
                Assign(
                    lineno=2,
                    col_offset=4,
                    targets=[Name(lineno=2, col_offset=4, id='a', ctx=Store())],
                    value=BinOp(
                        lineno=2,
                        col_offset=18,
                        left=BinOp(
                            lineno=2,
                            col_offset=8,
                            left=Num(lineno=2, col_offset=8, n=2),
                            op=Add(),
                            right=BinOp(
                                lineno=2,
                                col_offset=12,
                                left=Num(lineno=2, col_offset=12, n=3),
                                op=Mult(),
                                right=Num(lineno=2, col_offset=16, n=4),
                            ),
                        ),
                        op=Add(),
                        right=Num(lineno=2, col_offset=20, n=5),
                    ),
                ),
                Return(
                    lineno=3,
                    col_offset=4,
                    value=Name(lineno=3, col_offset=11, id='a', ctx=Load()),
                ),
            ],
            decorator_list=[],
            returns=None,
        ),
    ],
)
Bottle v0.12.21 server starting up (using WSGIRefServer())...
Listening on http://localhost:8080/
Hit Ctrl-C to quit.

访问localhost:8080可以看到生成的AST图

遍历AST

将使用ast库中的NodeVisitor类进行对树的遍历,如下是定义:

class NodeVisitor(object):
    """
    A node visitor base class that walks the abstract syntax tree and calls a
    visitor function for every node found.  This function may return a value
    which is forwarded by the `visit` method.

    This class is meant to be subclassed, with the subclass adding visitor
    methods.

    Per default the visitor functions for the nodes are ``'visit_'`` +
    class name of the node.  So a `TryFinally` node visit function would
    be `visit_TryFinally`.  This behavior can be changed by overriding
    the `visit` method.  If no visitor function exists for a node
    (return value `None`) the `generic_visit` visitor is used instead.

    Don't use the `NodeVisitor` if you want to apply changes to nodes during
    traversing.  For this a special visitor exists (`NodeTransformer`) that
    allows modifications.
    """

    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for field, value in iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, AST):
                        self.visit(item)
            elif isinstance(value, AST):
                self.visit(value)

上述代码递归地对节点进行访问,isinstance() 函数能判断一个对象是否是一个已知的类型,在generic_visit函数中,使用isinstance判断是列表还是AST,如果是列表则对其再进行遍历,如果是AST,则使用visit函数对其进行遍历。

拿上面astpretty.pprint举例,最外层的Module是AST,那么对其使用visit方法,其中的有些field对应的value,既不是list也不是AST,所以不做操作。args这个field的value是arguments,类型是AST,所以直接visit,body的value是列表,所以对其进行遍历,以此类推。

示例代码

import ast
import astpretty


class CodeVisitor(ast.NodeVisitor):
    def __init__(self):
        self.seq = []

    def generic_visit(self, node):
        ast.NodeVisitor.generic_visit(self, node)
        self.seq.append(type(node).__name__)

    def visit_FunctionDef(self, node):
        ast.NodeVisitor.generic_visit(self, node)
        self.seq.append(type(node).__name__)

    def visit_Assign(self, node):
        self.seq.append(type(node).__name__)


code = open("./hello.py", "r").read()
code_node = ast.parse(code)
visitor = CodeVisitor()
visitor.visit(code_node)
print(visitor.seq)

hello.py中的内容是hello.py

输出结果如下

arguments
Store
Name
Num
Add
Num
Mult
Num
BinOp
BinOp
Add
Num
BinOp
Assign
Load
Name
Return
FunctionDef
Module

这是对树的深度优先遍历得到的结果,这是目前得到的AST序列。

Smith-Waterman算法

史密斯-沃特曼算法_百度百科

利用该算法比较两份代码的AST序列。

基本概念

史密斯-沃特曼算法(Smith-Waterman algorithm)是一种进行局部序列比对(相对于全局比对)的算法,用于找出两个核苷酸序列或蛋白质序列之间的相似区域。该算法的目的不是进行全序列的比对,而是找出两个序列中具有高相似度的片段。

设要比对的两序列: A=a1a2...an,B=b1b2...bn

创建得分矩阵H,该矩阵的大小为n+1行,m+1列,全部初始化为0

设定: s(a,b)是组成序列的元素之间的相似性得分,Wk 表示长度为k的空位罚分。

那么设置罚分规则:

根据如下公式初始化该矩阵:

其中$H[i-1,j-1]+s(ai,bj)表示将ai,bj比对的相似性得分

$H[i-k,j]-Wk表示ai位于一段长度为k的删除的末端的得分

$H[i,j-l]-Wl表示bj位于一段长度为l的删除的末端的得分

0 表示 ai和bj到此为止无相关性

可以简化该算法,使用空位权值恒定模型的Smith Waterman算法,将空位罚分值固定。

计算得分

根据如上所述来初始化这个矩阵,先设定空位罚分和罚分规则

A = "GGTTGACTA"         # DNA序列
B = "TGTTACGG"
n, m = len(A), len(B)  # 两个序列的长度
W = 2  # 空位罚分

# 罚分规则
def score(a, b):
    if a == b:
        return 3
    else:
        return -3

下面计算得分矩阵,首行和首列都初始化为0

for i in range(1, n+1):
    for j in range(1, m+1):
        s = score(A[i-1], B[j-1])
        L = H[i-1, j-1] + s  # 左上元素+score
        P = H[i-1, j] - W    # 左方元素-W
        Q = H[i, j-1] - W    # 上方元素-W
        H[i, j] = max(L, P, Q, 0)

遍历两个下标,从1(即第二个元素)开始,按照上述公式计算得分。

完整代码

import numpy

A = "GGTTGACTA"         # DNA序列
B = "TGTTACGG"
n, m = len(A), len(B)  # 两个序列的长度
W = 2  # 空位罚分


def score(a, b):
    if a == b:
        return 3
    else:
        return -3


H = numpy.zeros([n+1, m+1], int)
for i in range(1, n+1):
    for j in range(1, m+1):
        s = score(A[i-1], B[j-1])
        L = H[i-1, j-1] + s
        P = H[i-1, j] - W
        Q = H[i, j-1] - W
        H[i, j] = max(L, P, Q, 0)

print(H)

输出如下:

[[ 0  0  0  0  0  0  0  0  0]
 [ 0  0  3  1  0  0  0  3  3]
 [ 0  0  3  1  0  0  0  3  6]
 [ 0  3  1  6  4  2  0  1  4]
 [ 0  3  1  4  9  7  5  3  2]
 [ 0  1  6  4  7  6  4  8  6]
 [ 0  0  4  3  5 10  8  6  5]
 [ 0  0  2  1  3  8 13 11  9]
 [ 0  3  1  5  4  6 11 10  8]
 [ 0  1  0  3  2  7  9  8  7]]

这就是最终得到的得分矩阵。

回溯

从矩阵H中得分最高的元素开始根据得分的来源回溯至上一位置,如此反复直至遇到得分为0的元素。

首先得到矩阵中得分最高的元素的位置

若ai=bj,则回溯到左上角单元格 ,若ai≠bj,回溯到左上角、上边、左边中值最大的单元格,若有相同最大值的单元格,优先级按照左上角、上边、左边的顺序

根据回溯路径,写出匹配字符串:
若回溯到左上角单元格,将ai添加到匹配字串A1,将bj添加到匹配字串B1;
若回溯到上边单元格,将ai添加到匹配字串A1,将_添加到匹配字串B1;
若回溯到左边单元格,将_添加到匹配字串A1,将bj添加到匹配字串B1。

在这里插入图片描述

实际上上述操作就是从最大元素开始,找到产生该元素的元素的位置。例如最大元素是13,产生13的元素的位置是其左上角元素产生的。(10+3=13)

如果在生成这些元素的时候就保存这一回溯路径,那么会更高效,于是改写计算得分矩阵的代码:

path = {}
for i in range(0, n+1):
    for j in range(0, m+1):
        if i == 0 or j == 0:
            path[point(i, j)] = None
        else:
            s = score(A[i-1], B[j-1])
            L = H[i-1, j-1] + s
            P = H[i-1, j] - W
            Q = H[i, j-1] - W
            H[i, j] = max(L, P, Q, 0)

            # 添加进路径
            path[point(i, j)] = None
            if L == H[i, j]:
                path[point(i, j)] = point(i-1, j-1)
            if P == H[i, j]:
                path[point(i, j)] = point(i-1, j)
            if Q == H[i, j]:
                path[point(i, j)] = point(i, j-1)

使用一个字典path来保存路径。在计算完H[i,j]后,判断是由左上元素、左边元素还是上边元素产生的,记录下位置,以字符串形式放入字典。

numpy.argwhere(H == numpy.max(H))

使用numpy.max找到最大元素的位置,该函数返回一个列表,如果有多个相等的最大值,那么将这些坐标都放入该列表,如果只有一个,那么该列表中只有一个元素。

end = numpy.argwhere(H == numpy.max(H))
for pos in end:
    key = point(pos[0], pos[1])
    traceback(path[key], [key])

找到最大值的坐标,然后在path找到该坐标对应的坐标,回溯找到的坐标,将沿途的坐标记录下来,直到访问到0值。

如下是回溯函数

def traceback(value, result):
    if value:
        result.append(value)
        value = path[value]
        x = int((value.split(',')[0]).strip('['))
        y = int((value.split(',')[1]).strip(']'))
    else:
        return
    if H[x, y] == 0:  # 终止条件
        print(result)
        xx = 0
        yy = 0
        s1 = ''
        s2 = ''
        md = ''
        for item in range(len(result) - 1, -1, -1):
            position = result[item]
            x = int((position.split(',')[0]).strip('['))
            y = int((position.split(',')[1]).strip(']'))
            if x == xx: # 判断是否为左方元素
                s1 += '-'
                s2 += B[y - 1]
                md += ' '
            elif y == yy: # 判断是否为上方元素
                s1 += A[x - 1]
                s2 += '-'
                md += ' '
            else: # 判断是否为左上元素
                s1 += A[x - 1]
                s2 += B[y - 1]
                md += '|'
            xx = x # 保存位置
            yy = y
        # 输出最佳匹配序列
        print('s1: %s' % s1)
        print('    ' + md)
        print('s2: %s' % s2)
    else:
        traceback(value, result)

完整代码:

import numpy

A = "GGTTGACTA"         # DNA序列
B = "TGTTACGG"
n, m = len(A), len(B)  # 两个序列的长度
W = 2  # 空位罚分

# 判分
def score(a, b):
    if a == b:
        return 3
    else:
        return -3

# 字符串
def point(x, y):
    return '[' + str(x) + ',' + str(y) + ']'

# 回溯
def traceback(value, result):
    if value:
        result.append(value)
        value = path[value]
        x = int((value.split(',')[0]).strip('['))
        y = int((value.split(',')[1]).strip(']'))
    else:
        return
    if H[x, y] == 0:  # 终止条件
        xx = 0
        yy = 0
        s1 = ''
        s2 = ''
        md = ''
        for item in range(len(result) - 1, -1, -1):
            position = result[item] # 取出坐标
            x = int((position.split(',')[0]).strip('['))
            y = int((position.split(',')[1]).strip(']'))
            if x == xx: # 判断是否为左方元素
                s1 += '-'
                s2 += B[y - 1]
                md += ' '
            elif y == yy: # 判断是否为上方元素
                s1 += A[x - 1]
                s2 += '-'
                md += ' '
            else:   # 判断是否为左上元素
                s1 += A[x - 1]
                s2 += B[y - 1]
                md += '|'
            xx = x
            yy = y
        # 输出最佳匹配序列
        print('s1: %s' % s1)
        print('    ' + md)
        print('s2: %s' % s2)
    else: # 未到终点 继续回溯
        traceback(value, result)


H = numpy.zeros([n+1, m+1], int)
path = {}
for i in range(0, n+1):
    for j in range(0, m+1):
        if i == 0 or j == 0:
            path[point(i, j)] = None
        else:
            s = score(A[i-1], B[j-1])
            L = H[i-1, j-1] + s
            P = H[i-1, j] - W
            Q = H[i, j-1] - W
            H[i, j] = max(L, P, Q, 0)

            # 添加进路径
            path[point(i, j)] = None
            if L == H[i, j]:
                path[point(i, j)] = point(i-1, j-1)
            if P == H[i, j]:
                path[point(i, j)] = point(i-1, j)
            if Q == H[i, j]:
                path[point(i, j)] = point(i, j-1)

end = numpy.argwhere(H == numpy.max(H))
for pos in end:
    key = point(pos[0], pos[1])
    traceback(path[key], [key])

输出

s1: GTTGAC
    ||| ||
s2: GTT-AC

计算相似度

将两部分功能整合

首先封装生成AST序列的功能

class CodeVisitor(ast.NodeVisitor):
    def __init__(self):
        self.seq = []

    def generic_visit(self, node):
        ast.NodeVisitor.generic_visit(self, node)
        self.seq.append(type(node).__name__)

    def visit_FunctionDef(self, node):
        ast.NodeVisitor.generic_visit(self, node)
        self.seq.append(type(node).__name__)

    def visit_Assign(self, node):
        self.seq.append(type(node).__name__)


class CodeParse(object):
    def __init__(self, fileA, fileB):
        self.visitorB = None
        self.visitorA = None
        self.codeA = open(fileA, encoding="utf-8").read()
        self.codeB = open(fileB, encoding="utf-8").read()
        self.nodeA = ast.parse(self.codeA)
        self.nodeB = ast.parse(self.codeB)
        self.seqA = ""
        self.seqB = ""
        self.work()

    def work(self):
        self.visitorA = CodeVisitor()
        self.visitorA.visit(self.nodeA)
        self.seqA = self.visitorA.seq
        self.visitorB = CodeVisitor()
        self.visitorB.visit(self.nodeB)
        self.seqB = self.visitorB.seq

之后将SW算法封装进类

class CalculateSimilarity(object):
    def __init__(self, A, B, W, M, N):
        self.A = A
        self.B = B
        self.W = W
        self.M = M
        self.N = N
        self.similarity = []
        self.SimthWaterman(self.A, self.B, self.W)

    def score(self,a, b):
        if a == b:
            return self.M
        else:
            return self.N

    def traceback(self,A, B, H, path, value, result):
        if value:
            temp = value[0]
            result.append(temp)
            value = path[temp]
            x = int((temp.split(',')[0]).strip('['))
            y = int((temp.split(',')[1]).strip(']'))
        else:
            return
        if H[x, y] == 0:  # 终止条件
            xx = 0
            yy = 0
            sim = 0
            for item in range(len(result) - 2, -1, -1):
                position = result[item]
                x = int((position.split(',')[0]).strip('['))
                y = int((position.split(',')[1]).strip(']'))
                if x == xx:
                    pass
                elif y == yy:
                    pass
                else:
                    sim = sim + 1
                xx = x
                yy = y
            self.similarity.append(sim * 2 / (len(A) + len(B)))

        else:
            self.traceback(A, B, H, path, value, result)

    def SimthWaterman(self, A, B, W):
        n, m = len(A), len(B)
        H = numpy.zeros([n + 1, m + 1], int)
        path = {}
        for i in range(0, n + 1):
            for j in range(0, m + 1):
                if i == 0 or j == 0:
                    path[point(i, j)] = []
                else:
                    s = self.score(A[i - 1], B[j - 1])
                    L = H[i - 1, j - 1] + s
                    P = H[i - 1, j] - W
                    Q = H[i, j - 1] - W
                    H[i, j] = max(L, P, Q, 0)

                    # 添加进路径
                    path[point(i, j)] = []
                    if math.floor(L) == H[i, j]:
                        path[point(i, j)].append(point(i - 1, j - 1))
                    if math.floor(P) == H[i, j]:
                        path[point(i, j)].append(point(i - 1, j))
                    if math.floor(Q) == H[i, j]:
                        path[point(i, j)].append(point(i, j - 1))

        end = numpy.argwhere(H == numpy.max(H))
        for pos in end:
            key = point(pos[0], pos[1])
            value = path[key]
            result = [key]
            self.traceback(A, B, H, path, value, result)

    def Answer(self): # 取均值
        return sum(self.similarity) / len(self.similarity)

完整代码

import math
import numpy
import ast

Similarity = []


def point(x, y):
    return '[' + str(x) + ',' + str(y) + ']'


class CodeVisitor(ast.NodeVisitor):
    def __init__(self):
        self.seq = []

    def generic_visit(self, node):
        ast.NodeVisitor.generic_visit(self, node)
        self.seq.append(type(node).__name__)

    def visit_FunctionDef(self, node):
        ast.NodeVisitor.generic_visit(self, node)
        self.seq.append(type(node).__name__)

    def visit_Assign(self, node):
        self.seq.append(type(node).__name__)


class CodeParse(object):
    def __init__(self, fileA, fileB):
        self.visitorB = None
        self.visitorA = None
        self.codeA = open(fileA, encoding="utf-8").read()
        self.codeB = open(fileB, encoding="utf-8").read()
        self.nodeA = ast.parse(self.codeA)
        self.nodeB = ast.parse(self.codeB)
        self.seqA = ""
        self.seqB = ""
        self.work()

    def work(self):
        self.visitorA = CodeVisitor()
        self.visitorA.visit(self.nodeA)
        self.seqA = self.visitorA.seq
        self.visitorB = CodeVisitor()
        self.visitorB.visit(self.nodeB)
        self.seqB = self.visitorB.seq


class CalculateSimilarity(object):
    def __init__(self, A, B, W, M, N):
        self.A = A
        self.B = B
        self.W = W
        self.M = M
        self.N = N
        self.similarity = []
        self.SimthWaterman(self.A, self.B, self.W)

    def score(self,a, b):
        if a == b:
            return self.M
        else:
            return self.N

    def traceback(self,A, B, H, path, value, result):
        if value:
            temp = value[0]
            result.append(temp)
            value = path[temp]
            x = int((temp.split(',')[0]).strip('['))
            y = int((temp.split(',')[1]).strip(']'))
        else:
            return
        if H[x, y] == 0:  # 终止条件
            xx = 0
            yy = 0
            sim = 0
            for item in range(len(result) - 2, -1, -1):
                position = result[item]
                x = int((position.split(',')[0]).strip('['))
                y = int((position.split(',')[1]).strip(']'))
                if x == xx:
                    pass
                elif y == yy:
                    pass
                else:
                    sim = sim + 1
                xx = x
                yy = y
            self.similarity.append(sim * 2 / (len(A) + len(B)))

        else:
            self.traceback(A, B, H, path, value, result)

    def SimthWaterman(self, A, B, W):
        n, m = len(A), len(B)
        H = numpy.zeros([n + 1, m + 1], int)
        path = {}
        for i in range(0, n + 1):
            for j in range(0, m + 1):
                if i == 0 or j == 0:
                    path[point(i, j)] = []
                else:
                    s = self.score(A[i - 1], B[j - 1])
                    L = H[i - 1, j - 1] + s
                    P = H[i - 1, j] - W
                    Q = H[i, j - 1] - W
                    H[i, j] = max(L, P, Q, 0)

                    # 添加进路径
                    path[point(i, j)] = []
                    if math.floor(L) == H[i, j]:
                        path[point(i, j)].append(point(i - 1, j - 1))
                    if math.floor(P) == H[i, j]:
                        path[point(i, j)].append(point(i - 1, j))
                    if math.floor(Q) == H[i, j]:
                        path[point(i, j)].append(point(i, j - 1))

        end = numpy.argwhere(H == numpy.max(H))
        for pos in end:
            key = point(pos[0], pos[1])
            value = path[key]
            result = [key]
            self.traceback(A, B, H, path, value, result)

    def Answer(self): # 取均值
        return sum(self.similarity) / len(self.similarity)


def main():
    AST = CodeParse("test1.py","test2.py")
    RES = CalculateSimilarity(AST.seqA, AST.seqB, 1, 1, -1/3)
    print(RES.Answer())


if __name__ == "__main__":
    main()

向CodeParse传入两个python文件名,即可计算出最终的相似度值.

posted @ 2022-05-31 11:27  N3ptune  阅读(2960)  评论(0编辑  收藏  举报