PyTorch:torch.Tensor.unsqueeze()、squeeze()

本文详细介绍了PyTorch中unsqueeze和squeeze两个函数的用途。unsqueeze用于在指定位置增加一个维度,使得不同形状的张量能够进行运算,如在不满足广播机制的情况下进行相加。例如,将一个3*4的张量通过unsqueeze(1)转换为3*1*4,以匹配3*2*4的张量。而squeeze则用于删除张量中所有维度为1的轴,不影响数据本身,如将3*1*4的张量转换为3*4。
部署运行你感兴趣的模型镜像

目录

1、unsqueeze()

2、squeeze()

 

 


1、unsqueeze()

作用:给指定的tensor增加一个指定(之前不存在的)的维度。通常用在两tensor相加,但不满足shape一致,同时又不符合广播机制的情况下。

举例说明:

# 两tensor相加,但不满足shape一致,同时又不符合广播机制的例子:
a = torch.ones(3,2,4)  # 创建3*2*4的全为1的tensor
b = torch.zeros(3,4)   # 创建3*4的全为0的tensor
c = a + b   # 直接相加 这样相加会报错
print('a:',a)
print('b:',b)
print('c:',c)

'''   报错如下   '''
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

'''==================================================================================='''

# 两tensor相加,但不满足shape一致,使用unsqueeze对b进行维度扩充,使其满足广播机制的例子:
a = torch.ones(3,2,4)  # 创建3*2*4的全为1的tensor
b = torch.zeros(3,4)   # 创建3*4的全为0的tensor
b = b.unsqueeze(1) # 在增添一个维度1,即将b的shape转换为(3,1,4)
c = a + b   # 相加 这次就可以相加了
print('a:', a)
print('b:', b)
print('c:', c)

'''   结果如下   '''
a: tensor([[[1., 1., 1., 1.],
            [1., 1., 1., 1.]],
           [[1., 1., 1., 1.],
            [1., 1., 1., 1.]],
           [[1., 1., 1., 1.],
            [1., 1., 1., 1.]]])

b: tensor([[[0., 0., 0., 0.]],
           [[0., 0., 0., 0.]],
           [[0., 0., 0., 0.]]])

c: tensor([[[1., 1., 1., 1.],
            [1., 1., 1., 1.]],
           [[1., 1., 1., 1.],
            [1., 1., 1., 1.]],
           [[1., 1., 1., 1.],
            [1., 1., 1., 1.]]])

2、squeeze()

作用:与unsqueeze()相反,squeeze的作用是删除tensor中的所有维度为1的维度,注意只是删除维度,对数据没有影响。

举例说明:

import torch
a = torch.ones(3, 1, 4)   # 创建3*1*4的全为1的tensor
b = a.squeeze() # 删除b的维度为1的维度,即将b的shape转换为(3,4)
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)

'''   运行结果   '''
a: tensor([[[1., 1., 1., 1.]],
           [[1., 1., 1., 1.]],
           [[1., 1., 1., 1.]]])
shape of a: torch.Size([3, 1, 4])

