Pytorch in-place操作不允许广播,要求操作前后张量形状、数据类型完全一致
计算loss时遇到报错:
RuntimeError: output with shape [] doesn't match the broadcast shape [1]
发现原因是在累加loss时是in-place地更新loss,这其中遇到了标量tensor和dim=1的tensor的加,导致报错。下面代码可以还原该报错:
>>> a = torch.tensor(1)
>>> b = torch.zeros(1)
>>> a
tensor(1)
>>> b
tensor([0.])
>>> a / b
tensor([inf])
>>> a /= b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: output with shape [] doesn't match the broadcast shape [1]
报错的根本原因是pytorch不允许in-place操作时,左右tensor的shape不一致,因为in-place操作时不会自动广播(内存已提前分配好),但非in-plac操作时由于会自动创建张量所以不存在这个问题,会自动广播;
同样地,in-place操作时不允许左右tensor的dtype不一致,不会发生自动转换,但非in-place由于会自动创建新的张量,所以会自动转换:
>>> a = torch.tensor([1], dtype=torch.int32)
>>> b = torch.tensor([1.0], dtype=torch.float32)
>>> a
tensor([1], dtype=torch.int32)
>>> b
tensor([1.])
>>> a / b
tensor([1.])
>>> a /= b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: result type Float can't be cast to the desired output type Int