python 在调用时计算默认值

大家都知道python的默认值是在函数定义时计算出来的, 也就是说默认值只会计算一次, 之后函数调用时, 如果参数没有给出,
同一个值会赋值给变量, 这会导致, 如果我们想要一个list默认值, 新手通常这么写:

def foo(a=[]):
 a.append(3)
 print a

其实是错误的,两次调用会这样的结果:

[3]
[3, 3]

其实应该这么写

def baz(a=None):
  a = a or []
a.append(3) print a

两次调用输出以下结果:

[3]
[3]

 

 

这样好挫啊, 搞的虽然有默认值用法,但是我们却需要写的和js,lua一样, 我们不能像c++一样, 在函数运行时每次执行默认值么.
用decorator可以模拟下

import functools
import copy
def
compute_default_value_for_each_call(func): defaults = func.__defaults__ if not defaults: return None defaults = tuple(copy.copy(x) for x in defaults) @functools.wraps(func) def wrapper(*args, **kwargs): if func.__defaults__ != defaults: func.__defaults__ = tuple(copy.copy(y) for y in defaults) return func(*args, **kwargs) return wrapper @compute_default_value_for_each_call def foo(b, a=[]): if b: a.append(3) return a import timeit

这样两次调用foo(1), 结果为:

[3]
[3]

这个decorator有对未修改默认参数值做优化, 在我们不对默认值修改的情况下(比如打印变量a的内容), 性能有很大提升:

@compute_default_value_for_each_call
def foo(b, a=[]):
    if b:
        a.append(3)
    return a


def foo2(b, a=None):
    a = a or []
    if b:
        a.append(3)
    return a

import timeit

print timeit.timeit('foo(1)', setup='from __main__ import foo')
print timeit.timeit('foo(0)', setup='from __main__ import foo')
print timeit.timeit('foo2(1)', setup='from __main__ import foo2')
print timeit.timeit('foo2(0)', setup='from __main__ import foo2')

执行结果(调用1000000次的总时间)

4.32704997063
0.630109071732
0.445858955383
0.26370882988

性能上还过得去....

觉得这种方法性能低的同学, 还可以使用ast模块, 直接修改函数源代码, 达到a = a or []这样的性能. 但是实现起来略麻烦, 有机会我再试试.

 

====================================================

修改函数代码版本如下:

import inspect
import dis
import ast
import re
import traceback
import timeit


def compute_default_value_for_each_call(func):
    source = inspect.getsource(func)
    m = re.search(r'^\s+', source)
    if m:
        m.group(0)
        n = len(m.group(0))
        lines = [line[n:] for line in str.splitlines(source)]
        source = '\n'.join(lines)
    root = ast.parse(source)
    root.body[0].decorator_list.pop()
    arg_names = [arg.id for arg in root.body[0].args.args]
    default_nodes = root.body[0].args.defaults
    arg_names = arg_names[len(arg_names) - len(default_nodes):]
    body = root.body[0].body
    n = len(arg_names)
    for i in reversed(xrange(n)):
        arg_name = arg_names[i]
        default_node = default_nodes[i]
        lineno = default_node.lineno
        col_offset = default_node.col_offset
        body.insert(0, ast.Assign(targets=[ast.Name(id=arg_name, ctx=ast.Store(),
                                                    lineno=lineno, col_offset=col_offset)],
                                  value=ast.BoolOp(op=ast.Or(), lineno=lineno, col_offset=col_offset,
                                  values=[ast.Name(id=arg_name, ctx=ast.Load(), lineno=lineno, col_offset=col_offset),
                                          default_node]),
                                  lineno=lineno, col_offset=col_offset))
    root.body[0].args.defaults = [ast.Name(id='None', ctx=ast.Load(), lineno=old.lineno, col_offset=old.col_offset)
                                  for old in default_nodes]
    root.body[0].body = body
    l = {}
    exec compile(root, '<string>', mode='exec') in globals(), l
    func = l[func.__name__]
    return func


def root2():
    def main():

        @compute_default_value_for_each_call
        def root(a=([])):
            a.append(3)
            return a
        print root()
        print root()
    main()

print 'used in local function'
root2()


def foo():
    return 42


bar_count = 0


def bar():
    global bar_count
    if bar_count:
        raise RuntimeError
    bar_count += 1


@compute_default_value_for_each_call
def f1(a=foo()):
    return a


def f2(a=None):
    a = a or foo()
    return a


@compute_default_value_for_each_call
def f3(a=bar()):
    return a


def f4(a=None):
    a = a or bar()
    return a

print 'f1:'
dis.dis(f1)
print 'f2:'
dis.dis(f2)

print 'f2 running time:'
print timeit.timeit('f2', setup='from __main__ import f2')
print 'f1 running time:'
print timeit.timeit('f1', setup='from __main__ import f1')
print 'f2 running time:'
print timeit.timeit('f2', setup='from __main__ import f2')

print 'f3:'
dis.dis(f3)
print 'f4:'
dis.dis(f4)

try:
    f3()
except:
    print traceback.format_exc()
try:
    f4()
except:
    print traceback.format_exc()

 

输出:

used in local function
[3]
[3]
f1:
  2           0 LOAD_FAST                0 (a)
              3 JUMP_IF_TRUE_OR_POP     12
              6 LOAD_GLOBAL              0 (foo)
              9 CALL_FUNCTION            0
        >>   12 STORE_FAST               0 (a)

  3          15 LOAD_FAST                0 (a)
             18 RETURN_VALUE        
f2:
 83           0 LOAD_FAST                0 (a)
              3 JUMP_IF_TRUE_OR_POP     12
              6 LOAD_GLOBAL              0 (foo)
              9 CALL_FUNCTION            0
        >>   12 STORE_FAST               0 (a)

 84          15 LOAD_FAST                0 (a)
             18 RETURN_VALUE        
f2 running time:
0.0288169384003
f1 running time:
0.0251071453094
f2 running time:
0.025267124176
f3:
  2           0 LOAD_FAST                0 (a)
              3 JUMP_IF_TRUE_OR_POP     12
              6 LOAD_GLOBAL              0 (bar)
              9 CALL_FUNCTION            0
        >>   12 STORE_FAST               0 (a)

  3          15 LOAD_FAST                0 (a)
             18 RETURN_VALUE        
f4:
 93           0 LOAD_FAST                0 (a)
              3 JUMP_IF_TRUE_OR_POP     12
              6 LOAD_GLOBAL              0 (bar)
              9 CALL_FUNCTION            0
        >>   12 STORE_FAST               0 (a)

 94          15 LOAD_FAST                0 (a)
             18 RETURN_VALUE        
Traceback (most recent call last):
  File "./b/b.py", line 114, in <module>
    f3()
  File "<string>", line 2, in f3
  File "./b/b.py", line 73, in bar
    raise RuntimeError
RuntimeError

Traceback (most recent call last):
  File "./b/b.py", line 118, in <module>
    f4()
  File "./b/b.py", line 93, in f4
    a = a or bar()
  File "./b/b.py", line 73, in bar
    raise RuntimeError
RuntimeError

可以看到性能完全一样了, 在调试信息也还不错

 

posted @ 2014-07-24 22:59  福尔摩喵  阅读(1127)  评论(4编辑  收藏  举报