pytorch/torch中的view

博客介绍了一种类似numpy中reshape的操作,该操作不会改变原始数据,并且重点在于它是从最后一个维度进行reshape的。

类似于numpy中的reshape,且不改变原始数据。重点是,它是从最后一个维度进行reshape中的

import torch
a = torch.randint(0, 10, size=(3, 4))
print(a)
print(a.view(4, 3))
print(a.view(1, 12))

分别得到

tensor([[7, 1, 7, 2],
        [9, 5, 7, 1],
        [6, 0, 9, 9]])
        
tensor([[7, 1, 7],
        [2, 9, 5],
        [7, 1, 6],
        [0, 9, 9]])
        
tensor([[7, 1, 7, 2, 9, 5, 7, 1, 6, 0, 9, 9]])
<think>我们遇到了一个PyTorch ONNX导出时的警告信息。这个警告表明在尝试进行符号形状推断中的常量折叠时出现了问题。 具体警告信息:`UserWarning: Constant folding in symbolic shape inference fails: Dimension specified as -1 but tensor has no dimensions` ### 问题分析 1. **触发原因**: - 当尝试对一个0维张量(标量)进行形状推断时,指定了维度为-1(通常表示该维度由其他维度推断得出),但0维张量没有维度,因此导致错误。 - 此警告在PyTorch内部文件`shape_type_inference.cpp`的第438行触发。 2. **常见场景**: - 模型包含标量张量(0维)的操作 - 使用`torch.onnx.export()`导出模型时 - 模型中有`view`或`reshape`等操作试图将标量视为向量处理 3. **潜在影响**: - 通常不会阻止导出过程,但可能导致导出的ONNX模型存在潜在问题 - 可能影响后续模型部署(如TensorRT、ONNX Runtime等引擎加载) ### 解决方案 #### 方法1:升级PyTorch版本 ```bash pip install --upgrade torch torchvision ``` - 新版本可能已修复此问题 #### 方法2:避免使用0维张量 在导出前将标量转换为1维张量: ```python # 原始代码(可能产生0维张量) output = tensor.mean() # 修改为保持维度 output = tensor.mean(keepdim=True) # 保持维度 output = output.reshape(1) # 或显式转为1维 ``` #### 方法3:修改导出配置 跳过常量折叠(不推荐,可能影响性能): ```python torch.onnx.export(..., do_constant_folding=False) ``` #### 方法4:定位具体操作 1. 使用调试模式定位问题节点: ```python torch.onnx.export(..., verbose=True) ``` 2. 检查输出日志中触发警告的具体操作 ### 完整导出示例(修复后) ```python import torch import torch.nn as nn class SampleModel(nn.Module): def forward(self, x): # 避免产生0维张量 return x.mean(dim=[2,3], keepdim=True) # 保持维度 model = SampleModel() dummy_input = torch.randn(1, 3, 256, 256) # 导出ONNX torch.onnx.export( model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"], opset_version=12 # 使用较新opset ) ``` ### 相关文档 - PyTorch ONNX导出问题追踪:https://github.com/pytorch/pytorch/issues - ONNX形状推断规范:https://github.com/onnx/onnx/blob/master/docs/ShapeInference.md §§相关问题§§ 1.PyTorch模型导出ONNX时,还有哪些常见警告需要特别关注? 2. 如何通过修改模型结构设计来避免ONNX导出时的维度不匹配问题? 3. 在ONNX导出过程中,常量折叠(Constant Folding)的作用是什么?禁用后会对模型性能产生什么影响? 4. 针对不同部署后端(如TensorRT、OpenVINO),ONNX模型需要进行哪些不同的后处理? 5. 如何利用ONNX Runtime对导出的模型进行正确性验证?</think>这个警告信息表明在使用 PyTorch 的 ONNX 导出功能时,在符号形状推断的常量折叠过程中遇到了问题。以下是详细解析: ### 警告原因分析 ```python UserWarning: Constant folding in symbolic shape inference fails: Dimension specified as -1 but tensor has no dimensions ``` 1. **核心问题**: - 尝试对一个**0维张量(标量)** 进行形状推断 - 使用了维度值 `-1`(通常表示自动推断维度大小) - 但0维张量没有维度,导致内部逻辑失败 2. **触发位置**: - PyTorch 内部文件:`/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:438` - 发生在 ONNX 导出时的符号形状推断阶段 ### 问题根源与解决方案 #### 常见触发场景 1. **模型输出标量值**: ```python # 模型返回单个标量值(如损失值) return loss.mean() # 产生0维张量 ``` 2. **视图操作错误**: ```python # 尝试对标量进行reshape/view操作 scalar = torch.tensor(0.5) scalar.view(-1) # 错误:对标量使用-1维度 ``` 3. **形状推断冲突**: ```python # 动态维度(-1)与静态维度的冲突 dynamic_tensor = torch.randn(1, -1) # 非法但可能间接引发 ``` #### 解决方案 **方法1:避免输出0维张量** ```python # 修改模型输出保持维度 def forward(self, x): loss = self.calculate_loss(x) # 添加虚拟维度 [1] 替代标量 [] return loss.unsqueeze(0) # 形状变为 [1] ``` **方法2:显式指定维度(推荐)** ```python torch.onnx.export( model, args, "model.onnx", # 关键参数:禁用问题优化 do_constant_folding=False, # 关闭常量折叠 dynamic_axes={'input': {0: 'batch'}}, # 显式指定动态轴 opset_version=13 # 使用较新opset ) ``` **方法3:升级PyTorch版本** ```bash # 此问题可能在较新版本已修复 pip install torch --upgrade ``` ### 完整导出示例(修复后) ```python import torch import torch.nn as nn class SafeModel(nn.Module): def forward(self, x): # 确保输出至少1维 output = x.mean(dim=1, keepdim=True) # 保持维度 [B,1] return output model = SafeModel() dummy = torch.randn(1, 3, 224, 224) # 安全导出配置 torch.onnx.export( model, dummy, "model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}}, # 显式动态批次 opset_version=14, # 使用最新opset do_constant_folding=False # 禁用问题优化 ) ``` ### ONNX导出最佳实践 1. **维度检查**: ```python def validate_shapes(model, input): with torch.no_grad(): out = model(input) assert out.dim() >= 1, "输出不能是标量" ``` 2. **版本兼容性**: | PyTorch版本 | 推荐opset | 主要改进 | |------------|-----------|----------| | 1.8+ | 13 | 动态形状支持 | | 1.10+ | 14 | 优化控制流 | | 2.0+ | 16 | 完整动态量化 | 3. **验证导出模型**: ```python import onnx model = onnx.load("model.onnx") onnx.checker.check_model(model) # 验证模型完整性 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值