在 Python 数据类中定义的可调用类型(比如函数)字段,在设置默认值时,可能会遭遇意外行为。
Python 数据类的字段最好用 dataclasses.field 定义默认值。
下午 huaji 在写实验代码的时候分享说遇到了如下 bug:
@dataclass
class A():
f = torch.exp
g = torch.nn.functional.normalize 结果,A().f(tensor) 可以正常工作,但是 A().g(tensor) 却把这个 A 类型的数据类当作第一个参数传给了 normalize。这是为什么?
首先考虑 g 的行为。在 Python 的 class 定义里,使用 def 定义的函数除非有特定装饰器,否则都会被转换成方法。既然是方法,那么第一个参数必须是 self。而 Python 在 class 定义内部的 def 定义实际上等价于以下赋值形式:
class Class:
def method(self, ...): ...
# 实际上是
def function(self, ...): ...
class Class:
method = function 所以当我们把 torch.nn.functional.normalize 赋值给 g 的时候,它也首先被转换成了方法——调用 type(A().g) 显示的确是 method。即使 A 是数据类,这种转换也会照常进行!
但更匪夷所思地,为什么 torch.exp 和 torch.nn.functional.normalize 的行为不一致?假如我们用 type 检查这两个函数的类型的话,会发现前者的类型是 builtin_function_or_method,而后者是 function。进一步地,假如用 inspect.isfunction 检验,会发现前者返回 False 而后者返回 True。但假如你用 repr 查看,它们都被描述成 <function ...>。
没错,在 Python 里面内置函数不是函数。
torch.exp 在运行时是一个 C++ 实现的外部函数。所以 class 不会把它转换成方法。同样地,还有其他一些我们当作函数使用但根本不是函数的可调用对象。
- 外部定义的函数。比如
torch.exp和torch.tensor - 内置函数。比如
map和all;特别地,某些标准库模块的函数可能有 Python 实现和 C 实现的两种版本,这取决于平台实现 - 类型。比如
str和torch.nn.Module numpy.ufunc。numpy的 API 会是这种类型numpy.vectorize向量化过的函数functools.partial创建的偏函数
所以,如果想要在数据类定义这种可调用对象的字段的话,还是用 dataclasses.field 最好,可以绕过 Python 类定义语法的隐式转换。但这时候 Callable 类型注解就不能省略。
from collections.abc import Callable
class A:
f: Callable = field(default=torch.log)
g: Callable = field(default=torch.nn.functional.normalize) 当然通过 staticmethod 关闭转换也是可以的。
class A:
f: Callable = staticmethod(torch.log)
g: Callable = staticmethod(torch.nn.functional.normalize)