b: tensor([[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]])
shape of b: torch.Size([3, 4])

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think>我们正在讨论的是RoPE(旋转位置嵌入)在Transformer模型中的实现,特别是在前向传播函数(forwardfunction)中如何处理位置信息。用户询问的是GCP(可能指某个具体模型或代码库,但这里更可能是泛指的模型实现)的forward函数实现细节,重点关注RoPE模型、position_ids和cache_position参数。在PyTorch中实现带有RoPE的Transformer模型时,forward函数通常需要以下步骤:1.准备输入(如隐藏状态、注意力掩码等)2.处理位置信息:通过position_ids确定每个token的位置3.应用RoPE:使用预计算的sin/cos缓存(通过cache_position访问)对query和key向量进行旋转4.计算注意力分数并完成后续操作我将分步解释关键实现点:###1.**位置ID(position_ids)的作用**-`position_ids`是一个张量,形状通常为`(batch_size,sequence_length)`,表示每个token在序列中的绝对位置索引(从0开始)。-在自回归生成中,`position_ids`通常是递增的,例如输入序列为[0,1,2,...,seq_len-1],但在处理padding或分段序列时可能不同。-在forward函数中,`position_ids`用于索引RoPE缓存中的对应位置。###2.**缓存位置(cache_position)的用途**-`cache_position`是RoPE缓存中的索引机制。它通常是一个张量,包含预计算的sin和cos值,形状为`(max_sequence_length,dim//2,2)`(其中dim是头维度大小)。-由于RoPE的三角函数值仅取决于位置和固定的频率参数,因此在模型初始化时即可预计算并缓存。-在推理时(尤其是生成任务),缓存避免了对每个新token重复计算相同位置的sin/cos值,从而提升效率。###3.**RoPE的具体应用**RoPE的数学原理是将query和key向量按位置旋转。对于query向量$q$在位置$m$处,旋转操作相当于:$$q_m=\begin{pmatrix}q_0\cosm\theta_0-q_1\sinm\theta_0\\q_1\cosm\theta_0+q_0\sinm\theta_0\\q_2\cosm\theta_1-q_3\sinm\theta_1\\\vdots\end{pmatrix}$$其中$\theta_j=10000^{-2j/d}$($d$是维度,$j$是维度索引)。在代码中,通过以下步骤实现:-**步骤1**:将query和key向量重塑为复数形式(实部为原值,虚部为交错位置)。-**步骤2**:根据`position_ids`获取对应位置的sin/cos值。-**步骤3**:执行旋转(复数乘法或使用实数矩阵乘法)。###4.**PyTorch实现示例**以下是一个简化版的forward函数片段(基于常见的Transformer类实现):```pythonimporttorchimporttorch.nnasnnclassRotaryEmbedding(nn.Module):def__init__(self,dim,max_seq_len=2048):super().__init__()inv_freq=1.0/(10000**(torch.arange(0,dim,2).float()/dim))t=torch.arange(max_seq_len).type_as(inv_freq)freqs=torch.einsum("i,j->ij",t,inv_freq)emb=torch.cat((freqs,freqs),dim=-1)self.register_buffer("cos_cached",emb.cos()[None,None,:,:])self.register_buffer("sin_cached",emb.sin()[None,None,:,:])defforward(self,x,seq_len=None):#返回缓存的cos/sin值,适用于整个序列returnself.cos_cached[:,:,:seq_len,...],self.sin_cached[:,:,:seq_len,...]classAttentionWithRoPE(nn.Module):def__init__(self,hidden_size,num_heads,max_seq_len=2048):super().__init__()self.rotary_emb=RotaryEmbedding(hidden_size//num_heads,max_seq_len)#...初始化其他注意力参数defforward(self,query,key,position_ids,attention_mask):batch_size,seq_len=query.size(0),query.size(1)head_dim=query.size(-1)//self.num_heads#步骤1:应用RoPE旋转cos,sin=self.rotary_emb(query,seq_len=seq_len)#调整cos/sin形状以匹配query/key的头部结构cos=cos.squeeze(1)#[1,seq_len,head_dim]sin=sin.squeeze(1)#将query/key重塑为多头的形式[batch,heads,seq_len,head_dim]query=query.view(batch_size,seq_len,self.num_heads,head_dim).transpose(1,2)key=key.view(batch_size,seq_len,self.num_heads,head_dim).transpose(1,2)#根据位置索引提取对应位置的cos/sin#position_ids形状:[batch_size,seq_len]->索引cos和sin#注意:这里假设position_ids可以直接用作索引,且cos/sin已对齐所有位置q_cos=cos[position_ids].unsqueeze(1)#[batch_size,1,seq_len,head_dim]q_sin=sin[position_ids].unsqueeze(1)#同样的方式处理key的位置(此处可能相同,但如果是双向则可能不同)#旋转query:q_rot=q*cos+rotate_half(q)*sinq_half_left=query[...,:head_dim//2]q_half_right=query[...,head_dim//2:]q_rot=torch.cat((-q_half_right,q_half_left),dim=-1)#对后半部分旋转90度query_rotated=query*q_cos+q_rot*q_sin#同样旋转keyk_half_left=key[...,:head_dim//2]k_half_right=key[...,head_dim//2:]k_rot=torch.cat((-k_half_right,k_half_left),dim=-1)key_rotated=key*q_cos+k_rot*q_sin#注意:这里使用key的cos/sin,若与query相同则用q_cos#步骤2:计算注意力分数#...(使用旋转后的query和key计算点积)scores=torch.matmul(query_rotated,key_rotated.transpose(-2,-1))/(head_dim**0.5)#应用attention_maskifattention_maskisnotNone:scores=scores+attention_maskattn_weights=torch.softmax(scores,dim=-1)#...后续操作returnattn_output```###关键实现细节:-**位置索引的动态处理**:`position_ids`可能因输入序列长度变化,因此缓存需要支持可变长序列(如本例中,`RotaryEmbedding`缓存了整个最大序列长度的值,然后切片至当前序列长度)。-**旋转操作的高效性**:使用`torch.cat`和逐元素乘法实现旋转,避免了复杂矩阵乘法(实际应用中通常使用融合算子优化)。-**缓存的注册**:通过`register_buffer`将cos/sin张量注册为模型的缓冲区,确保在设备(CPU/GPU)间正确移动。###实际库中的差异在HuggingFaceTransformers库中,类似的实现可在`GPTNeoX`或`Llama`模型的代码中找到(如`modeling_llama.py`中的`LlamaRotaryEmbedding`)。这些实现通常:-使用更高效的旋转公式(如直接使用复数乘法或向量化操作)。-支持增量解码(在`cache_position`中记录当前生成位置)。###总结RoPE模型的forward函数实现关键在于:1.使用`position_ids`标识每个token的位置。2.通过缓存(`cache_position`机制)避免重复计算三角函数值。3.对query和key进行旋转嵌入以编码位置信息[^1]。§§相关问题§§1.在增量解码过程中,`cache_position`如何更新?2.RoPE实现中的数值稳定性问题有哪些解决方法?3.如何处理位置索引超出预计算缓存长度的情况(如外推)?[^1]:旋转位置嵌入的理论基础:Su,J.etal.(2022).*RoFormer:EnhancedTransformerwithRotaryPositionEmbedding*.arXiv:2104.09864.</think> 以下是关于RoPE模型的`forward`函数实现细节的完整解析,重点解释`position_ids`和`cache_position`在PyTorch中的具体应用: ### 一、RoPE模型的前向传播架构 RoPE模型的`forward`函数通常包含以下核心模块: ```python def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, ...]: # 1. 位置标识处理 seq_len = hidden_states.shape[1] if position_ids is None: position_ids = self._init_position_ids(seq_len) # 默认位置序列 # 2. 缓存位置处理 if cache_position is None: cache_position = position_ids elif cache_position.shape[1] < seq_len: cache_position = self._update_cache(cache_position, seq_len) # 3. 应用RoPE位置编码 q, k = self._apply_rope( query_states, key_states, position_ids=cache_position # 使用缓存位置索引 ) # 4. 注意力计算 attn_output = self._attn(q, k, v, attention_mask) return (attn_output, cache_position) # 返回更新后的缓存 ``` ### 二、核心参数解析 #### 1. `position_ids`:位置标识张量 - **作用**:标识每个token在序列中的绝对位置(形状:`(batch_size, seq_length)`) - **生成逻辑**: ```python def _init_position_ids(self, seq_len): return torch.arange(seq_len, device=self.device).expand(1, -1) ``` - **使用场景**: - 训练阶段:每次从0开始生成新位置序列 - 增量推理:通过`cache_position`动态扩展位置序列 #### 2. `cache_position`:缓存位置系统 - **核心作用**:管理位置编码的预计算缓存 - **数据格式**:整型张量`(batch_size, current_length)` - **生命周期**: ```mermaid graph LR A[第一次调用] --> B[创建初始缓存: 0->N] B --> C[后续调用] C --> D{新序列更长?} D -->|是| E[扩展缓存: N->M] D -->|否| F[直接使用缓存] ``` ### 三、RoPE应用的关键实现 #### 1. 缓存初始化(预计算正弦/余弦) ```python class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=2048): super().__init__() # 频率参数计算 inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) # 预计算所有位置的旋转矩阵 position = torch.arange(max_seq_len) sinusoid = torch.einsum('i,j->ij', position, inv_freq) emb = torch.cat([sinusoid.sin(), sinusoid.cos()], dim=-1) # 注册为不参与训练的缓存 self.register_buffer("rotary_emb", emb) # [max_len, dim] ``` #### 2. 缓存应用(前向传播) ```python def _apply_rope(self, q, k, position_ids): # 获取缓存的旋转矩阵子集 emb = self.rotary_emb[position_ids] # [bs, seq_len, dim] # 将旋转矩阵分解为正弦/余弦分量 cos = emb[..., : emb.shape[-1] // 2] sin = emb[..., emb.shape[-1] // 2 :] # 旋转向量计算 (高效实现) q_embed = q * cos + self._rotate_half(q) * sin k_embed = k * cos + self._rotate_half(k) * sin return q_embed, k_embed ``` #### 3. 旋转辅助函数 ```python def _rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) ``` ### 四、增量推理的特殊处理 在文本生成场景下的缓存更新机制: ```python def _update_cache(self, cache_position, new_len): # 创建新增位置序列 new_positions = torch.arange( cache_position.shape[-1], new_len, device=cache_position.device ) # 水平拼接实现缓存扩展 updated_cache = torch.cat([cache_position, new_positions.unsqueeze(0)], dim=1) # 更新预计算缓存(如果超出原始范围) if new_len > self.rotary_emb.size(0): self._extend_rotary_cache(new_len) return updated_cache ``` ### 五、典型使用场景示例 #### 1. 训练阶段调用 ```python outputs = model( input_ids, position_ids=None, # 自动生成0->511 cache_position=None # 等同于position_ids ) ``` #### 2. 文本生成推理 ```python cache_pos = None for i in range(max_new_tokens): outputs = model( input_ids=next_token, cache_position=cache_pos # 传入历史缓存 ) logits, cache_pos = outputs # 接收更新后的缓存 ``` ### 六、设计优势与注意事项 1. **性能优势**: - 将$O(L^2)$的位置计算复杂度降为$O(1)$的缓存访问 - 避免每次前向传播重复计算三角函数 2. **使用限制**: - 最大序列长度受预计算缓存限制 - 动态扩展缓存需保证旋转矩阵维度一致性 3. **最佳实践**: ```python # 初始化时预留足够缓存空间 model = RoPEModel(max_position=4096) # 输入位置溢出处理 if current_len > model.max_pos: position_ids = position_ids % model.max_pos # 循环位置策略 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

地球被支点撬走啦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值