最近在使用装饰器,我创建了一个用于过滤目标函数结果的装饰器类,该目标函数默认返回一个特定序列:
class Filter(object):
def __init__(self, id=None):
self.id = id
def __call__(self, func):
def wrapper(*args):
entity_ids = func(*args)
result = {}
for k, v in entity_ids.items():
if self.id:
if '_' + str(self.id) in k:
result.update({k: v})
return result
return wrapper
我在其他类的某些方法中使用此装饰器,如下所示:
class SomeClass(object):
@Filter(id=None)
def get_ids(*args):
return result_sequence
现在的问题是,当调用类方法时,如何为装饰器定义参数:
>>>sc = SomeClass()
>>>sc.get_ids(*args) # 我想在这里传递 Filter 的 id 关键字参数
2、解决方案 方案一 在类定义中应用 Filter 装饰器时,可以在那里传递 id 参数:
@Filter(id=None)
如果 id 应为其他值,则需要在那里传递该值。Filter() 对象在 @Filter(id=None) 行中创建,然后调用。您也可以将代码重写为:
class SomeClass(object):
def get_ids(*args):
return result_sequence
get_ids = Filter(id=None)(get_ids)
因为这是 Python 处理装饰器时所做的。Filter.call() 方法的返回值替换了 get_ids,您在此时无法再为 Filter() 对象指定参数。SomeClass.get_ids() 现在是装饰器返回的嵌套 wrapper() 函数。
如果您想在调用装饰的方法时指定 id,则需要将 wrapper() 签名更改为接受(可选)的额外 id 参数。由于您已经支持 *args,因此您唯一的选择是添加 **kwargs 通用参数来支持可选的关键字参数:
def wrapper(*args, **kwargs):
id = kwargs.get('id', self.id)
entity_ids = func(*args)
result = {}
for k, v in entity_ids.items():
if id:
if '_' + str(id) in k:
result.update({k: v})
return result
在这里,id 关键字参数覆盖了装饰器类上设置的 id 值,而不是直接使用 self.id:
sc.get_ids(*args, id='foo')
您可能还想将任何关键字参数传递给包装函数;在这种情况下,我会使用:
def wrapper(*args, **kwargs):
id = kwargs.pop('id', self.id)
entity_ids = func(*args, **kwargs)
result = {}
for k, v in entity_ids.items():
if id:
if '_' + str(id) in k:
result.update({k: v})
return result
在这里,id 关键字参数在将剩余的关键字参数传递给包装函数之前被删除。
方案二
如果您想能够在调用装饰的函数时覆盖 id 参数,可以使用关键字参数来做到这一点:
class Filter(object):
def __init__(self, id=None):
self.id = id
def __call__(self, func):
def wrapper(*args, **kwargs):
id = kwargs.get("id", self.id)
entity_ids = func(*args)
result = {}
for k, v in entity_ids.items():
if id:
if '_' + str(id) in k:
result.update({k: v})
return result
return wrapper
但请注意,这意味着:
- 如果您想重载默认值,则必须以关键字参数的形式传递
id。 - 您无法通过这种方式将关键字参数传递给装饰的函数(但无论如何您也不会传递关键字参数)。
作为旁注(这里有点题外话),包装器函数的实现可以稍作改进:
def wrapper(*args, **kwargs):
id = kwargs.get("id", self.id)
if not id:
# no need to go further
return {}
id = "_%s" % id
entity_ids = func(*args)
result = dict(
(k, v) for k, v in entity_ids.items()
if id in k
)
return result