理论学习:expand_as

文章介绍了如何使用torch库处理数据,计算给定预测分数与真实标签的top-k准确度,适用于监督学习任务。
部署运行你感兴趣的模型镜像
import torch

# 假设输出和目标
output = torch.tensor([
    [0.1, 0.2, 0.6, 0.1],  # 第一个样本的预测分数
    [0.1, 0.4, 0.2, 0.3],  # 第二个样本的预测分数
    [0.5, 0.1, 0.2, 0.2]   # 第三个样本的预测分数
])
target = torch.tensor([2, 1, 0])  # 真实的标签
# 定义准确率计算函数
def accuracy(output, target, topk):
    maxk = max(topk)
    _,pred = output.topk(maxk, 1, True, True)
    # pred = pred.t()
    # print(pred)
    print(target.view(1,-1).expand_as(pred))

# 计算准确率
accuracy(output, target, topk=(1,3))

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

def precompute_freqs_cis( hidden_size: int, max_seq_len: int, base: int = 10000, num_attention_heads: int = 8 ) -> torch.Tensor: """ 预计算复数形式的旋转位置编码 (Rotary Position Embedding, RoPE) 🔧 核心思想: - 每个 token 的表示在每个维度上被看作一个二维向量 (x, y) - 位置 i 对应的向量会被旋转 i×θ_d 弧度 - 使用复数乘法来高效实现这种旋转:e^(iθ) * (x + iy) 📚 所属知识领域: - 位置编码设计(Positional Encoding) - 复数数学在深度学习中的应用 - 因果语言建模的位置感知机制 Args: hidden_size: 模型隐藏层维度(如 512) max_seq_len: 最大支持序列长度(如 8192) base: 控制频率衰减速度的基数,默认为 10000 num_attention_heads: 注意力头数(用于确定 head_dim) Returns: freqs_cis: complex tensor of shape (max_seq_len, head_dim) 表示每个位置对应的复数旋转因子 e^(iθ) """ head_dim = hidden_size // num_attention_heads # 每个注意力头的维度 # 计算逆频率:θ_d = 1 / (base^(2d/head_dim)) → 控制不同维度有不同的周期性 inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(max_seq_len) # 位置索引 [0, 1, ..., S-1] # 构造每个位置和每个频率的乘积:freqs[s,d] = t[s] * inv_freq[d] freqs = torch.einsum("s,d->sd", t, inv_freq) # (S, D//2) # 将频率复制两次,扩展到完整 head_dim 维度 # 原因:我们对相邻两维做联合旋转(如 x₀,x₁ → 旋转成新坐标),所以需要两个相同频率 freqs = torch.cat([freqs, freqs], dim=-1) # now (S, D) # 转换为复数形式:cos(θ) + i*sin(θ),即单位圆上的点 freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # magnitude=1, angle=freqs return freqs_cis # shape: (max_seq_len, head_dim) def _rotate_half(x: torch.Tensor) -> torch.Tensor: """ 在最后一个维度上,将后半部分移到前面并取负。 示例: x = [x0, x1, x2, x3] → [-x2, -x3, x0, x1] 这对应于复数中乘以 i 的操作(相当于旋转90度): i*(x0 + ix1) = -x1 + ix0 → 交换并加负号 📌 注意:这是 apply_rotary_pos_emb 中的关键辅助函数 Args: x: 输入张量,shape = (*, D) Returns: 旋转后的张量,shape = (*, D) """ x1, x2 = x.chunk(2, dim=-1) # 分成前后两半:x1 是前半,x2 是后半 return torch.cat((-x2, x1), dim=-1) # 后半取负放前,形成 [-x2, x1] def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ 将 RoPE 应用于 query 和 key 张量。 数学公式: q_embed = q * cos(θ) + rotate_half(q) * sin(θ) k_embed = k * cos(θ) + rotate_half(k) * sin(θ) 📚 理论来源: - RoFormer: https://arxiv.org/abs/2104.09864 - LLaMA 使用此方法替代绝对位置嵌入 Args: q: Query 张量,shape = (bsz, n_heads, seq_len, head_dim) k: Key 张量,shape = (bsz, n_heads, seq_len, head_dim) freqs_cis: 预计算的复数旋转因子,shape = (seq_len, head_dim) Returns: q_embed: 加入位置信息的 query k_embed: 加入位置信息的 key """ # 提取实部(cos)和虚部(sin),并调整形状以便广播 cos = freqs_cis.real.view(1, 1, freqs_cis.size(0), -1) # (1,1,S,D) sin = freqs_cis.imag.view(1, 1, freqs_cis.size(0), -1) # (1,1,S,D) # 截断到当前序列长度(防止越界) q_len = q.size(2) k_len = k.size(2) # 应用旋转公式 q_out = (q * cos[:, :, :q_len, :]) + (_rotate_half(q) * sin[:, :, :q_len, :]) k_out = (k * cos[:, :, :k_len, :]) + (_rotate_half(k) * sin[:, :, :k_len, :]) return q_out.type_as(q), k_out.type_as(k) 检查一下公式和计算
最新发布
12-03
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值