一、导入库
import torch
import torch. nn as nn
import torchvision. transforms as transforms
import torchvision
from torchvision import transforms, datasets
import matplotlib. pyplot as plt
from PIL import Image
import torch. nn. functional as F
import warnings
import os, PIL, pathlib, random
二、导入数据
data_dir = './data/weather_photos'
data_dir = pathlib. Path( data_dir)
data_paths = list ( data_dir. glob( '*' ) )
classeNames = [ str ( path) . split( "\\" ) [ 2 ] for path in data_paths]
image_folder = './data/weather_photos/cloudy/'
image_files = [ f for f in os. listdir( image_folder) if f. endswith( ( ".jpg" , ".png" , ".jpeg" ) ) ]
fig, axes = plt. subplots( 3 , 8 , figsize= ( 16 , 6 ) )
for ax, img_file in zip ( axes. flat, image_files) :
img_path = os. path. join( image_folder, img_file)
img = Image. open ( img_path)
ax. imshow( img)
ax. axis( 'off' )
三、数据预处理
total_datadir = './data/weather_photos'
train_transforms = transforms. Compose( [
transforms. Resize( [ 224 , 224 ] ) ,
transforms. ToTensor( ) ,
transforms. Normalize(
mean= [ 0.485 , 0.456 , 0.406 ] ,
std= [ 0.229 , 0.224 , 0.225 ] )
] )
total_data = datasets. ImageFolder( total_datadir, transform= train_transforms)
train_size = int ( 0.8 * len ( total_data) )
test_size = len ( total_data) - train_size
train_dataset, test_dataset = torch. utils. data. random_split( total_data, [ train_size, test_size] )
batch_size = 32
train_dl = torch. utils. data. DataLoader( train_dataset,
batch_size= batch_size,
shuffle= True ,
num_workers= 1 )
test_dl = torch. utils. data. DataLoader( test_dataset,
batch_size= batch_size,
shuffle= True ,
num_workers= 1 )
四、创建模型