目录
因为科研需要,正在把一个工程化的 tf Transformer 项目转为 torch,便于后续实验和改进。在这个过程中踩了不少坑,写一篇文章记录下来。转载请注明出处。
常规修改
常规的修改包括将 tf 的函数、模块直接替换成 torch 中相同功能的函数或模块,这里举一些例子:
TensorFlow | Torch |
---|---|
tf.keras.layers.Layer | torch.nn.Module |
tf.keras.Model | torch.nn.Module |
tf.keras.layers.Dense() | torch.nn.Linear() |
tf.expand_dims() | Tensor.unsqueeze() |
tf.reduce_all() | torch.all() |
这都是很基础的内容,改几个函数或模块之后就熟悉了。其中有一些参数名也会有变化,例如 tf 中的 axis
参数名到了 torch 中会变成 dim
。
另外也有 torch 简化的内容,例如生成一个张量,同时指定数据类型和设备,可以调用 torch.tensor(x, dtype=torch.float32, device=yy.device)
。
特别需要注意的是,tf 是静态图,torch 是动态图,因此代码中涉及到事先指定张量形状的一些代码都是不必要的,例如后面会提到的 tf.while_loop
。
TensoFlow 中特有的函数
这部分有一些麻烦,tf 中有一部分函数在 torch 中没有直接对应的函数。这种情况大多数都可以问大模型生成,但有一部分函数,现阶段大模型并不能给出比较好的结果,用上去会出 bug。并且这样的 bug 并不好找,因此浪费了不少时间。这里记录一下一些有意思的函数:
tf.while_loop
这个函数是 tf 框架下的循环,需要指定循环的条件函数 cond
、循环体函数 body
、循环变量函数 loop_vars
、形状不变量 shape_invaiants
,可能还有并行数等。这样的 while_loop
可能是可以融合到 tf 的框架中实现优化加速,但是在 torch 中没有对应,只能用 python 本身的 while 循环来执行:
while condition(loop_var):
loop body
这里并不涉及到 shape,因为 torch 是动态图,不用考虑形状问题。tf 是静态图,需要指明形状。
tf.gather_nd
这个函数 gather_nd(params, indices)
的语义大致是,接收一个张量和索引,按照索引的范围组合出一个新的张量。我一开始尝试了用大模型生成一个等效的函数,结果并不能使用。然后在网上找是否有现成的实现。在国内的网站上找到了很多同质的实现,代码都是一样的,我就直接放进了项目。后来发现,目前很容易搜到的 gather_nd
的等价实现有 bug! 在我的项目中会产生维度错误,具体的表现是函数能正常结束,但是产生的结果的维度和预期不符。我没有深究 bug 是怎么产生的,最后去 StackOverflow: how-to-implement-tf-gather-nd-in-pytorch-with-the-argument-batch-dims 上找到了一个更好的实现(目前在我的项目里面没有出问题)。
原网站的最后有两个用户都给出了手动实现的版本,有一个是借用 numpy 实现的,只能运行于 CPU(我没有测试 CPU 版本的正确性),我直接采用了后面的 GPU 版本的,并把代码贴在下面。
def gather_nd(params, indices):
"""https://stackoverflow.com/questions/61783826/how-to-implement-tf-gather-nd-in-pytorch-with-the-argument-batch-dims"""
f = params
index = indices
shape = f.shape
ind_shape = index.shape
N = ind_shape[-1]
padding = [0, 0] * (len(shape) - N) + [0, 1] * N
f = torch.nn.functional.pad(f, padding)
bc = torch.tensor(shape[0:N], dtype=f.dtype, device=f.device)
bc = torch.broadcast_to(bc, ind_shape)
T1 = list(range(len(ind_shape)))
index = torch.where((bc > index) * (index >= 0), index, torch.tensor(shape[0:N], device=f.device))
index = torch.permute(index, [T1[-1]] + T1[0:-1])
return f[[ind for ind in index]]
tf.nest.map_structure
这个函数 tf.nest.map_structure(func: Any, *structure: Any, **kwargs: Any) -> Union[defaultdict, list, object]
接收一个函数和一个可以展开的结构,然后把函数作用在每一个结构展开的子成分上;如果不可展开,就直接作用在输入的变量上。这个可以采用递归的方式实现一个等价的,直接上代码
def map_structure(func, structure):
if isinstance(structure, dict):
return {k: map_structure(func, v) for k, v in structure.items()}
elif isinstance(structure, (list, tuple)):
return [map_structure(func, item) for item in structure]
else:
return func(structure)
其他问题
Dataloader 内存泄漏
tf 中的 dataloader 可以接收一个 Generator,并且这个 generator 生成的数据可以是 tuple。但是在 torch 中,我的 Dataset 采用的是 IterableDataset
,并且返回的是 tuple:(Tensor, Tensor), Tensor
,这样的结构产生了内存泄漏。最后修改成 yield Tensor, Tensor, Tensor
修复了内存泄漏
F.scaled_dot_product_attention 掩码参数问题
这是我找了最久的一个 bug。我看见 torch 的库里面有 F.scaled_dot_product_attention
函数,并且可以调用 flash attention 加速,就直接采用了库里的 F.scaled_dot_product_attention
。完成训练后,发现生成的序列中从第 3 位开始都是重复的 token。逐行对比和原来 tf 代码的差异都没有找到问题。最后看了 torch 的官方文档,里面给出了 F.scaled_dot_product_attention
的等效实现:
# from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
可以注意到,坑在这里 ↓
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
这里判断了 attn_mask
的类型,如果是 torch.bool
就取非再替换成 -inf
,否则直接相加。即,如果不是 bool,默认掩码是一个绝对值很大的负数。而我的项目中,掩码是 torch.float32
类型的,并且是用 0 和 1 区分的。直接调用的效果近似于没加掩码,这也解释了为什么生成的序列会大量重复。
我尝试改成了 attn_bias += attn_mask * (-1e7)
,结果发生了梯度爆炸,换成 float("-inf")
也是,最后查明原因是后面再执行 softmax
时,过小的 logits
发生了溢出,出现了除零错误。最后改成了 -1e4
得以解决。
模块的参数不随模型一起迁移
在 Encoder 和 Decoder 中会有多个相同的 Layer,这在 tf 中直接用一个列表生成:
self.decoder_layers = [TransformerDecoderLayer(params) for _ in range(num_layers_dec)]
但是在 torch 中,列表中的模块没有被注册上。一个表现是模型和数据都迁移到 GPU 上之后,在运算时仍然提示某一个参数才 CPU 上。必须用 nn.ModuleList
:
self.encoder_layers = nn.ModuleList([TransformerEncoderLayer(params) for _ in range(num_layers_enc)])
这样就注册上了,不会引起别的问题。