【pytorch】torch.utils.data.TensorDataset()原版与新版的差异

本文解析了PyTorch中TensorDataset类的更新,详细介绍了新版中data_tensor和target_tensor参数被移除的原因,以及如何调整代码以适应新版本。通过对比原版与新版的使用方法,帮助读者理解并正确应用TensorDataset。
部署运行你感兴趣的模型镜像

使用时发现出错了,如下:

原因是新版把之前的data_tensor 和target_tensor去掉了,输入变成了可变参数,也就是我们平常使用*args

class TensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

所以新版的使用方法是直接传入参数:

# 原版使用方法
train_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

# 新版使用方法
train_dataset = Data.TensorDataset(x,y)

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

C:\Users\Lenovo\Anaconda3\envs\sam-env\python.exe E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py A module that was compiled using NumPy 1.x cannot be run in NumPy 2.0.2 as it may crash. To support both 1.x and 2.x versions of NumPy, modules must be compiled with NumPy 2.0. Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. If you are a user of the module, the easiest solution will be to downgrade to 'numpy<2' or try to upgrade the affected module. We expect that some modules will need time to support NumPy 2. Traceback (most recent call last): File "E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py", line 9, in <module> from groundingdino.util.inference import load_model, load_image, predict File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\util\inference.py", line 8, in <module> from torchvision.ops import box_convert File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\__init__.py", line 7, in <module> from torchvision import models File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\__init__.py", line 16, in <module> from . import detection File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\__init__.py", line 1, in <module> from .faster_rcnn import * File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\faster_rcnn.py", line 16, in <module> from .anchor_utils import AnchorGenerator File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\anchor_utils.py", line 10, in <module> class AnchorGenerator(nn.Module): File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\torchvision\models\detection\anchor_utils.py", line 63, in AnchorGenerator device: torch.device = torch.device("cpu"), UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\tensor_numpy.cpp:68.) Disabling PyTorch because PyTorch >= 2.1 is required but found 1.12.1+cu113 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:2895.) final text_encoder_type: bert-base-uncased Traceback (most recent call last): File "E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py", line 112, in <module> dino_model, sam_predictor, device = setup_models(DEVICE) File "E:\work\danzhutiqu\image\xuanzhuan\sam\fenge.py", line 31, in setup_models dino = load_model(GROUNDING_DINO_CONFIG, GROUNDING_DINO_CHECKPOINT) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\util\inference.py", line 32, in load_model model = build_model(args) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\models\__init__.py", line 17, in build_model model = build_func(args) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\models\GroundingDINO\groundingdino.py", line 388, in build_groundingdino model = GroundingDINO( File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\models\GroundingDINO\groundingdino.py", line 108, in __init__ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\groundingdino\util\get_tokenlizer.py", line 25, in get_pretrained_language_model return BertModel.from_pretrained(text_encoder_type) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\transformers\utils\import_utils.py", line 2157, in __getattribute__ requires_backends(cls, cls._backends) File "C:\Users\Lenovo\Anaconda3\envs\sam-env\lib\site-packages\transformers\utils\import_utils.py", line 2143, in requires_backends raise ImportError("".join(failed)) ImportError: BertModel requires the PyTorch library but it was not found in your environment. Check out the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. Process finished with exit code 1
最新发布
11-14
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值