Python使用技巧--python装饰器的使用
一、函数装饰器的运用
示例一:编写计时装饰器
1.简易版装饰器:该装饰器支持装饰普通方法,也支持类中的方法,但是不支持传入装饰器参数;
def timer(func):
"""
用于对函数计时,不支持传入装饰器参数;
该装饰器支持装饰普通方法,也支持类中的方法
"""
def onCall(*args, **kargs):
start = time.time()
result = func(*args, **kargs)
elapsed = time.time() - start
onCall.alltime += elapsed
format = '%s run time: %.5f s;total run time:%.5f s'
values = (func.__name__, elapsed, onCall.alltime)
print(format % values)
return result
onCall.alltime = 0
return onCall
运行测试:
# 测试普通函数
@timer
def test():
time.sleep(1)
print("run test")
test()
test()
# 输出
run test
test run time: 1.00509 s;total run time:1.00509 s
run test
test run time: 1.00507 s;total run time:2.01015 s
#测试类中的方法
class Test():
@timer
def test_method(self):
time.sleep(1)
print("run test_method")
a = Test()
a.test_method()
a.test_method()
#输出
run test_method
test_method run time: 1.00507 s;total run time:1.00507 s
run test_method
test_method run time: 1.00509 s;total run time:2.01016 s
示例一 加强版
该装饰器支持装饰普通方法,支持类中的方法,同时也支持传入装饰器参数
def timer(label='', trace=True):
"""
用于对函数计时,支持传入装饰器参数;
该装饰器支持装饰普通方法,也支持类中的方法
"""
def onDecorator(func):
def onCall(*args, **kargs):
start = time.time()
result = func(*args, **kargs)
elapsed = time.time() - start
onCall.alltime += elapsed
if trace:
format = '%s%s run time: %.5f s;total run time:%.5f s'
values = (label, func.__name__, elapsed, onCall.alltime)
print(format % values)
return result
onCall.alltime = 0
return onCall
return onDecorator
运行测试:
# 测试普通函数
@timer(label="Fun test===>")
def test():
time.sleep(1)
print("run test")
test()
test()
#输出:
run test
Fun test===>test run time: 1.00525 s;total run time:1.00525 s
run test
Fun test===>test run time: 1.00097 s;total run time:2.00623 s
# 测试类中的方法
class Test():
@timer(label="Class method test===>")
def test_method(self):
time.sleep(1)
print("run test_method")
a = Test()
a.test_method()
a.test_method()
#输出:
run test_method
Class method test===>test_method run time: 1.00361 s;total run time:1.00361 s
run test_method
Class method test===>test_method run time: 1.00507 s;total run time:2.00868 s
示例二:参数验证装饰器
1.简易版:只有范围验证的功能
# -*-coding:utf-8-*-
def rangetest(trace=True, **argchecks): # Validate ranges for both+defaults
def onDecorator(func): # onCall remembers func and argchecks
if not __debug__: # True if "python -O main.py args..."
return func # Wrap if debugging; else use original
else:
import sys
code = func.__code__
allargs = code.co_varnames[:code.co_argcount]
funcname = func.__name__
def onCall(*pargs, **kargs):
# All pargs match first N expected args by position
# The rest must be in kargs or be omitted defaults
positionals = list(allargs)
positionals = positionals[:len(pargs)]
for (argname, (low, high)) in argchecks.items():
# For all args to be checked
if argname in kargs:
# Was passed by name
if kargs[argname] < low or kargs[argname] > high:
errmsg = '{0} argument "{1}" not in {2}..{3}'
errmsg = errmsg.format(funcname, argname, low, high)
raise TypeError(errmsg)
elif argname in positionals:
# Was passed by position
position = positionals.index(argname)
if pargs[position] < low or pargs[position] > high:
errmsg = '{0} argument "{1}" not in {2}..{3}'
errmsg = errmsg.format(funcname, argname, low, high)
raise TypeError(errmsg)
else:
# Assume not passed: default
if trace:
print('Argument "{0}" defaulted'.format(argname))
return func(*pargs, **kargs) # OK: run original call
return onCall
return onDecorator
运行测试:
测试普通函数
@rangetest(age=(0, 120)) # persinfo = rangetest(..)(persinfo)
def persinfo(name, age):
print('%s is %s years old' % (name, age))
@rangetest(M=(1, 12), D=(1, 31), Y=(0, 2009))
def birthday(M, D, Y):
print('birthday = {0}/{1}/{2}'.format(M, D, Y))
persinfo('Bob', 40)
birthday(5, D=1, Y=1963)
# 输出
Bob is 40 years old
birthday = 5/1/1963
persinfo('Bob', 150)
# 输出
Traceback (most recent call last):
File "/Users/edwin/PycharmProjects/testProject/test.py", line 64, in <module>
persinfo('Bob', 150)
File "/Users/edwin/PycharmProjects/testProject/test.py", line 37, in onCall
raise TypeError(errmsg)
TypeError: persinfo argument "age" not in 0..120
测试类中的方法
# 测试类中的方法
class Person:
def __init__(self, name, job, pay):
self.job = job
self.pay = pay
# giveRaise = rangetest(..)(giveRaise)
@rangetest(percent=(0.0, 1.0)) # percent passed by name or position
def giveRaise(self, percent):
self.pay = int(self.pay * (1 + percent))
bob = Person('Bob Smith', 'dev', 100000)
sue = Person('Sue Jones', 'dev', 100000)
bob.giveRaise(0.10)
sue.giveRaise(percent=0.20)
print(bob.pay, sue.pay)
# 输出
110000 120000
bob.giveRaise(1.10)
# 输出
Traceback (most recent call last):
File "/Users/edwin/PycharmProjects/testProject/test.py", line 84, in <module>
bob.giveRaise(1.10)
File "/Users/edwin/PycharmProjects/testProject/test.py", line 34, in onCall
raise TypeError(errmsg)
TypeError: giveRaise argument "percent" not in 0.0..1.0
示例二加强版
可以处理范围测试,类型测试,值测试
def rangetest(trace=True, **argchecks):
return argtest(argchecks, lambda arg, vals: arg < vals[0] or arg > vals[1], trace=trace)
def typetest(trace=True, **argchecks):
return argtest(argchecks, lambda arg, type: not isinstance(arg, type), trace=trace)
def valuetest(trace=True, **argchecks):
return argtest(argchecks, lambda arg, tester: not tester(arg), trace=trace)
def argtest(argchecks, failif, trace): # Validate ranges for both+defaults
def onDecorator(func): # onCall remembers func and argchecks
if not __debug__: # True if "python -O main.py args..."
return func # Wrap if debugging; else use original
else:
code = func.__code__
allargs = code.co_varnames[:code.co_argcount]
funcname = func.__name__
def onError(argname, criteria):
errfmt = '%s argument "%s" not %s'
raise TypeError(errfmt % (funcname, argname, criteria))
def onCall(*pargs, **kargs):
# All pargs match first N expected args by position
# The rest must be in kargs or be omitted defaults
positionals = list(allargs)
positionals = positionals[:len(pargs)]
for (argname, criteria) in argchecks.items():
# 关键字参数检查
if argname in kargs:
# Was passed by name
if failif(kargs[argname], criteria):
onError(argname, criteria)
# 位置参数检查
elif argname in positionals:
# Was passed by position
position = positionals.index(argname)
if failif(pargs[position], criteria):
onError(argname, criteria)
else:
# Assume not passed: default
if trace:
print('Argument "{0}" defaulted'.format(argname))
return func(*pargs, **kargs) # OK: run original call
return onCall
return onDecorator
运行测试:
import sys
def fails(test):
try:
result = test()
except:
print("[%s]" % sys.exc_info()[1])
else:
print('?%s?' % result)
@rangetest(M=(1, 12), D=(1, 31), Y=(1900, 2013))
def date(M, D, Y):
print('date = {0}/{1}/{2}'.format(M, D, Y))
date(5, 1, 1960)
date(M=5, D=1, Y=1960)
fails(lambda: date(1, 2, 3))
print("---------------------------------------------")
@typetest(a=int, c=float)
def sum(a, b, c, d):
print(a+b+c+d)
sum(1, 2, 3.0, 4)
fails(lambda: sum('spam', 2, 3, 4))
print("---------------------------------------------")
@valuetest(word1=str.islower, word2=(lambda x: x[0].isupper()))
def msg(word1, word2):
print("%s %s" % (word1, word2))
msg('edwin', 'Edwin')
fails(lambda: msg('Edwin', 'EdWin'))
print("---------------------------------------------")
@rangetest(X=(1, 10))
@typetest(Z=str)
def nester(X, Y, Z):
print("%s %s %s" % (X, Y, Z))
nester(1, 2, "edwin")
fails(lambda: nester(1, 2, 1))
输出:
date = 5/1/1960
date = 5/1/1960
[date argument "Y" not (1900, 2013)]
---------------------------------------------
10.0
[sum argument "a" not <class 'int'>]
---------------------------------------------
edwin Edwin
[msg argument "word1" not <method 'islower' of 'str' objects>]
---------------------------------------------
Argument "X" defaulted
1 2 edwin
Argument "X" defaulted
[nester argument "Z" not <class 'str'>]
二、类装饰器的运用
示例一:实现单例功能的类装饰器
只适合用于python3环境,因为nonlocal语句仅在python3.x中可用
def singleton(aClass):
"""
管理一个类只能创建一个实例
只适合用于python3环境,因为nonlocal语句仅在python3.x中可用
:param aClass: 装饰的类
:return:
"""
instance = None
def onCall(*args):
nonlocal instance
if instance == None:
instance = aClass(*args)
return instance
return onCall
适合用于python2和python3环境
def singleton(aClass):
"""
管理一个类只能创建一个实例
适合用于python2和python3环境
:param aClass: 装饰的类
:return:
"""
def onCall(*args):
if onCall.instance == None:
onCall.instance = aClass(*args)
return onCall.instance
onCall.instance = None
return onCall
运行测试:
if __name__ == '__main__':
@singleton # Person = singleton(Person)
class Person:
def __init__(self, name, hours, rate):
self.name = name
self.hours = hours
self.rate = rate
def pay(self):
return self.hours * self.rate
bob = Person('Bob', 40, 10) # Really calls onCall
print(bob.name, bob.pay())
sue = Person('Sue', 50, 20) # Same, single object
print(sue.name, sue.pay())
输出:
Bob 400
Bob 400
示例二:类属性访问装饰器
只适合用于python2的装饰器
traceMe = False
def trace(*args):
if traceMe: print('[' + ' '.join(map(str, args)) + ']')
def accessControl(failIf):
def onDecorator(aClass):
if not __debug__:
return aClass
else:
class onInstance:
def __init__(self, *args, **kargs):
self.__wrapped = aClass(*args, **kargs)
def __getattr__(self, attr):
trace('get:', attr)
if failIf(attr):
raise TypeError('private attribute fetch: ' + attr)
else:
return getattr(self.__wrapped, attr)
def __setattr__(self, attr, value):
trace('set:', attr, value)
if attr == '_onInstance__wrapped':
self.__dict__[attr] = value
elif failIf(attr):
raise TypeError('private attribute change: ' + attr)
else:
setattr(self.__wrapped, attr, value)
return onInstance
return onDecorator
def Private(*attributes):
return accessControl(failIf=(lambda attr: attr in attributes))
def Public(*attributes):
return accessControl(failIf=(lambda attr: attr not in attributes))
测试:
if __name__ == '__main__':
import sys
@Private('age') # Person = Private('age')(Person)
class Person: # Person = onInstance with state
def __init__(self, name, age):
self.name = name
self.age = age # Inside accesses run normally
def __add__(self, other):
self.age += other
def __str__(self):
return '%s: %s' % (self.name, self.age)
X = Person('Bob', 40)
print(X.name) # Outside accesses validated
X.name = 'Sue'
print(X.name)
X + 10
print(X)
try:
t = X.age
except:
print("Error:[%s]" % sys.exc_info()[1])
这里当运行到X+10这条语句时,便会报错:
Traceback (most recent call last):
File "/Users/edwin/PycharmProjects/testProject/test.py", line 75, in <module>
X + 10
TypeError: unsupported operand type(s) for +: 'onInstance' and 'int'
异常分析:
当在python2下运行时,代理类(onInstance)是一个经典类,但是当在python3下运行时,代理类是一个新式类。(python3中只有新式类,没有经典类)。当通过内置操作隐式地运行时(X+10),在经典类中,会触发代理类(onInstance)中__getattr__的调用,在新式类中,不会触发代理类(onInstance)中__getattr__的调用,从而不会调用到Person类中的__add__。详细细节请参考另一文章<<Python使用技巧--拦截内置运算属性>>。。
注意:在python2的默认经典类中,__getattr__会拦截内置函数对__add__和__str__这样的运算符重载方法的隐式访问,但是在python3的新式类中不会拦截(包括python2的新式类)。
解决方法:
在代理类中重新定义__add__这些运算符重载方法。
示例二改进版
适合用于python2和python3的装饰器
以下装饰器使用了一个混合技巧来为包装器类添加一些运算符重载方法的重定义,这样在python3.x中它会正确地将内置操作委托到使用这些方法的主体类上。
# -*-coding:utf-8-*-
traceMe = False
def trace(*args):
if traceMe: print('[' + ' '.join(map(str, args)) + ']')
def accessControl(failIf):
def onDecorator(aClass):
if not __debug__:
return aClass
else:
class onInstance(BuiltinsMixin):
def __init__(self, *args, **kargs):
self.__wrapped = aClass(*args, **kargs)
def __getattr__(self, attr):
trace('get:', attr)
if failIf(attr):
raise TypeError('private attribute fetch: ' + attr)
else:
return getattr(self.__wrapped, attr)
def __setattr__(self, attr, value):
trace('set:', attr, value)
if attr == '_onInstance__wrapped':
self.__dict__[attr] = value
elif failIf(attr):
raise TypeError('private attribute change: ' + attr)
else:
setattr(self.__wrapped, attr, value)
return onInstance
return onDecorator
def Private(*attributes):
return accessControl(failIf=(lambda attr: attr in attributes))
def Public(*attributes):
return accessControl(failIf=(lambda attr: attr not in attributes))
class BuiltinsMixin():
def reroute(self, attr, *args, **kargs):
return self.__class__.__getattr__(self, attr)(*args, **kargs)
def __add__(self, other):
return self.reroute('__add__', other)
def __str__(self):
return self.reroute('__str__')
def __getitem__(self, index):
return self.reroute('__getitem__', index)
def __call__(self, *args, **kargs):
return self.reroute('__call__', *args, **kargs)
测试:
if __name__ == '__main__':
import sys
@Private('age') # Person = Private('age')(Person)
class Person: # Person = onInstance with state
def __init__(self, name, age):
self.name = name
self.age = age # Inside accesses run normally
def __add__(self, other):
self.age += other
def __str__(self):
return '%s: %s' % (self.name, self.age)
X = Person('Bob', 40)
print(X.name) # Outside accesses validated
X.name = 'Sue'
print(X.name)
X + 10
print(X)
try:
t = X.age
except:
print("Error:[%s]" % sys.exc_info()[1])
输出:
Bob
Sue
Sue: 50
Error:[private attribute fetch: age]
三、编写装饰器的注意事项
1.保持多个装饰的实例
我们都是知道,编写装饰器的时候,可以使用函数,也可以使用类来编写,但是当使用类来编写的时候,我们需要注意装饰的实例被覆盖。
以下装饰器实现属性调用的追踪。
class Tracer:
def __init__(self, aClass):
self.fetches = 0
self.aClass = aClass
def __call__(self, *args):
self.wrapped = self.aClass(*args)
return self
def __getattr__(self, attrname):
print('Trace: ' + attrname)
self.fetches += 1
return getattr(self.wrapped, attrname)
测试
@Tracer
class Person: # Person = Tracer(Person)
def __init__(self, name): # Wrapper bound to Person
self.name = name
bob = Person('Bob')
print(bob.name)
Sue = Person('Sue') #bob实例被sue实例
print(sue.name)
print(bob.name) # bob实例的name='Sue'!
分析:
每个实例构建调用会触发__call__,这会覆盖前面的实例。直接效果是Tracer只保存了一个实例,即最后创建的那个实例。
改进:基于函数的装饰器可用于多个实例,因为每个实例构造调用都会创建一个新的Wrapper实例,而不是覆盖一个单个共享的Tracer实例的状态。
def Tracer(aClass):
class Wrapper:
def __init__(self):
self.fetches = 0
self.aClass = aClass
def __call__(self, *args, **kwargs):
self.wrapped = self.aClass(*args, **kwargs)
return self
def __getattr__(self, attrname):
print('Trace: ' + attrname)
self.fetches += 1
return getattr(self.wrapped, attrname)
return Wrapper
2.对类方法进行装饰
我们编写以下一个装饰器:
class tracer:
def __init__(self, func):
self.calls = 0
self.func = func
def __call__(self, *args):
self.calls += 1
print('call %s to %s' % (self.calls, self.func.__name__))
self.func(*args)
装饰普通函数没问题:
@tracer
def spam(a, b, c): # spam = tracer(spam)
print(a + b + c) # Wraps spam in a decorator object
spam(1, 2, 3)
spam('a', 'b', 'c')
输出:
call 1 to spam
6
call 2 to spam
abc
当装饰类中的方法,就失效了。
if __name__ == '__main__':
class Person:
def __init__(self, name, pay):
self.name = name
self.pay = pay
@tracer
def giveRaise(self, percent): # giveRaise = tracer(giverRaise)
self.pay *= (1.0 + percent)
@tracer
def lastName(self): # lastName = tracer(lastName)
return self.name.split()[-1]
bob = Person('Bob Smith', 50000)
bob.giveRaise(0.25) # Runs tracer.__call__(???, .25)
print(bob.lastName()) # Runs tracer.__call__(???)
输出:
Traceback (most recent call last):
File "/Users/edwin/PycharmProjects/testProject/test.py", line 26, in <module>
bob.giveRaise(0.25)
File "/Users/edwin/PycharmProjects/testProject/test.py", line 10, in __call__
self.func(*args)
TypeError: giveRaise() missing 1 required positional argument: 'percent'
call 1 to giveRaise
分析:
这里用__call__把被装饰方法名称重绑定到一个类实例对象的时候,python只向self传递了tracer实例,它根本没有在参数列表中传递Person主体。因此tracer不知道我们要利用方法调用处理的Person实例的任何信息,导致没办法创建一个带有实例的绑定方法,也没办法正确地分发调用。这是一个非常值得注意的细节。
改进:使用嵌套函数装饰方法
def tracer(func):
calls = 0
def onCall(*args, **kwargs):
nonlocal calls
calls += 1
print('call %s to %s' % (calls, func.__name__))
return func(*args, **kwargs)
return onCall
测试:
if __name__ == '__main__':
class Person:
def __init__(self, name, pay):
self.name = name
self.pay = pay
@tracer
def giveRaise(self, percent): # giveRaise = tracer(giverRaise)
self.pay *= (1.0 + percent)
@tracer
def lastName(self): # lastName = tracer(lastName)
return self.name.split()[-1]
print('methods...')
bob = Person('Bob Smith', 50000)
sue = Person('Sue Jones', 100000)
print(bob.name, sue.name)
sue.giveRaise(.10) # Runs onCall(sue, .10)
print(sue.pay)
print(bob.lastName(), sue.lastName()) # Runs onCall(bob), lastName in scopes
输出:
methods...
Bob Smith Sue Jones
call 1 to giveRaise
110000.00000000001
call 1 to lastName
call 2 to lastName
Smith Jones
浙公网安备 33010602011771号