import torch
import torchvision
from torch.utils.data import DataLoader
das= torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(das,batch_size=35)
class my_nn(torch.nn.Module):
def __init__(self):
super(my_nn, self).__init__()
self.conv1= torch.nn.Conv2d(in_channels=3,out_channels=6,kernel_size=3, stride=1,padding=0)
def forward(self,input):
x=self.conv1(input)
return x
from torch.utils.tensorboard import SummaryWriter
wirte= SummaryWriter("nn_rgylin")
count=0
rgylin= my_nn()
for i in dataloader:
img,target= i
print(img.shape)
output= rgylin(img)
print(output.shape)
output1= torch.reshape(output,(-1,3,30,30))
wirte.add_images("rglyin_nn",output1,count,dataformats="NCHW")
count+=1