Embedding源码解析
这部分代码相对比较简单,先附上官网的doc链接,这里重点介绍我们常用到的一些东西
# 初始化函数
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
sparse=False, _weight=None):
参数说明:
- (num_embeddings, embedding_dim)对应token的数目,和每个token对应的维度,这里都是index形式。
- padding_idx指明哪个idx为全0,且不需梯度更

本文解析了PyTorch中Embedding模块的源码,详细介绍了初始化函数及from_pretrained函数的工作原理,特别关注padding_idx参数的作用,并讨论了在使用不同精度处理weights时遇到的cudnn错误及解决方案。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



