_,predicted = torch.max(outputs.data,dim)

本文介绍了PyTorch中`torch.max()`函数在不同维度(dim=0, dim=1)下返回最大值和对应索引的方法,并通过实例演示了如何获取行或列的最大值及其索引。
部署运行你感兴趣的模型镜像

 dim=1时,按返回最大值所在索引

 dim=0时,按返回最大值所在索引

_,predicted = torch.max(outputs.data,dim):返回最大值所在索引

predicted = torch.max(outputs.data,dim):返回最大值


import torch
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678]])
_,predicted = torch.max(tensor,1)
print(predicted)

'''
返回最大值所在的索引
tensor([5])
'''


tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678]])
predicted1 = torch.max(tensor,1)
print(predicted1)

'''
返回最大值和其所在索引
torch.return_types.max(
values=tensor([5.6780]),
indices=tensor([5]))
'''
import torch
#0按列返回
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678],[1,2,3,4,5,0]])
_,predicted = torch.max(tensor,0)
print(predicted)

'''
按列返回最大值所在的索引,此处只有两个分类结果,即0,1列
tensor([0, 1, 1, 1, 1, 0])
'''


tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678],[1,2,3,4,5,0]])
predicted1 = torch.max(tensor,0)
print(predicted1)

'''
返回最大值和所在索引
torch.return_types.max(
values=tensor([1.2000, 2.
0000, 3.0000, 4.0000, 5.0000, 5.6780]),
indices=tensor([0, 1, 1, 1, 1, 0]))
'''

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

