#P19 神经网络-非线性激活
#常用:ReLU、Sigmoid
#nn_relu.py
import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# input = torch.tensor([[1, -0.5],
# [-1, 3]])
#
# input = torch.reshape(input, (-1, 1, 2, 2))#需要有batch_size,因此还是用reshape设置形状,一般设置成(batch_size, channel, H_in, W_in)比较常用
# print(input.shape)
dataset = torchvision.datasets.CIFAR10("./data_nn", train=False, download=True,
transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)
class West(nn.Module):
def __init__(self):
super(West, self).__init__()
# self.relu1 = ReLU()#ReLU是使得输入小于0的数为0,大于0的数保持原始值。#inplace=Ture会使得input被output覆盖,即不保留原始值;inplace=False会保留原始值,防止数据丢失。默认为False,因此一般不设置。
self.sigmoid1 = Sigmoid()#效果是将图加了个白色滤镜
def forward(self, input):
# output = self.relu1(input)
output = self.sigmoid1(input)
return output
west = West()#将神经网络赋予给west
# output = west(input)
# print(output)
writer = SummaryWriter("./logs_Sigmoid")
step = 0
for data in dataloader:
imgs, targets = data
writer.add_images("input", imgs, global_step=step)
output = west(imgs)
writer.add_images("output", output, step)
step += 1
writer.close()
'''
#ReLU附官方网址和说明
ReLU网址
https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html#torch.nn.ReLU
ReLU说明
classtorch.nn.ReLU(inplace=False)
Shape:
Input: (N, ∗),where means any number of dimensions.
Output: (N, ∗), same shape as the input.
#Sigmoid附官方网址和说明
Sigmoid网址
https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html#torch.nn.Sigmoid
Sigmoid说明
classtorch.nn.Sigmoid(*args, **kwargs)
Input: (N, ∗),where means any number of dimensions.
Output: (N, ∗), same shape as the input.
'''
#ReLU附官方网址和说明
ReLU网址
https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html#torch.nn.ReLU
ReLU说明
classtorch.nn.ReLU(inplace=False)
Shape:
Input: (N, ∗),where means any number of dimensions.
Output: (N, ∗), same shape as the input.
#Sigmoid附官方网址和说明
Sigmoid网址
https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html#torch.nn.Sigmoid
Sigmoid说明
classtorch.nn.Sigmoid(*args, **kwargs)
Input: (N, ∗),where means any number of dimensions.
Output: (N, ∗), same shape as the input.