线性层及其他层

# 1.线性拉平
# 线性拉平(Linear Flattening)是指将一个多维的线性结构(如数组、列表等)转换为一个一维的线性结构的过程。
# 在计算机科学中,这个过程通常用于将多维数组转换为一维数组,以便进行某些操作,如排序、搜索等。
import torch
from torch import nn
import torchvision
from torch.nn import Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root="D:\PyCharm\CIFAR10", train=False,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)

for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    output = torch.reshape(imgs, (1, 1, 1, -1))
    print(output.shape)

# 2.线性层
datset = torchvision.datasets.CIFAR10(root="D:\PyCharm\CIFAR10", train=False,
                                      transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)


# 当 drop_last=True 时,如果数据集的样本数量不能被 batch_size 整除,那么 DataLoader 会丢弃最后一个不完整的批次,以确保每个批次的样本数量都是 batch_size。

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = Linear(196608, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output


model = MyModel()
writer = SummaryWriter("logs12")
step = 0

for data in dataloader:
    imgs, targets = data
    print(f"imgs.shape:{imgs.shape}")
    for i, img in enumerate(imgs):
        img_input_hwc = img.permute(1, 2, 0)
        writer.add_image("input", img_input_hwc, step, dataformats="HWC")
    # 方式一:拉平
    output = torch.reshape(imgs, (1, 1, 1, -1))
    print(f"output.shape:{output.shape}")
    output = model(output)
    print(f"output = model(output)的shape:{output.shape}")
    for i, img in enumerate(imgs):
        img_output_hwc = img.permute(1, 2, 0)
        writer.add_image("output", img_output_hwc, step, dataformats="HWC")
    step = step + 1
writer.close()

# 方式二:拉平,展开为一维
dataset = torchvision.datasets.CIFAR10(root="D:\PyCharm\CIFAR10", train=False,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)


class MyModel_1(nn.Module):
    def __init__(self):
        super(MyModel_1, self).__init__()
        self.linear2 = nn.Linear(196608, 10)

    def forward(self, input):
        output = self.linear2(input)
        return output


model_1 = MyModel_1()
writer = SummaryWriter("logs13")
step = 0

for data in dataloader:
    imgs, targets = data
    print(f"imgs.shape:{imgs.shape}")
    for i, img in enumerate(imgs):
        img_input_hwc = img.permute(1, 2, 0)
        writer.add_image("input", img_input_hwc, step, dataformats="HWC")
    output = torch.flatten(imgs)  # 方式二:拉平,展开为一维
    print(f"output.shape:{output.shape}")
    output = model_1(output)
    print(f"output =model_1(output):{output.shape}")
    for i, img in enumerate(imgs):
        img_output_hwc = img.permute(1, 2, 0)
        writer.add_image("output", img_output_hwc, step, dataformats="HWC")
    step = step + 1
writer.close()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YLTommi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值