在调用装饰器函数时传递参数

30 阅读2分钟

最近在使用装饰器,我创建了一个用于过滤目标函数结果的装饰器类,该目标函数默认返回一个特定序列:

class Filter(object):
    def __init__(self, id=None):
        self.id = id

    def __call__(self, func):
        def wrapper(*args):
            entity_ids = func(*args)
            result = {}
            for k, v in entity_ids.items():
                if self.id:
                    if '_' + str(self.id) in k:
                        result.update({k: v})
            return result
        return wrapper

我在其他类的某些方法中使用此装饰器,如下所示:

class SomeClass(object):
    @Filter(id=None)
    def get_ids(*args):
        return result_sequence

现在的问题是,当调用类方法时,如何为装饰器定义参数:

>>>sc = SomeClass()
>>>sc.get_ids(*args)  # 我想在这里传递 Filter 的 id 关键字参数

2、解决方案 方案一 在类定义中应用 Filter 装饰器时,可以在那里传递 id 参数:

@Filter(id=None)

如果 id 应为其他值,则需要在那里传递该值。Filter() 对象在 @Filter(id=None) 行中创建,然后调用。您也可以将代码重写为:

class SomeClass(object):
    def get_ids(*args):
        return result_sequence
    get_ids = Filter(id=None)(get_ids)

因为这是 Python 处理装饰器时所做的。Filter.call() 方法的返回值替换了 get_ids,您在此时无法再为 Filter() 对象指定参数。SomeClass.get_ids() 现在是装饰器返回的嵌套 wrapper() 函数。

如果您想在调用装饰的方法时指定 id,则需要将 wrapper() 签名更改为接受(可选)的额外 id 参数。由于您已经支持 *args,因此您唯一的选择是添加 **kwargs 通用参数来支持可选的关键字参数:

def wrapper(*args, **kwargs):
    id = kwargs.get('id', self.id)
    entity_ids = func(*args)
    result = {}
    for k, v in entity_ids.items():
        if id:
            if '_' + str(id) in k:
                result.update({k: v})
    return result

在这里,id 关键字参数覆盖了装饰器类上设置的 id 值,而不是直接使用 self.id:

sc.get_ids(*args, id='foo')

您可能还想将任何关键字参数传递给包装函数;在这种情况下,我会使用:

def wrapper(*args, **kwargs):
    id = kwargs.pop('id', self.id)
    entity_ids = func(*args, **kwargs)
    result = {}
    for k, v in entity_ids.items():
        if id:
            if '_' + str(id) in k:
                result.update({k: v})
    return result

在这里,id 关键字参数在将剩余的关键字参数传递给包装函数之前被删除。

方案二 如果您想能够在调用装饰的函数时覆盖 id 参数,可以使用关键字参数来做到这一点:

class Filter(object):
    def __init__(self, id=None):
        self.id = id

    def __call__(self, func):
        def wrapper(*args, **kwargs):
            id = kwargs.get("id", self.id)
            entity_ids = func(*args)
            result = {}
            for k, v in entity_ids.items():
                if id:
                    if '_' + str(id) in k:
                        result.update({k: v})
            return result
        return wrapper

但请注意,这意味着:

  1. 如果您想重载默认值,则必须以关键字参数的形式传递 id
  2. 您无法通过这种方式将关键字参数传递给装饰的函数(但无论如何您也不会传递关键字参数)。

作为旁注(这里有点题外话),包装器函数的实现可以稍作改进:

def wrapper(*args, **kwargs):
    id = kwargs.get("id", self.id)
    if not id:
       # no need to go further
       return {}

    id = "_%s" % id
    entity_ids = func(*args)
    result = dict(
        (k, v) for k, v in entity_ids.items()
        if id in k
        )
    return result