在 Python 数据类中定义的可调用类型(比如函数)字段,在设置默认值时,可能会遭遇意外行为。
Python 数据类的字段最好用 dataclasses.field
定义默认值。
下午 huaji 在写实验代码的时候分享说遇到了如下 bug:
@dataclass
class A():
f = torch.exp
g = torch.nn.functional.normalize
结果,A().f(tensor)
可以正常工作,但是 A().g(tensor)
却把这个 A
类型的数据类当作第一个参数传给了 normalize
。这是为什么?
首先考虑 g
的行为。在 Python 的 class
定义里,使用 def
定义的函数除非有特定装饰器,否则都会被转换成方法。既然是方法,那么第一个参数必须是 self
。而 Python 在 class
定义内部的 def
定义实际上等价于以下赋值形式:
class Class:
def method(self, ...): ...
# 实际上是
def function(self, ...): ...
class Class:
method = function
所以当我们把 torch.nn.functional.normalize
赋值给 g
的时候,它也首先被转换成了方法——调用 type(A().g)
显示的确是 method
。即使 A
是数据类,这种转换也会照常进行!
但更匪夷所思地,为什么 torch.exp
和 torch.nn.functional.normalize
的行为不一致?假如我们用 type
检查这两个函数的类型的话,会发现前者的类型是 builtin_function_or_method
,而后者是 function
。进一步地,假如用 inspect.isfunction
检验,会发现前者返回 False
而后者返回 True
。但假如你用 repr
查看,它们都被描述成 <function ...>
。
没错,在 Python 里面内置函数不是函数。
torch.exp
在运行时是一个 C++ 实现的外部函数。所以 class
不会把它转换成方法。同样地,还有其他一些我们当作函数使用但根本不是函数的可调用对象。
- 外部定义的函数。比如
torch.exp
和torch.tensor
- 内置函数。比如
map
和all
;特别地,某些标准库模块的函数可能有 Python 实现和 C 实现的两种版本,这取决于平台实现 - 类型。比如
str
和torch.nn.Module
numpy.ufunc
。numpy
的 API 会是这种类型numpy.vectorize
向量化过的函数functools.partial
创建的偏函数
所以,如果想要在数据类定义这种可调用对象的字段的话,还是用 dataclasses.field
最好,可以绕过 Python 类定义语法的隐式转换。但这时候 Callable
类型注解就不能省略。
from collections.abc import Callable
class A:
f: Callable = field(default=torch.log)
g: Callable = field(default=torch.nn.functional.normalize)
当然通过 staticmethod
关闭转换也是可以的。
class A:
f: Callable = staticmethod(torch.log)
g: Callable = staticmethod(torch.nn.functional.normalize)