python 在调用时计算默认值
大家都知道python的默认值是在函数定义时计算出来的, 也就是说默认值只会计算一次, 之后函数调用时, 如果参数没有给出,
同一个值会赋值给变量, 这会导致, 如果我们想要一个list默认值, 新手通常这么写:
def foo(a=[]): a.append(3) print a
其实是错误的,两次调用会这样的结果:
[3]
[3, 3]
其实应该这么写
def baz(a=None):
a = a or [] a.append(3) print a
两次调用输出以下结果:
[3]
[3]
这样好挫啊, 搞的虽然有默认值用法,但是我们却需要写的和js,lua一样, 我们不能像c++一样, 在函数运行时每次执行默认值么.
用decorator可以模拟下
import functools
import copy
def compute_default_value_for_each_call(func): defaults = func.__defaults__ if not defaults: return None defaults = tuple(copy.copy(x) for x in defaults) @functools.wraps(func) def wrapper(*args, **kwargs): if func.__defaults__ != defaults: func.__defaults__ = tuple(copy.copy(y) for y in defaults) return func(*args, **kwargs) return wrapper @compute_default_value_for_each_call def foo(b, a=[]): if b: a.append(3) return a import timeit
这样两次调用foo(1), 结果为:
[3]
[3]
这个decorator有对未修改默认参数值做优化, 在我们不对默认值修改的情况下(比如打印变量a的内容), 性能有很大提升:
@compute_default_value_for_each_call def foo(b, a=[]): if b: a.append(3) return a def foo2(b, a=None): a = a or [] if b: a.append(3) return a import timeit print timeit.timeit('foo(1)', setup='from __main__ import foo') print timeit.timeit('foo(0)', setup='from __main__ import foo') print timeit.timeit('foo2(1)', setup='from __main__ import foo2') print timeit.timeit('foo2(0)', setup='from __main__ import foo2')
执行结果(调用1000000次的总时间)
4.32704997063 0.630109071732 0.445858955383 0.26370882988
性能上还过得去....
觉得这种方法性能低的同学, 还可以使用ast模块, 直接修改函数源代码, 达到a = a or []这样的性能. 但是实现起来略麻烦, 有机会我再试试.
====================================================
修改函数代码版本如下:
import inspect import dis import ast import re import traceback import timeit def compute_default_value_for_each_call(func): source = inspect.getsource(func) m = re.search(r'^\s+', source) if m: m.group(0) n = len(m.group(0)) lines = [line[n:] for line in str.splitlines(source)] source = '\n'.join(lines) root = ast.parse(source) root.body[0].decorator_list.pop() arg_names = [arg.id for arg in root.body[0].args.args] default_nodes = root.body[0].args.defaults arg_names = arg_names[len(arg_names) - len(default_nodes):] body = root.body[0].body n = len(arg_names) for i in reversed(xrange(n)): arg_name = arg_names[i] default_node = default_nodes[i] lineno = default_node.lineno col_offset = default_node.col_offset body.insert(0, ast.Assign(targets=[ast.Name(id=arg_name, ctx=ast.Store(), lineno=lineno, col_offset=col_offset)], value=ast.BoolOp(op=ast.Or(), lineno=lineno, col_offset=col_offset, values=[ast.Name(id=arg_name, ctx=ast.Load(), lineno=lineno, col_offset=col_offset), default_node]), lineno=lineno, col_offset=col_offset)) root.body[0].args.defaults = [ast.Name(id='None', ctx=ast.Load(), lineno=old.lineno, col_offset=old.col_offset) for old in default_nodes] root.body[0].body = body l = {} exec compile(root, '<string>', mode='exec') in globals(), l func = l[func.__name__] return func def root2(): def main(): @compute_default_value_for_each_call def root(a=([])): a.append(3) return a print root() print root() main() print 'used in local function' root2() def foo(): return 42 bar_count = 0 def bar(): global bar_count if bar_count: raise RuntimeError bar_count += 1 @compute_default_value_for_each_call def f1(a=foo()): return a def f2(a=None): a = a or foo() return a @compute_default_value_for_each_call def f3(a=bar()): return a def f4(a=None): a = a or bar() return a print 'f1:' dis.dis(f1) print 'f2:' dis.dis(f2) print 'f2 running time:' print timeit.timeit('f2', setup='from __main__ import f2') print 'f1 running time:' print timeit.timeit('f1', setup='from __main__ import f1') print 'f2 running time:' print timeit.timeit('f2', setup='from __main__ import f2') print 'f3:' dis.dis(f3) print 'f4:' dis.dis(f4) try: f3() except: print traceback.format_exc() try: f4() except: print traceback.format_exc()
输出:
used in local function [3] [3] f1: 2 0 LOAD_FAST 0 (a) 3 JUMP_IF_TRUE_OR_POP 12 6 LOAD_GLOBAL 0 (foo) 9 CALL_FUNCTION 0 >> 12 STORE_FAST 0 (a) 3 15 LOAD_FAST 0 (a) 18 RETURN_VALUE f2: 83 0 LOAD_FAST 0 (a) 3 JUMP_IF_TRUE_OR_POP 12 6 LOAD_GLOBAL 0 (foo) 9 CALL_FUNCTION 0 >> 12 STORE_FAST 0 (a) 84 15 LOAD_FAST 0 (a) 18 RETURN_VALUE f2 running time: 0.0288169384003 f1 running time: 0.0251071453094 f2 running time: 0.025267124176 f3: 2 0 LOAD_FAST 0 (a) 3 JUMP_IF_TRUE_OR_POP 12 6 LOAD_GLOBAL 0 (bar) 9 CALL_FUNCTION 0 >> 12 STORE_FAST 0 (a) 3 15 LOAD_FAST 0 (a) 18 RETURN_VALUE f4: 93 0 LOAD_FAST 0 (a) 3 JUMP_IF_TRUE_OR_POP 12 6 LOAD_GLOBAL 0 (bar) 9 CALL_FUNCTION 0 >> 12 STORE_FAST 0 (a) 94 15 LOAD_FAST 0 (a) 18 RETURN_VALUE Traceback (most recent call last): File "./b/b.py", line 114, in <module> f3() File "<string>", line 2, in f3 File "./b/b.py", line 73, in bar raise RuntimeError RuntimeError Traceback (most recent call last): File "./b/b.py", line 118, in <module> f4() File "./b/b.py", line 93, in f4 a = a or bar() File "./b/b.py", line 73, in bar raise RuntimeError RuntimeError
可以看到性能完全一样了, 在调试信息也还不错