pytorch --- torch.utils.data.DataLoder()的使用

本文详细介绍了PyTorch中DataLoader的使用方法,包括参数解释、自定义Dataset类、使用TensorDataset以及如何定义collate_fn函数进行数据处理,如padding等。
部署运行你感兴趣的模型镜像

torch.utils.data.DataLoader(dataset, batch_size, shuffle, num_workers, collate_fn)

参数说明

  • dataset 传入的数据集
  • batch_size每个batch有多少个样本
  • shuffle是否打乱数据
  • num_workers有几个进程来处理data loading
  • collate_fn将一个list的sample组成一个mini-batch的函数

具体使用

  1. Dataloader的处理逻辑是先通过Dataset类里面的__getitem__函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作,比如padding之类的。
  2. 在NLP中的使用主要是要重构两个两个东西,一个是新建Dataset类构建dataset,必须继承torch.utils.data.Dataset类,内部要实现两个函数一个是__len__用来获取整个数据集的大小,一个是__getitem__用来从数据集中得到一个数据片段item。

如下代码新建MyDataset类来构建输入DataLoader第一个参数dataset

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, index):
        return (self.centers[index], self.contexts[index], self.negatives[index])

    def __len__(self):
        return len(self.centers)
  # args:
  #		centers, contexts, negatives 都是list, 且contexts中元素也是list, 并且长度不一定都相同

在torch中有一个专门构建dataset的函数TensorDataset, 使用如下

dataset = torch.utils.data.TensorDataset(x, y)
# 输入的x, y 是tensor类型
  1. 可以通过自定义collate_fn=myfunction来设计数据收集的方式,也就是通过上面的Dataset类中的__getitem__函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数, nlp任务中,经常在collate_fn指定的函数里面做padding

如下定义一个函数进行padding

def batchify(data):
    max_len = max(len(c)+len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0]*(max_len-cur_len)]
        masks += [[1]*cur_len + [0]*(max_len-cur_len)]
        labels += [[1]*len(context) + [0]*(max_len-len(context))]
    return torch.tensor(centers).view(-1,1), torch.tensor(contexts_negatives), torch.tensor(masks), torch.tensor(labels)

总结
总的使用流程就是: 重建Dataset类(或直接用torch.utils.data.TensorDataset())来构建输入参数dataset, 设定其他参数, 自定义一个函数处理数据(传给collate_fn参数)如padding等

如下完整的传参过程

dataset = MyDataset(all_centers, all_contexts, all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=num_workers)

完整程序见githup: 完整程序 word2vec.py

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

根据Mend.io提供的安全报告,PyTorch 2.3.0版本存在多个高危漏洞(CVE-2024-48063等),主要涉及反序列化漏洞、内存损坏和拒绝服务攻击。以下是解决方案和详细说明: --- ### **1. 漏洞影响分析** | CVE编号 | 严重性 | 漏洞类型 | 影响范围 | |------------------|--------|------------------------------|------------------------| | CVE-2024-48063 | 9.8 | 反序列化未信任数据 | 所有PyTorch 2.3.0用户 | | CVE-2025-2999/8 | 5.3 | 内存损坏(RNN工具函数) | `unpack_sequence`等函数| | CVE-2025-3001/0 | 5.3 | 内存损坏(LSTM/JIT模块) | `torch.lstm_cell`等 | | CVE-2025-2148 | 5.0 | 内存损坏(Profiler回调) | JIT调试功能 | | CVE-2025-3136/21 | 3.3 | 内存损坏(CUDA/JIT模块) | GPU内存分配、模型加载 | | CVE-2025-2953 | 3.3 | 拒绝服务(MKL-DNN池化层) | 量化推理场景 | --- ### **2. 修复方案** #### **方案1:升级到安全版本** - **最新稳定版**:PyTorch 2.4.0+(已修复CVE-2024-48063等) ```bash pip install --upgrade torch>=2.4.0 ``` - **验证安装**: ```python import torch print(torch.__version__) # 应输出≥2.4.0 ``` #### **方案2:临时缓解措施(若无法升级)** - **禁用高危功能**: - 避免使用`torch.load()`加载不可信模型(CVE-2024-48063)。 - 替换RNN相关函数为安全实现: ```python # 替代pad_packed_sequence(示例) from torch.nn.utils.rnn import PackedSequence def safe_pad(packed_seq): return packed_seq[0] # 需根据实际需求调整 ``` - **限制CUDA使用**: ```python import os os.environ['CUDA_VISIBLE_DEVICES'] = '' # 强制使用CPU ``` #### **方案3:依赖隔离** - 使用虚拟环境隔离问题版本: ```bash conda create -n pytorch_safe python=3.9 conda activate pytorch_safe pip install torch==2.4.0 ``` --- ### **3. 长期安全建议** 1. **启用依赖项扫描**: - 集成Mend.io、Snyk等工具到CI/CD流程。 - 示例GitHub Action配置: ```yaml - name: Scan for vulnerabilities uses: snyk/actions/python@master with: command: test ``` 2. **监控安全公告**: - 订阅[PyTorch安全公告](https://github.com/pytorch/pytorch/security/advisories)。 3. **最小化权限**: - 运行PyTorch服务时使用非root用户。 ---
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值