稀疏嵌入(sparse embeddings)是一个计算稀疏嵌入的方法,通过输入的隐藏状态和 token ID 生成稀疏嵌入,并对未使用的 token 进行处理,以确保它们不会影响模型的后续操作。
以下代码是一个 PyTorch 方法的实现:
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
token_weights = torch.relu(self.sparse_linear(hidden_state))
if not return_embedding: return token_weights
sparse_embedding = torch.zeros(input_ids.size(0), input_ids.size(1), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.

最低0.47元/天 解锁文章
959

被折叠的 条评论
为什么被折叠?



