Python 数据类可调用对象字段陷阱

comp
@python

在 Python 数据类中定义的可调用类型(比如函数)字段,在设置默认值时,可能会遭遇意外行为。

TL;DR

Python 数据类的字段最好用 dataclasses.field 定义默认值。

下午 huaji 在写实验代码的时候分享说遇到了如下 bug:

@dataclass
class A():
    f = torch.exp
    g = torch.nn.functional.normalize

结果,A().f(tensor) 可以正常工作,但是 A().g(tensor) 却把这个 A 类型的数据类当作第一个参数传给了 normalize。这是为什么?

首先考虑 g 的行为。在 Python 的 class 定义里,使用 def 定义的函数除非有特定装饰器,否则都会被转换成方法。既然是方法,那么第一个参数必须是 self。而 Python 在 class 定义内部的 def 定义实际上等价于以下赋值形式:

class Class:
   def method(self, ...): ...

# 实际上是

def function(self, ...): ...
class Class:
   method = function

所以当我们把 torch.nn.functional.normalize 赋值给 g 的时候,它也首先被转换成了方法——调用 type(A().g) 显示的确是 method。即使 A 是数据类,这种转换也会照常进行!

但更匪夷所思地,为什么 torch.exptorch.nn.functional.normalize 的行为不一致?假如我们用 type 检查这两个函数的类型的话,会发现前者的类型是 builtin_function_or_method,而后者是 function。进一步地,假如用 inspect.isfunction 检验,会发现前者返回 False 而后者返回 True。但假如你用 repr 查看,它们都被描述成 <function ...>

没错,在 Python 里面内置函数不是函数

torch.exp 在运行时是一个 C++ 实现的外部函数。所以 class 不会把它转换成方法。同样地,还有其他一些我们当作函数使用但根本不是函数的可调用对象

  • 外部定义的函数。比如 torch.exptorch.tensor
  • 内置函数。比如 mapall;特别地,某些标准库模块的函数可能有 Python 实现和 C 实现的两种版本,这取决于平台实现
  • 类型。比如 strtorch.nn.Module
  • numpy.ufuncnumpy 的 API 会是这种类型
  • numpy.vectorize 向量化过的函数
  • functools.partial 创建的偏函数

所以,如果想要在数据类定义这种可调用对象的字段的话,还是用 dataclasses.field 最好,可以绕过 Python 类定义语法的隐式转换。但这时候 Callable 类型注解就不能省略。

from collections.abc import Callable
class A:
    f: Callable = field(default=torch.log)
    g: Callable = field(default=torch.nn.functional.normalize)

当然通过 staticmethod 关闭转换也是可以的。

class A:
    f: Callable = staticmethod(torch.log)
    g: Callable = staticmethod(torch.nn.functional.normalize)