# %% import torch import torch.nn as nn import math from typing import Optional, Tuple, Dict, List import matplotlib.pyplot as plt import numpy as np import random # 设置随机种子 torch.manual_seed(42) np.random.seed(42) random.seed(42) # %% [markdown] # 1. RoPE: 旋转位置编码(完全自实现) # %% def precompute_freqs_cis( hidden_size: int, max_seq_len: int, base: int = 10000, num_attention_heads: int = 8 ) -> torch.Tensor: """ 预计算复数形式的旋转位置编码 (RoPE) Args: hidden_size: 模型维度 max_seq_len: 最大序列长度 base: 频率基数 num_attention_heads: 注意力头数 Returns: freqs_cis: complex tensor of shape (max_seq_len, head_dim // 2) """ head_dim = hidden_size // num_attention_heads inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(max_seq_len) # 位置索引 [0, 1, ..., S-1] freqs = torch.einsum("s,d->sd", t, inv_freq) # (S, D//2) # 转为复数角度:cosθ + i*sinθ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis # shape: (max_seq_len, head_dim // 2) def _rotate_half(x: torch.Tensor) -> torch.Tensor: """ 在最后一个维度上,将后半部分移到前面并取负。 例如:x = [x0, x1, x2, x3] -> [-x2, -x3, x0, x1] 这对应于乘以 i 的操作(复数旋转90度) """ x1, x2 = x.chunk(2, dim=-1) # 分成前后两半 return torch.cat((-x2, x1), dim=-1) # 后半取负放前 def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ 将 RoPE 应用于 query 和 key 张量。 Args: q: (bsz, n_heads, seq_len, head_dim) k: (bsz, n_heads, seq_len, head_dim) freqs_cis: (seq_len, head_dim // 2) Returns: q_embed, k_embed """ # 提取实部和虚部,并扩展维度以便广播 cos = freqs_cis.real.view(1, 1, -1, 1) # (1, 1, S, 1) sin = freqs_cis.imag.view(1, 1, -1, 1) # 使用标准公式: q * cos + rotate_half(q) * sin q_out = (q * cos[:, :, :q.size(2), :]) + (_rotate_half(q) * sin[:, :, :q.size(2), :]) k_out = (k * cos[:, :, :k.size(2), :]) + (_rotate_half(k) * sin[:, :, :k.size(2), :]) return q_out.type_as(q), k_out.type_as(k) # %% [markdown] # 2. Hybrid Embedding Layer (实现字符与数值的综合判别) # %% class HybridEmbedding(nn.Module): def __init__( self, vocab_size: int, hidden_size: int, max_value: int = 10000, value_offset: int = 30000 ): super().__init__() self.hidden_size = hidden_size self.max_value = max_value self.value_offset = value_offset # 1. 标准词表 embedding(用于语言) self.word_embeddings = nn.Embedding(vocab_size, hidden_size) self.word_embeddings.weight.data.normal_(mean=0.0, std=0.02) # 2. 数值感知 embedding(用于真实数值) self.value_embedding = nn.Linear(1, hidden_size) # 初始化更稳定 nn.init.xavier_uniform_(self.value_embedding.weight) if self.value_embedding.bias is not None: nn.init.zeros_(self.value_embedding.bias) # 3. 类型标识:0=word, 1=value self.type_embedding = nn.Embedding(2, hidden_size) # 4. 位置编码 self.position_embeddings = nn.Embedding(8192, hidden_size) self.position_embeddings.weight.data.normal_(mean=0.0, std=0.02) self.dropout = nn.Dropout(0.1) def forward( self, input_ids: torch.LongTensor, value_mask: Optional[torch.BoolTensor] = None, position_ids: Optional[torch.LongTensor] = None ) -> torch.Tensor: """ input_ids: (B, S) - < 30000: 正常词汇 - >=30000: 表示数值,减去 offset 得到真实数值 value_mask: (B, S), bool, True 表示该位置是数值 """ batch_size, seq_length = input_ids.shape if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) embeddings = torch.zeros(batch_size, seq_length, self.hidden_size, device=input_ids.device) # ==== Step 1: 处理普通词汇 ==== word_mask = ~value_mask if value_mask is not None else (input_ids < self.value_offset) if word_mask.any(): word_embs = self.word_embeddings(input_ids[word_mask]) embeddings[word_mask] = word_embs # ==== Step 2: 处理数值 ==== if value_mask is not None and value_mask.any(): raw_values = input_ids[value_mask].float() - self.value_offset # 解码真实值 values = raw_values.unsqueeze(-1) # (N, 1) value_embs = torch.tanh(self.value_embedding(values)) # 稳定激活 embeddings[value_mask] = value_embs # ==== Step 3: 加入类型和位置信息 ==== type_ids = torch.where(value_mask, 1, 0) if value_mask is not None \ else torch.zeros_like(input_ids) type_embs = self.type_embedding(type_ids) pos_embs = self.position_embeddings(position_ids) embeddings = embeddings + type_embs + pos_embs embeddings = self.dropout(embeddings) return embeddings # %% [markdown] # 3. 自定义注意力层 # %% class YiziAttention(nn.Module): def __init__(self, hidden_size: int, num_heads: int): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) # 用于调试:存储最后一次 attention 权重 self.attn_weights = None def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: bsz, seq_len, _ = x.shape # 投影到 QKV query_states = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # 应用 RoPE query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freqs_cis) # Scaled Dot-Product Attention attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale attn_weights = torch.softmax(attn_weights, dim=-1) # 存储 attention map(可用于可视化) self.attn_weights = attn_weights.detach() # 合并头 attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size) return self.o_proj(attn_output) # %% [markdown] # 4. 前馈网络 + 残差连接 # %% class YiziBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int): super().__init__() self.attn = YiziAttention(hidden_size, num_heads) self.norm1 = nn.LayerNorm(hidden_size) self.mlp = nn.Sequential( nn.Linear(hidden_size, intermediate_size), nn.GELU(), nn.Linear(intermediate_size, hidden_size) ) self.norm2 = nn.LayerNorm(hidden_size) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: # 注意力残差连接 x = x + self.attn(self.norm1(x), freqs_cis) # MLP 残差连接 x = x + self.mlp(self.norm2(x)) return x # %% [markdown] # 5. 主模型:YiziLM # %% class YiziLMConfig: def __init__( self, vocab_size: int = 30000, hidden_size: int = 512, num_hidden_layers: int = 6, num_attention_heads: int = 8, max_position_embeddings: int = 8192, intermediate_size: int = 2048, value_offset: int = 30000 ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.max_position_embeddings = max_position_embeddings self.intermediate_size = intermediate_size self.value_offset = value_offset class YiziLM(nn.Module): def __init__(self, config: YiziLMConfig): super().__init__() self.config = config self.value_head = nn.Linear(config.hidden_size, 1) # 输出 1 维实数 self.value_mean = 500.0 # 假设数据集中在 100~999 self.value_scale = 500.0 # 归一化因子 nn.init.xavier_uniform_(self.value_head.weight, gain=1e-3) nn.init.zeros_(self.value_head.bias) # 使用 Hybrid Embedding 替代传统嵌入 self.embed_tokens = HybridEmbedding( vocab_size=config.vocab_size, hidden_size=config.hidden_size, max_value=10000, value_offset=config.value_offset ) # # 词嵌入 # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # 预计算所有位置的 RoPE 编码,并注册为 buffer(随模型保存) freqs_cis = precompute_freqs_cis( hidden_size=config.hidden_size, max_seq_len=config.max_position_embeddings, base=10000, num_attention_heads=config.num_attention_heads ) self.register_buffer("freqs_cis", freqs_cis) # 注册为 buffer # Transformer 层堆叠 self.layers = nn.ModuleList([ YiziBlock( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, intermediate_size=config.intermediate_size ) for _ in range(config.num_hidden_layers) ]) # 输出归一化 self.norm = nn.LayerNorm(config.hidden_size) # === 双头输出 === # 头 1: 语言 token 分类 (0 ~ vocab_size-1) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head.weight = self.embed_tokens.word_embeddings.weight # 共享权重 # 头 2: 数值回归 head —— 输出一个 scalar(预测数值) self.value_head = nn.Linear(config.hidden_size, 1) # 输出 1 维实数 # self.embed_tokens.weight.data.normal_(mean=0.0, std=0.02) # 标准初始化 # 初始化 self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, input_ids: torch.LongTensor, value_mask: Optional[torch.BoolTensor] = None, labels: Optional[torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # 获取当前序列长度对应的 freqs_cis seq_len = input_ids.size(1) freqs_cis = self.freqs_cis[:seq_len] # 使用 hybrid embedding inputs_embeds = self.embed_tokens(input_ids, value_mask) hidden_states = inputs_embeds for layer in self.layers: hidden_states = layer(hidden_states, freqs_cis) hidden_states = self.norm(hidden_states) # (B, S, D) # === 双头预测 === word_logits = self.lm_head(hidden_states) # (B, S, V) value_pred = self.value_head(hidden_states).squeeze(-1) # (B, S) # 我们不直接返回 value_pred 到 logits,而是用于特殊位置训练 output = {} if labels is not None and value_mask is not None: total_loss = 0.0 n_losses = 0 # ==== 1. 普通 token 损失:交叉熵 ==== word_positions = ~(value_mask) # 非数值位置 if word_positions.any(): word_labels = labels[word_positions] # 过滤掉 ignore_index == -100 valid_mask = (word_labels != -100) if valid_mask.any(): pred = word_logits[word_positions][valid_mask] target = word_labels[valid_mask] # 确保 target 在 [0, vocab_size) target = torch.clamp(target, 0, self.config.vocab_size - 1) loss_word = nn.functional.cross_entropy(pred, target) total_loss += loss_word n_losses += 1 # ==== 2. 数值位置损失:回归 MSE Loss ==== value_positions = value_mask & (labels != -100) # 数值且非 padding if value_positions.any(): # 提取真实值并归一化 raw_values = labels[value_positions].float() - self.config.value_offset normalized_values = (raw_values - self.value_mean) / self.value_scale pred_normalized = value_pred[value_positions] loss_value = nn.functional.mse_loss(pred_normalized, normalized_values) total_loss += loss_value n_losses += 1 loss = total_loss / n_losses if n_losses > 0 else total_loss output["loss"] = loss # 返回两个 logits 方便生成时决策 output["word_logits"] = word_logits output["value_pred"] = value_pred # 回归值(未离散化) return output @torch.no_grad() def generate( self, input_ids: torch.LongTensor, value_mask: Optional[torch.BoolTensor] = None, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 50, eos_token_id: Optional[int] = None, strategy: str = "hybrid" # 'hybrid', 'force_value', 'force_word' ) -> Tuple[torch.LongTensor, torch.BoolTensor]: """ 自回归生成,支持 hybrid 输出策略 """ for _ in range(max_new_tokens): outputs = self.forward(input_ids, value_mask) word_logits = outputs["word_logits"][:, -1, :] # (B, V) value_pred_raw = outputs["value_pred"][:, -1] # (B,) next_token_id = None is_value = False if strategy == "hybrid": # 启发式判断:如果前一个是数值,则继续预测数值 if value_mask is not None and value_mask[0, -1]: # 前一个是数值 → 当前也预测数值 predicted_value = round(value_pred_raw.item()) next_token_id = predicted_value + self.config.value_offset is_value = True else: # 否则走语言路径 logit = word_logits / temperature if top_k > 0: v, _ = torch.topk(logit, min(top_k, logit.size(-1))) pivot = v[:, [-1]] logit = torch.where(logit < pivot, -float('inf'), logit) prob = torch.softmax(logit, dim=-1) next_token_id = torch.multinomial(prob, num_samples=1).item() is_value = False elif strategy == "force_value": predicted_value = round(value_pred_raw.item()) next_token_id = predicted_value + self.config.value_offset is_value = True else: logit = word_logits / temperature if top_k > 0: v, _ = torch.topk(logit, min(top_k, logit.size(-1))) pivot = v[:, [-1]] logit = torch.where(logit < pivot, -float('inf'), logit) prob = torch.softmax(logit, dim=-1) next_token_id = torch.multinomial(prob, num_samples=1).item() is_value = False # 构造新 token new_token = torch.tensor([[next_token_id]], device=input_ids.device) input_ids = torch.cat([input_ids, new_token], dim=1) new_is_value = torch.tensor([[is_value]], dtype=torch.bool, device=value_mask.device) value_mask = torch.cat([value_mask, new_is_value], dim=1) if eos_token_id is not None and next_token_id == eos_token_id: break return input_ids, value_mask # %% [markdown] # 6. 数据生成 # %% # 词汇表模拟(简化版) word_to_id = { "the": 100, "next": 101, "number": 102, "after": 103, "is": 104, "predict": 105, "?": 106, "[EOS]": 107 } UNK_ID = 108 VALUE_OFFSET = 30000 def encode_input(tokens: List) -> Tuple[torch.Tensor, torch.BoolTensor]: """ 将混合 tokens 编码为 input_ids 和 value_mask 支持字符串或整数 """ input_ids = [] value_mask = [] for t in tokens: if isinstance(t, str): tid = word_to_id.get(t.lower(), UNK_ID) input_ids.append(tid) value_mask.append(False) elif isinstance(t, int): raw_id = t + VALUE_OFFSET input_ids.append(raw_id) value_mask.append(True) return ( torch.tensor([input_ids], dtype=torch.long), torch.tensor([value_mask], dtype=torch.bool) ) def generate_arithmetic_batch_v2( batch_size: int = 32, context_len: int = 4, pred_len: int = 1, num_range: tuple = (100, 999), step_candidates: list = None ): """ 生成带有自然语言提示的等差数列预测任务 输入:["the", "sequence", 260, 270, 280, "?"] 输出:label = [?, 260, 270, 280, 290, -100] """ if step_candidates is None: step_candidates = [1, -1, 2, -2, 5, -5, 10, -10] all_input_ids = [] all_labels = [] all_value_masks = [] for _ in range(batch_size): start = random.randint(*num_range) step = random.choice(step_candidates) seq = [start + i * step for i in range(context_len)] if any(x < 0 or x > 9999 for x in seq): continue # 构造带语言提示的输入 prompt = ["the", "next", "after"] + seq[:-1] + ["?"] target = seq[-1] input_ids, value_mask = encode_input(prompt) label_ids = input_ids.clone() label_ids[0, -1] = target + VALUE_OFFSET # 最后一个 ? 替换为答案 label_ids = torch.roll(label_ids, shifts=-1, dims=1) # 左移一位作为 label label_ids[0, -1] = -100 # 忽略最后 padding all_input_ids.append(input_ids[0]) all_labels.append(label_ids[0]) all_value_masks.append(value_mask[0]) return ( torch.stack(all_input_ids), torch.stack(all_labels), torch.stack(all_value_masks) ) # %% [markdown] # 7. 初始化模型 # %% config = YiziLMConfig( vocab_size=30000, hidden_size=512, num_hidden_layers=4, num_attention_heads=8, intermediate_size=2048, max_position_embeddings=8192, value_offset=VALUE_OFFSET ) model = YiziLM(config) total_params = sum(p.numel() for p in model.parameters()) print(f"🚀 Total Parameters: {total_params:,}") # %% [markdown] # 8. 模型训练 # %% optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-5) print("🔥 开始训练...") for step in range(500): input_ids, labels, value_mask = generate_arithmetic_batch_v2(batch_size=16) outputs = model(input_ids=input_ids, value_mask=value_mask, labels=labels) loss = outputs["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad() scheduler.step() if step % 50 == 0: print(f"Step {step} | Loss: {loss.item():.4f} | Grad Norm: {grad_norm:.4f}") # %% [markdown] # 9. 测试和推理 # %% @torch.no_grad() def test_predict_next(current_seq: List[int]): model.eval() prompt = ["the", "next", "after"] + current_seq + ["?"] input_ids, value_mask = encode_input(prompt) generated, _ = model.generate( input_ids=input_ids, value_mask=value_mask, max_new_tokens=1, temperature=0.1, top_k=10 ) pred_token = generated[0, -1].item() if pred_token >= VALUE_OFFSET: predicted_value = pred_token - VALUE_OFFSET expected = current_seq[-1] + (current_seq[-1] - current_seq[-2]) # 等差推断 print(f"🎯 Input: {current_seq} → Predicted: {predicted_value}, Expected: {expected} {'✔️' if predicted_value == expected else '❌'}") else: word = [k for k, v in word_to_id.items() if v == pred_token] print(f"🔤 Word Output: {word}") print("\n🧪 测试结果:") test_predict_next([260, 270, 280]) test_predict_next([500, 490, 480]) test_predict_next([101, 103, 105]) # %% [markdown] # 9. 保存模型 # %% save_dict = { 'model_state_dict': model.state_dict(), 'config': config.__dict__, 'word_to_id': word_to_id, 'value_offset': VALUE_OFFSET } torch.save(save_dict, "yizilm_hybrid_numerical.pth") print("💾 模型已保存至 yizilm_hybrid_numerical.pth") # %% [markdown] # 10. 可视化注意力 # %% attn_weights = model.layers[0].attn.attn_weights if attn_weights is not None and attn_weights.ndim == 4: plt.figure(figsize=(6, 6)) avg_attn = attn_weights[0].mean(0).cpu().numpy() # average over heads plt.imshow(avg_attn, cmap='viridis', aspect='auto') plt.title("Average Attention Map (Layer 1)") plt.xlabel("Key Position") plt.ylabel("Query Position") plt.colorbar() plt.tight_layout() plt.show() 总结问题并解决,修改后的代码给出完整的注释
最新发布
12-03
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值