pytorch中torch.chunk()方法

chunk方法可以对张量分块,返回一个张量列表:

torch.chunk(tensorchunksdim=0) → List of Tensors

Splits a tensor into a specific number of chunks.

Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.(如果指定轴的元素个数被chunks除不尽,那么最后一块的元素个数变少)

Parameters:
  • tensor (Tensor) – the tensor to split
  • chunks (int) – number of chunks to return(分割的块数)
  • dim (int) – dimension along which to split the tensor(沿着哪个轴分块)
 
import numpy as np
import torch

data = torch.from_numpy(np.random.rand(3, 5))
print(str(data))
>>
tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590],
        [0.9705, 0.8673, 0.8854, 0.9029, 0.5473],
        [0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64)

for i, data_i in enumerate(data.chunk(5, 1)): # 沿1轴分为5块
    print(str(data_i))
>>
tensor([[0.6742],
        [0.9705],
        [0.0199]], dtype=torch.float64)
tensor([[0.5700],
        [0.8673],
        [0.4729]], dtype=torch.float64)
tensor([[0.3519],
        [0.8854],
        [0.4001]], dtype=torch.float64)
tensor([[0.4603],
        [0.9029],
        [0.7581]], dtype=torch.float64)
tensor([[0.9590],
        [0.5473],
        [0.5045]], dtype=torch.float64)  

for i, data_i in enumerate(data.chunk(3, 0)): # 沿0轴分为3块
    print(str(data_i))
>>
tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590]], dtype=torch.float64)
tensor([[0.9705, 0.8673, 0.8854, 0.9029, 0.5473]], dtype=torch.float64)
tensor([[0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64)
   
for i, data_i in enumerate(data.chunk(3, 1)): # 沿1轴分为3块,除不尽
    print(str(data_i))
>>
tensor([[0.6742, 0.5700],
        [0.9705, 0.8673],
        [0.0199, 0.4729]], dtype=torch.float64)
tensor([[0.3519, 0.4603],
        [0.8854, 0.9029],
        [0.4001, 0.7581]], dtype=torch.float64)
tensor([[0.9590],
        [0.5473],
        [0.5045]], dtype=torch.float64)

 

 

转载于:https://www.cnblogs.com/jiangkejie/p/10309766.html

### PyTorch `split` 和 `chunk` 函数的区别 #### 功能描述 PyTorch 的 `split` 和 `chunk` 都用于分割张量,但两者的工作方式有所不同。 对于 `torch.split(tensor, split_size_or_sections, dim=0)` 函数而言,此方法允许指定要切割成的小片段大小或是各个部分的具体尺寸列表。如果提供的是单个整数,则表示每一部分的长度;如果是元组或列表,则精确指定了各分片的大小[^1]。 ```python import torch tensor_example = torch.arange(8) # 使用 split 方法按每份两个元素来划分张量 result_split = torch.split(tensor_example, 2) print([t.tolist() for t in result_split]) ``` 另一方面,`torch.chunk(tensor, chunks, dim=0)` 则更倾向于平均分配输入张量到指定数量的部分中去。当无法均匀切分时,最后一块可能会小于其他部分。这里只接受一个参数来定义希望得到多少个子集。 ```python # 使用 chunk 方法将张量分为三等分 result_chunk = torch.chunk(tensor_example, 3) print([t.tolist() for t in result_chunk]) ``` 这两种操作都支持通过设置维度参数 (`dim`) 来控制沿哪个轴执行拆分,默认情况下是在第零维上工作。 #### 输出差异展示 上述代码会分别打印出由 `split` 和 `chunk` 返回的结果: - 对于 `split`: 如果传入了合适的总和等于原始张量长度的分割尺寸数组,那么将会获得完全按照给定规格被分开的新张量集合; - 而对于 `chunk`: 不管怎样都会尝试尽可能公平地把原数据分成所请求的数量级,即使这意味着某些部分可能比其他的稍大一些或小一点。 因此,在实际编程过程中可以根据具体需求选择合适的方法来进行张量的操作处理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值