unfold函数

部署运行你感兴趣的模型镜像

文章目录

  • 1. 原理介绍
  • 2. pytorch源码验证:

1. 原理介绍

torch.unfold函数的作用是将卷积出来的元素提取后按列向量排列

  • 提取元素
    在这里插入图片描述
  • 按列生成矩阵
    在这里插入图片描述

2. pytorch源码验证:

  • pytorch:
import torch.nn as nn
import torch

if __name__ == "__main__":
    run_code = 0
    my_unfold = nn.Unfold(kernel_size=(2, 2))
    batch_size = 1
    in_channels = 2
    input_h = 3
    input_w = 4
    my_total = batch_size * in_channels * input_h * input_w
    my_shape = (batch_size, in_channels, input_w, input_h)
    my_matrix = torch.arange(my_total).reshape(my_shape).to(torch.float)
    my_output = my_unfold(my_matrix)
    kernel_size = (2,2)
    print(f"my_matrix.shape=\n{my_matrix.shape}")
    print(f"my_output.shape=\n{my_output.shape}")
    print(f"my_matrix=\n{my_matrix}")
    print(f"unfold_kernel={kernel_size}")
    print(f"my_output=\n{my_output}")
  • 结果:
my_matrix.shape=
torch.Size([1, 2, 4, 3])
my_output.shape=
torch.Size([1, 8, 6])
my_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.],
          [ 9., 10., 11.]],

         [[12., 13., 14.],
          [15., 16., 17.],
          [18., 19., 20.],
          [21., 22., 23.]]]])
unfold_kernel=(2, 2)
my_output=
tensor([[[ 0.,  1.,  3.,  4.,  6.,  7.],
         [ 1.,  2.,  4.,  5.,  7.,  8.],
         [ 3.,  4.,  6.,  7.,  9., 10.],
         [ 4.,  5.,  7.,  8., 10., 11.],
         [12., 13., 15., 16., 18., 19.],
         [13., 14., 16., 17., 19., 20.],
         [15., 16., 18., 19., 21., 22.],
         [16., 17., 19., 20., 22., 23.]]])

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值