torch.nn.Embedding理解

本文介绍了 PyTorch 中的 Embedding 层,它是一个用于存储词嵌入的简单查找表。文章解释了如何使用该层根据词汇索引获取对应的词向量,并讨论了其在自然语言处理任务中的应用。
部署运行你感兴趣的模型镜像

Pytorch(0.3.1)官网的解释是:一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索它们。模块的输入是一个下标的列表,输出是对应的词嵌入。

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

个人理解:这是一个矩阵类,里面初始化了一个随机矩阵,矩阵的长是字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。

输入下标0,输出就是embeds矩阵中第0行。

放代码:

调试过程的参数:

用途:用作自然语言处理中作用很大

而对于一个词,我们自己去想它的属性不是很困难吗,所以这个时候就可以交给神经网络了,我们只需要定义我们想要的维度,比如100,然后通过神经网络去学习它的每一个属性的大小,而我们并不用关心到底这个属性代表着什么,我们只需要知道词向量的夹角越小,表示他们之间的语义更加接近

参考网址:https://my.oschina.net/earnp/blog/1113896

http://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

`torch.nn.functional.embedding` 是 PyTorch 中用于实现词嵌入(Embedding)操作的函数。词嵌入是将离散的符号(如单词、ID 等)映射为连续向量的技术,在自然语言处理、推荐系统等领域应用广泛。 ### 原理 词嵌入的核心思想是将每个离散的符号表示为一个固定长度的向量,这些向量能够捕捉符号之间的语义和句法关系。具体来说,`torch.nn.functional.embedding` 维护一个嵌入矩阵(Embedding Matrix),矩阵的每一行对应一个符号的嵌入向量。当输入一个符号的 ID 时,函数会从嵌入矩阵中取出对应行的向量作为输出。 ### 使用方法 `torch.nn.functional.embedding` 的函数签名如下: ```python torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False) ``` - `input`:输入的符号 ID 张量,可以是任意形状。 - `weight`:嵌入矩阵,形状为 `(num_embeddings, embedding_dim)`,其中 `num_embeddings` 是符号的总数,`embedding_dim` 是嵌入向量的维度。 - `padding_idx`:可选参数,指定用于填充的符号 ID,填充位置的嵌入向量将被置为零。 - `max_norm`:可选参数,指定嵌入向量的最大范数。 - `norm_type`:可选参数,指定计算范数的类型,默认为 2 范数。 - `scale_grad_by_freq`:可选参数,指定是否根据符号的频率对梯度进行缩放。 - `sparse`:可选参数,指定是否使用稀疏梯度。 以下是一个简单的使用示例: ```python import torch import torch.nn.functional as F # 定义嵌入矩阵 num_embeddings = 10 embedding_dim = 5 weight = torch.randn(num_embeddings, embedding_dim) # 输入符号 ID input_ids = torch.tensor([1, 3, 5]) # 进行嵌入操作 embedded = F.embedding(input_ids, weight) print(embedded) ``` ### 注意事项 - 输入的符号 ID 必须是整数类型,且范围在 `[0, num_embeddings - 1]` 之间。 - 嵌入矩阵 `weight` 通常是可训练的参数,可以通过 `torch.nn.Embedding` 模块来创建和管理。
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值