【pytorch】torch.utils.data.TensorDataset()原版与新版的差异

本文解析了PyTorch中TensorDataset类的更新,详细介绍了新版中data_tensor和target_tensor参数被移除的原因,以及如何调整代码以适应新版本。通过对比原版与新版的使用方法,帮助读者理解并正确应用TensorDataset。
部署运行你感兴趣的模型镜像

使用时发现出错了,如下:

原因是新版把之前的data_tensor 和target_tensor去掉了,输入变成了可变参数,也就是我们平常使用*args

class TensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

所以新版的使用方法是直接传入参数:

# 原版使用方法
train_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

# 新版使用方法
train_dataset = Data.TensorDataset(x,y)

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

C:\Users\Lenovo\Anaconda3\envs\sam-env\python.exe E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py A module that was compiled using NumPy 1.x cannot be run in NumPy 2.0.2 as it may crash. To support both 1.x and 2.x versions of NumPy, modules must be compiled with NumPy 2.0. Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. If you are a user of the module, the easiest solution will be to downgrade to 'numpy<2' or try to upgrade the affected module. We expect that some modules will need time to support NumPy 2. Traceback (most recent call last): File "E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py", line 9, in <module> from groundingdino.util.inference import load_model, load_image, predict File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\util\inference.py", line 8, in <module> from torchvision.ops import box_convert File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\__init__.py", line 7, in <module> from torchvision import models File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\__init__.py", line 16, in <module> from . import detection File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\__init__.py", line 1, in <module> from .faster_rcnn import * File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\faster_rcnn.py", line 16, in <module> from .anchor_utils import AnchorGenerator File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\anchor_utils.py", line 10, in <module> class AnchorGenerator(nn.Module): File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\anchor_utils.py", line 63, in AnchorGenerator device: torch.device = torch.device("cpu"), UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\tensor_numpy.cpp:68.) Disabling PyTorch because PyTorch >= 2.1 is required but found 1.12.1+cu113 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:2895.) final text_encoder_type: bert-base-uncased Traceback (most recent call last): File "E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py", line 112, in <module> dino_model, sam_predictor, device = setup_models(DEVICE) File "E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py", line 31, in setup_models dino = load_model(GROUNDING_DINO_CONFIG, GROUNDING_DINO_CHECKPOINT) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\util\inference.py", line 32, in load_model model = build_model(args) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\models\__init__.py", line 17, in build_model model = build_func(args) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\models\GroundingDINO\groundingdino.py", line 388, in build_groundingdino model = GroundingDINO( File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\models\GroundingDINO\groundingdino.py", line 108, in __init__ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\util\get_tokenlizer.py", line 25, in get_pretrained_language_model return BertModel.from_pretrained(text_encoder_type) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\transformers\utils\import_utils.py", line 2157, in __getattribute__ requires_backends(cls, cls._backends) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\transformers\utils\import_utils.py", line 2143, in requires_backends raise ImportError("".join(failed)) ImportError: BertModel requires the PyTorch library but it was not found in your environment. Check out the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. Process finished with exit code 1
最新发布
11-14
你遇到的这个错误,虽然提示是: ``` ImportError: BertModel requires the PyTorch library but it was not found in your environment. ``` 但实际上 **PyTorch 是已经安装了** 的! 真正的问题藏在前面那几行警告中,尤其是这三类关键信息: --- ## 🔥 根本原因分析(按优先级排序) ### ❌ 1. **NumPy 版本冲突:NumPy 2.0 不兼容旧编译模块** ```text A module that was compiled using NumPy 1.x cannot be run in NumPy 2.0.2 as it may crash. ... UserWarning: Failed to initialize NumPy: _ARRAY_API not found ``` 👉 这是最严重的错误。 `torchvision`、`transformers` 等库中的某些 `.pyd`(C 扩展)文件是在 **NumPy 1.x 环境下编译的**,而你现在装的是 **NumPy 2.0.2**,导致无法加载。 > ⚠️ 结果:即使 `import torch` 成功,底层调用也会失败 → `BertModel` 报“未找到 PyTorch” --- ### ❌ 2. **PyTorch 版本过低或损坏** ```text Disabling PyTorch because PyTorch >= 2.1 is required but found 1.12.1+cu113 ``` 👉 你的环境中显示 PyTorch 版本是 `1.12.1`,但当前 Hugging Face Transformers 和 Grounding DINO 要求: > ✅ **PyTorch ≥ 2.1** 你目前的版本太老,不支持新功能(如 `torch.compile`, `SDPA`),甚至会被自动禁用! --- ### ❌ 3. **依赖混杂 + Conda/Pip 冲突** 你在 Conda 环境里可能混合使用了 `conda install pytorch` 和 `pip install torch`,导致: - 多个不同版本共存 - `.pth` 文件路径错乱 - C++ 扩展 Python 接口不匹配 --- ## ✅ 解决方案:重建干净环境(推荐做法) 我们来**彻底解决这个问题**,通过创建一个全新的、兼容性良好的虚拟环境。 ### ✅ 步骤 1:删除旧环境(避免污染) ```bash conda deactivate conda env remove -n sam-env ``` --- ### ✅ 步骤 2:创建新环境并安装正确版本的 PyTorch ```bash # 创建新环境(Python 3.9 最稳定) conda create -n sam-env python=3.9 conda activate sam-env # 安装 PyTorch 2.3 + CUDA 11.8(官方推荐命令) conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia # 强制降级 NumPy 到 1.26.x(避免 NumPy 2.0 兼容问题) pip install "numpy<2" # 安装其他必要包 pip install transformers timm opencv-python supervision groundingdino-pip ``` > 💡 注意:`groundingdino-pip` 是社区维护的可用版本,原版常有依赖问题。 --- ### ✅ 步骤 3:验证安装是否成功 ```bash python -c " import torch, numpy as np, transformers print(f'✅ Torch version: {torch.__version__}') print(f'CUDA available: {torch.cuda.is_available()}') print(f'NumPy version: {np.__version__}') from transformers import BertModel print('✅ BertModel 可以正常导入!') " ``` ✅ 正确输出应类似: ``` ✅ Torch version: 2.3.0+cu118 CUDA available: True NumPy version: 1.26.4 ✅ BertModel 可以正常导入! ``` --- ### ✅ 步骤 4:确保 IDE 使用正确的解释器 #### VSCode 用户: - 按 `Ctrl+Shift+P` - 输入:`Python: Select Interpreter` - 选择: ``` C:\Users\Lenovo\Anaconda3\envs\sam-env\python.exe ``` #### Jupyter 用户: 注册内核: ```bash pip install ipykernel python -m ipykernel install --name sam-env --display-name "Python (sam-env)" ``` 然后重启 Jupyter Lab,在新建 Notebook 时选择该内核。 --- ### ✅ 步骤 5:运行你的脚本 ```bash python E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py ``` 现在应该可以正常运行! --- ## 🛠️ 补充说明:为什么不能留着 NumPy 2.0? | 问题 | 原因 | |------|------| | `_ARRAY_API not found` | NumPy 2.0 更改了 C API 接口,旧版扩展无法识别 | | `Failed to initialize NumPy` | `torchvision` 编译时链接的是 NumPy 1.x ABI | | `Models won't be available` | Transformers 库检测到 PyTorch 初始化失败 | 📌 当前大多数深度学习库(包括 `torchvision`, `timm`, `albumentations`)都尚未完全支持 NumPy 2.0。 > ✅ 所以目前最稳妥的做法是:**强制降级 NumPy < 2** --- ## ✅ 总结:修复清单 | 项目 | 是否完成 | |------|----------| | 删除旧环境 `sam-env` | ✅ | | 重新创建 Python 3.9 环境 | ✅ | | 安装 PyTorch 2.3 + CUDA 支持 | ✅ | | `pip install "numpy<2"` | ✅ | | 安装 `transformers`, `timm`, `groundingdino-pip` | ✅ | | 验证 `BertModel` 可导入 | ✅ | | 设置 IDE 使用正确解释器 | ✅ | | 重启运行环境(Jupyter/VSCode) | ✅ | --- ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值