Python代码相似度计算(基于AST和SW算法)
代码相似度计算将基于AST和Smith-Waterman算法
AST (抽象语法树)
AST即Abstract Syntax Trees,是源代码的抽象语法结构的树状表示,树上的每个节点都表示源代码中的一种结构。一般的,在源代码的翻译和编译过程中,语法分析器创建出分析树,然后从分析树生成AST。
生成AST
使用Python中的ast库来生成源代码的AST
最简单的例子:
import ast
root_node = ast.parse("print('hello world')")
print(root_node)
输出: <_ast.Module object at 0x7f702f13a550>
这里返回一个object,并不能直观地看到这个树状结构,使用astpretty就能清晰地输出这棵树。
import ast
import astpretty
root_node = ast.parse("print('hello world')")
astpretty.pprint(root_node)
输出
Module(
body=[
Expr(
lineno=1,
col_offset=0,
value=Call(
lineno=1,
col_offset=0,
func=Name(lineno=1, col_offset=0, id='print', ctx=Load()),
args=[Str(lineno=1, col_offset=6, s='hello world')],
keywords=[],
),
),
],
)
现在可以更直观一点的看到这棵树的结构
AST图像
上述程序打印出来的AST虽然已经比较清晰了,但生成一个图像会更加清晰。
测试代码 test.py:
def func():
a = 2 + 3 * 4 + 5
return a
使用instaviz库可以查看AST结构图
import ast
import astpretty
import instaviz
from test import func
code = open("./test.py", "r").read()
code_node = ast.parse(code)
astpretty.pprint(code_node)
instaviz.show(func)
打印出来的AST结构
Module(
body=[
FunctionDef(
lineno=1,
col_offset=0,
name='func',
args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[
Assign(
lineno=2,
col_offset=4,
targets=[Name(lineno=2, col_offset=4, id='a', ctx=Store())],
value=BinOp(
lineno=2,
col_offset=18,
left=BinOp(
lineno=2,
col_offset=8,
left=Num(lineno=2, col_offset=8, n=2),
op=Add(),
right=BinOp(
lineno=2,
col_offset=12,
left=Num(lineno=2, col_offset=12, n=3),
op=Mult(),
right=Num(lineno=2, col_offset=16, n=4),
),
),
op=Add(),
right=Num(lineno=2, col_offset=20, n=5),
),
),
Return(
lineno=3,
col_offset=4,
value=Name(lineno=3, col_offset=11, id='a', ctx=Load()),
),
],
decorator_list=[],
returns=None,
),
],
)
Bottle v0.12.21 server starting up (using WSGIRefServer())...
Listening on http://localhost:8080/
Hit Ctrl-C to quit.
访问localhost:8080可以看到生成的AST图
遍历AST
将使用ast库中的NodeVisitor类进行对树的遍历,如下是定义:
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
visitor function for every node found. This function may return a value
which is forwarded by the `visit` method.
This class is meant to be subclassed, with the subclass adding visitor
methods.
Per default the visitor functions for the nodes are ``'visit_'`` +
class name of the node. So a `TryFinally` node visit function would
be `visit_TryFinally`. This behavior can be changed by overriding
the `visit` method. If no visitor function exists for a node
(return value `None`) the `generic_visit` visitor is used instead.
Don't use the `NodeVisitor` if you want to apply changes to nodes during
traversing. For this a special visitor exists (`NodeTransformer`) that
allows modifications.
"""
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)
上述代码递归地对节点进行访问,isinstance() 函数能判断一个对象是否是一个已知的类型,在generic_visit函数中,使用isinstance判断是列表还是AST,如果是列表则对其再进行遍历,如果是AST,则使用visit函数对其进行遍历。
拿上面astpretty.pprint举例,最外层的Module是AST,那么对其使用visit方法,其中的有些field对应的value,既不是list也不是AST,所以不做操作。args这个field的value是arguments,类型是AST,所以直接visit,body的value是列表,所以对其进行遍历,以此类推。
示例代码
import ast
import astpretty
class CodeVisitor(ast.NodeVisitor):
def __init__(self):
self.seq = []
def generic_visit(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.seq.append(type(node).__name__)
def visit_FunctionDef(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.seq.append(type(node).__name__)
def visit_Assign(self, node):
self.seq.append(type(node).__name__)
code = open("./hello.py", "r").read()
code_node = ast.parse(code)
visitor = CodeVisitor()
visitor.visit(code_node)
print(visitor.seq)
hello.py中的内容是hello.py
输出结果如下
arguments
Store
Name
Num
Add
Num
Mult
Num
BinOp
BinOp
Add
Num
BinOp
Assign
Load
Name
Return
FunctionDef
Module
这是对树的深度优先遍历得到的结果,这是目前得到的AST序列。
Smith-Waterman算法
利用该算法比较两份代码的AST序列。
基本概念
史密斯-沃特曼算法(Smith-Waterman algorithm)是一种进行局部序列比对(相对于全局比对)的算法,用于找出两个核苷酸序列或蛋白质序列之间的相似区域。该算法的目的不是进行全序列的比对,而是找出两个序列中具有高相似度的片段。
设要比对的两序列: A=a1a2...an,B=b1b2...bn
创建得分矩阵H,该矩阵的大小为n+1行,m+1列,全部初始化为0
设定: s(a,b)是组成序列的元素之间的相似性得分,Wk 表示长度为k的空位罚分。
那么设置罚分规则:
根据如下公式初始化该矩阵:
其中$H[i-1,j-1]+s(ai,bj)表示将ai,bj比对的相似性得分
$H[i-k,j]-Wk表示ai位于一段长度为k的删除的末端的得分
$H[i,j-l]-Wl表示bj位于一段长度为l的删除的末端的得分
0 表示 ai和bj到此为止无相关性
可以简化该算法,使用空位权值恒定模型的Smith Waterman算法,将空位罚分值固定。
计算得分
根据如上所述来初始化这个矩阵,先设定空位罚分和罚分规则
A = "GGTTGACTA" # DNA序列
B = "TGTTACGG"
n, m = len(A), len(B) # 两个序列的长度
W = 2 # 空位罚分
# 罚分规则
def score(a, b):
if a == b:
return 3
else:
return -3
下面计算得分矩阵,首行和首列都初始化为0
for i in range(1, n+1):
for j in range(1, m+1):
s = score(A[i-1], B[j-1])
L = H[i-1, j-1] + s # 左上元素+score
P = H[i-1, j] - W # 左方元素-W
Q = H[i, j-1] - W # 上方元素-W
H[i, j] = max(L, P, Q, 0)
遍历两个下标,从1(即第二个元素)开始,按照上述公式计算得分。
完整代码
import numpy
A = "GGTTGACTA" # DNA序列
B = "TGTTACGG"
n, m = len(A), len(B) # 两个序列的长度
W = 2 # 空位罚分
def score(a, b):
if a == b:
return 3
else:
return -3
H = numpy.zeros([n+1, m+1], int)
for i in range(1, n+1):
for j in range(1, m+1):
s = score(A[i-1], B[j-1])
L = H[i-1, j-1] + s
P = H[i-1, j] - W
Q = H[i, j-1] - W
H[i, j] = max(L, P, Q, 0)
print(H)
输出如下:
[[ 0 0 0 0 0 0 0 0 0]
[ 0 0 3 1 0 0 0 3 3]
[ 0 0 3 1 0 0 0 3 6]
[ 0 3 1 6 4 2 0 1 4]
[ 0 3 1 4 9 7 5 3 2]
[ 0 1 6 4 7 6 4 8 6]
[ 0 0 4 3 5 10 8 6 5]
[ 0 0 2 1 3 8 13 11 9]
[ 0 3 1 5 4 6 11 10 8]
[ 0 1 0 3 2 7 9 8 7]]
这就是最终得到的得分矩阵。
回溯
从矩阵H中得分最高的元素开始根据得分的来源回溯至上一位置,如此反复直至遇到得分为0的元素。
首先得到矩阵中得分最高的元素的位置
若ai=bj,则回溯到左上角单元格 ,若ai≠bj,回溯到左上角、上边、左边中值最大的单元格,若有相同最大值的单元格,优先级按照左上角、上边、左边的顺序
根据回溯路径,写出匹配字符串:
若回溯到左上角单元格,将ai添加到匹配字串A1,将bj添加到匹配字串B1;
若回溯到上边单元格,将ai添加到匹配字串A1,将_添加到匹配字串B1;
若回溯到左边单元格,将_添加到匹配字串A1,将bj添加到匹配字串B1。
实际上上述操作就是从最大元素开始,找到产生该元素的元素的位置。例如最大元素是13,产生13的元素的位置是其左上角元素产生的。(10+3=13)
如果在生成这些元素的时候就保存这一回溯路径,那么会更高效,于是改写计算得分矩阵的代码:
path = {}
for i in range(0, n+1):
for j in range(0, m+1):
if i == 0 or j == 0:
path[point(i, j)] = None
else:
s = score(A[i-1], B[j-1])
L = H[i-1, j-1] + s
P = H[i-1, j] - W
Q = H[i, j-1] - W
H[i, j] = max(L, P, Q, 0)
# 添加进路径
path[point(i, j)] = None
if L == H[i, j]:
path[point(i, j)] = point(i-1, j-1)
if P == H[i, j]:
path[point(i, j)] = point(i-1, j)
if Q == H[i, j]:
path[point(i, j)] = point(i, j-1)
使用一个字典path来保存路径。在计算完H[i,j]后,判断是由左上元素、左边元素还是上边元素产生的,记录下位置,以字符串形式放入字典。
numpy.argwhere(H == numpy.max(H))
使用numpy.max找到最大元素的位置,该函数返回一个列表,如果有多个相等的最大值,那么将这些坐标都放入该列表,如果只有一个,那么该列表中只有一个元素。
end = numpy.argwhere(H == numpy.max(H))
for pos in end:
key = point(pos[0], pos[1])
traceback(path[key], [key])
找到最大值的坐标,然后在path找到该坐标对应的坐标,回溯找到的坐标,将沿途的坐标记录下来,直到访问到0值。
如下是回溯函数
def traceback(value, result):
if value:
result.append(value)
value = path[value]
x = int((value.split(',')[0]).strip('['))
y = int((value.split(',')[1]).strip(']'))
else:
return
if H[x, y] == 0: # 终止条件
print(result)
xx = 0
yy = 0
s1 = ''
s2 = ''
md = ''
for item in range(len(result) - 1, -1, -1):
position = result[item]
x = int((position.split(',')[0]).strip('['))
y = int((position.split(',')[1]).strip(']'))
if x == xx: # 判断是否为左方元素
s1 += '-'
s2 += B[y - 1]
md += ' '
elif y == yy: # 判断是否为上方元素
s1 += A[x - 1]
s2 += '-'
md += ' '
else: # 判断是否为左上元素
s1 += A[x - 1]
s2 += B[y - 1]
md += '|'
xx = x # 保存位置
yy = y
# 输出最佳匹配序列
print('s1: %s' % s1)
print(' ' + md)
print('s2: %s' % s2)
else:
traceback(value, result)
完整代码:
import numpy
A = "GGTTGACTA" # DNA序列
B = "TGTTACGG"
n, m = len(A), len(B) # 两个序列的长度
W = 2 # 空位罚分
# 判分
def score(a, b):
if a == b:
return 3
else:
return -3
# 字符串
def point(x, y):
return '[' + str(x) + ',' + str(y) + ']'
# 回溯
def traceback(value, result):
if value:
result.append(value)
value = path[value]
x = int((value.split(',')[0]).strip('['))
y = int((value.split(',')[1]).strip(']'))
else:
return
if H[x, y] == 0: # 终止条件
xx = 0
yy = 0
s1 = ''
s2 = ''
md = ''
for item in range(len(result) - 1, -1, -1):
position = result[item] # 取出坐标
x = int((position.split(',')[0]).strip('['))
y = int((position.split(',')[1]).strip(']'))
if x == xx: # 判断是否为左方元素
s1 += '-'
s2 += B[y - 1]
md += ' '
elif y == yy: # 判断是否为上方元素
s1 += A[x - 1]
s2 += '-'
md += ' '
else: # 判断是否为左上元素
s1 += A[x - 1]
s2 += B[y - 1]
md += '|'
xx = x
yy = y
# 输出最佳匹配序列
print('s1: %s' % s1)
print(' ' + md)
print('s2: %s' % s2)
else: # 未到终点 继续回溯
traceback(value, result)
H = numpy.zeros([n+1, m+1], int)
path = {}
for i in range(0, n+1):
for j in range(0, m+1):
if i == 0 or j == 0:
path[point(i, j)] = None
else:
s = score(A[i-1], B[j-1])
L = H[i-1, j-1] + s
P = H[i-1, j] - W
Q = H[i, j-1] - W
H[i, j] = max(L, P, Q, 0)
# 添加进路径
path[point(i, j)] = None
if L == H[i, j]:
path[point(i, j)] = point(i-1, j-1)
if P == H[i, j]:
path[point(i, j)] = point(i-1, j)
if Q == H[i, j]:
path[point(i, j)] = point(i, j-1)
end = numpy.argwhere(H == numpy.max(H))
for pos in end:
key = point(pos[0], pos[1])
traceback(path[key], [key])
输出
s1: GTTGAC
||| ||
s2: GTT-AC
计算相似度
将两部分功能整合
首先封装生成AST序列的功能
class CodeVisitor(ast.NodeVisitor):
def __init__(self):
self.seq = []
def generic_visit(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.seq.append(type(node).__name__)
def visit_FunctionDef(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.seq.append(type(node).__name__)
def visit_Assign(self, node):
self.seq.append(type(node).__name__)
class CodeParse(object):
def __init__(self, fileA, fileB):
self.visitorB = None
self.visitorA = None
self.codeA = open(fileA, encoding="utf-8").read()
self.codeB = open(fileB, encoding="utf-8").read()
self.nodeA = ast.parse(self.codeA)
self.nodeB = ast.parse(self.codeB)
self.seqA = ""
self.seqB = ""
self.work()
def work(self):
self.visitorA = CodeVisitor()
self.visitorA.visit(self.nodeA)
self.seqA = self.visitorA.seq
self.visitorB = CodeVisitor()
self.visitorB.visit(self.nodeB)
self.seqB = self.visitorB.seq
之后将SW算法封装进类
class CalculateSimilarity(object):
def __init__(self, A, B, W, M, N):
self.A = A
self.B = B
self.W = W
self.M = M
self.N = N
self.similarity = []
self.SimthWaterman(self.A, self.B, self.W)
def score(self,a, b):
if a == b:
return self.M
else:
return self.N
def traceback(self,A, B, H, path, value, result):
if value:
temp = value[0]
result.append(temp)
value = path[temp]
x = int((temp.split(',')[0]).strip('['))
y = int((temp.split(',')[1]).strip(']'))
else:
return
if H[x, y] == 0: # 终止条件
xx = 0
yy = 0
sim = 0
for item in range(len(result) - 2, -1, -1):
position = result[item]
x = int((position.split(',')[0]).strip('['))
y = int((position.split(',')[1]).strip(']'))
if x == xx:
pass
elif y == yy:
pass
else:
sim = sim + 1
xx = x
yy = y
self.similarity.append(sim * 2 / (len(A) + len(B)))
else:
self.traceback(A, B, H, path, value, result)
def SimthWaterman(self, A, B, W):
n, m = len(A), len(B)
H = numpy.zeros([n + 1, m + 1], int)
path = {}
for i in range(0, n + 1):
for j in range(0, m + 1):
if i == 0 or j == 0:
path[point(i, j)] = []
else:
s = self.score(A[i - 1], B[j - 1])
L = H[i - 1, j - 1] + s
P = H[i - 1, j] - W
Q = H[i, j - 1] - W
H[i, j] = max(L, P, Q, 0)
# 添加进路径
path[point(i, j)] = []
if math.floor(L) == H[i, j]:
path[point(i, j)].append(point(i - 1, j - 1))
if math.floor(P) == H[i, j]:
path[point(i, j)].append(point(i - 1, j))
if math.floor(Q) == H[i, j]:
path[point(i, j)].append(point(i, j - 1))
end = numpy.argwhere(H == numpy.max(H))
for pos in end:
key = point(pos[0], pos[1])
value = path[key]
result = [key]
self.traceback(A, B, H, path, value, result)
def Answer(self): # 取均值
return sum(self.similarity) / len(self.similarity)
完整代码
import math
import numpy
import ast
Similarity = []
def point(x, y):
return '[' + str(x) + ',' + str(y) + ']'
class CodeVisitor(ast.NodeVisitor):
def __init__(self):
self.seq = []
def generic_visit(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.seq.append(type(node).__name__)
def visit_FunctionDef(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.seq.append(type(node).__name__)
def visit_Assign(self, node):
self.seq.append(type(node).__name__)
class CodeParse(object):
def __init__(self, fileA, fileB):
self.visitorB = None
self.visitorA = None
self.codeA = open(fileA, encoding="utf-8").read()
self.codeB = open(fileB, encoding="utf-8").read()
self.nodeA = ast.parse(self.codeA)
self.nodeB = ast.parse(self.codeB)
self.seqA = ""
self.seqB = ""
self.work()
def work(self):
self.visitorA = CodeVisitor()
self.visitorA.visit(self.nodeA)
self.seqA = self.visitorA.seq
self.visitorB = CodeVisitor()
self.visitorB.visit(self.nodeB)
self.seqB = self.visitorB.seq
class CalculateSimilarity(object):
def __init__(self, A, B, W, M, N):
self.A = A
self.B = B
self.W = W
self.M = M
self.N = N
self.similarity = []
self.SimthWaterman(self.A, self.B, self.W)
def score(self,a, b):
if a == b:
return self.M
else:
return self.N
def traceback(self,A, B, H, path, value, result):
if value:
temp = value[0]
result.append(temp)
value = path[temp]
x = int((temp.split(',')[0]).strip('['))
y = int((temp.split(',')[1]).strip(']'))
else:
return
if H[x, y] == 0: # 终止条件
xx = 0
yy = 0
sim = 0
for item in range(len(result) - 2, -1, -1):
position = result[item]
x = int((position.split(',')[0]).strip('['))
y = int((position.split(',')[1]).strip(']'))
if x == xx:
pass
elif y == yy:
pass
else:
sim = sim + 1
xx = x
yy = y
self.similarity.append(sim * 2 / (len(A) + len(B)))
else:
self.traceback(A, B, H, path, value, result)
def SimthWaterman(self, A, B, W):
n, m = len(A), len(B)
H = numpy.zeros([n + 1, m + 1], int)
path = {}
for i in range(0, n + 1):
for j in range(0, m + 1):
if i == 0 or j == 0:
path[point(i, j)] = []
else:
s = self.score(A[i - 1], B[j - 1])
L = H[i - 1, j - 1] + s
P = H[i - 1, j] - W
Q = H[i, j - 1] - W
H[i, j] = max(L, P, Q, 0)
# 添加进路径
path[point(i, j)] = []
if math.floor(L) == H[i, j]:
path[point(i, j)].append(point(i - 1, j - 1))
if math.floor(P) == H[i, j]:
path[point(i, j)].append(point(i - 1, j))
if math.floor(Q) == H[i, j]:
path[point(i, j)].append(point(i, j - 1))
end = numpy.argwhere(H == numpy.max(H))
for pos in end:
key = point(pos[0], pos[1])
value = path[key]
result = [key]
self.traceback(A, B, H, path, value, result)
def Answer(self): # 取均值
return sum(self.similarity) / len(self.similarity)
def main():
AST = CodeParse("test1.py","test2.py")
RES = CalculateSimilarity(AST.seqA, AST.seqB, 1, 1, -1/3)
print(RES.Answer())
if __name__ == "__main__":
main()
向CodeParse传入两个python文件名,即可计算出最终的相似度值.