1.类别文件
{
"0" : "daisy" ,
"1" : "dandelion" ,
"2" : "roses" ,
"3" : "sunflowers" ,
"4" : "tulips"
}
2. 模型
import torch
import torch. nn as nn
class AlexNet ( nn. Module) :
def __init__ ( self, num_classes= 1000 ) :
super ( ) . __init__( )
self. features = nn. Sequential(
nn. Conv2d( 3 , 48 , kernel_size= 11 , stride= 4 , padding= 2 ) ,
nn. ReLU( inplace= True ) ,
nn. MaxPool2d( kernel_size= 3 , stride= 2 ) ,
nn. Conv2d( 48 , 128 , kernel_size= 5 , padding= 2 ) ,
nn. ReLU( inplace= True ) ,
nn. MaxPool2d( kernel_size= 3 , stride= 2 ) ,
nn. Conv2d( 128 , 192 , kernel_size= 3 , padding= 1 ) ,
nn. ReLU( inplace= True ) ,
nn. Conv2d( 192 , 192 , kernel_size= 3 , padding= 1 ) ,
nn. ReLU( inplace= True ) ,
nn. Conv2d( 192 , 128 , kernel_size= 3 , padding= 1 ) ,
nn. ReLU( inplace= True ) ,
nn. MaxPool2d( kernel_size= 3 , stride= 2 )
self. classifier = nn. Sequential(
nn. Dropout( p= 0.5 ) ,
nn. Linear( 128 * 6 * 6 , 2048 ) ,
nn. ReLU( inplace= True ) ,
nn. Dropout( p= 0.5 ) ,
nn. Linear( 2048 , 2048 ) ,
nn. ReLU( inplace= True ) ,
nn. Linear( 2048 , num_classes) , )
def forward ( self, x) :
x = self. features( x)
x = torch. flatten( x, start_dim= 1 )
x = self. classifier( x)
return x
3.训练
import os
import json
import sys
from tqdm import tqdm
import torch
import torch. nn as nn
from torchvision import transforms, datasets
from model import AlexNet
def main ( ) :
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
print ( f'using { device} ' )
image_path = os. path. join( './' , 'flower_data' )
assert os. path. exists( image_path) , "image path does not exist"
data_transform = {
'train' : transforms. Compose( [ transforms. RandomResizedCrop( 224 ) ,
transforms. RandomHorizontalFlip( ) ,
transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) ) ] ) ,
'val' : transforms. Compose( [ transforms. Resize( ( 224 , 224 ) ) ,
transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) ) ] ) }
train_dataset = datasets. ImageFolder( root= os. path. join( image_path, 'train' ) ,
transform= data_transform[ 'train' ] )
train_num = len ( train_dataset)
flower_list = train_dataset. class_to_idx
print ( flower_list)
cla_dict = dict ( ( val, key) for key, val in flower_list. items( ) )
json_str = json. dumps( cla_dict, indent= 4 )
with open ( 'class_indices.json' , 'w' ) as json_file:
json_file. write( json_str)
batch_size = 32
nw = min ( [ os. cpu_count( ) , batch_size if batch_size > 1 else 0 , 8 ] )
print ( f'using { nw} dataloader workers every process' )
train_loader = torch. utils. data. DataLoader( train_dataset,
batch_size= batch_size, shuffle= True ,
num_workers= nw)
validate_dataset = datasets. ImageFolder( root= os. path. join( image_path, 'val' ) ,
transform= data_transform[ 'val' ] )
val_num = len ( validate_dataset)
validate_loader = torch. utils. data. DataLoader( validate_dataset, batch_size= batch_size,
shuffle= False , num_workers= nw)
net = AlexNet( num_classes= 5 )
net. to( device)
loss_fn = nn. CrossEntropyLoss( )
optimizer = torch. optim. Adam( net. parameters( ) , lr= 0.0002 )
epochs = 10
save_path = './AlexNet.pth'
best_acc = 0.0
train_step = len ( train_loader)
for epoch in range ( epochs) :
net. train( )
running_loss = 0.0
train_bar = tqdm( train_loader, file = sys. stdout)
for step, data in enumerate ( train_bar) :
images, labels = data
optimizer. zero_grad( )
outputs = net( images. to( device) )
loss = loss_fn( outputs, labels. to( device) )
loss. backward( )
optimizer. step( )
running_loss += loss. item( )
train_bar. desc = f'train epoch { epoch + 1 } / { epochs} loss: { loss: .3f } '
net. eval ( )
acc = 0.0
with torch. no_grad( ) :
val_bar = tqdm( validate_loader, file = sys. stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net( val_images. to( device) )
predict_y = torch. max ( outputs, dim= 1 ) [ 1 ]
acc += torch. eq( predict_y, val_labels. to( device) ) . sum ( ) . item( )
val_accuracy = acc / val_num
print ( f'[epoch { epoch + 1 } ] train_loss: { running_loss / train_step: .3f } , val_accuracy: { val_accuracy: .3f } ' )
if val_accuracy > best_acc:
best_acc = val_accuracy
torch. save( net. state_dict( ) , save_path)
if __name__ == '__main__' :
main( )
4.预测
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib. pyplot as plt
from model import AlexNet
def main ( ) :
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
data_transform = transforms. Compose( [
transforms. Resize( ( 224 , 224 ) ) ,
transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) ) ] )
img_path = './4.jpeg'
assert os. path. exists( img_path) , f' { img_path} does not exist'
img = Image. open ( img_path)
plt. imshow( img)
img = data_transform( img)
img = torch. unsqueeze( img, dim= 0 )
json_path = './class_indices.json'
assert os. path. exists( json_path) , f'file { json_path} does not exist'
with open ( json_path, 'r' ) as f:
class_dict = json. load( f)
model = AlexNet( num_classes= 5 ) . to( device)
weights_path = './AlexNet.pth'
assert os. path. exists( weights_path) , f'file { weights_path} does not exist'
model. load_state_dict( torch. load( weights_path) )
model. eval ( )
with torch. no_grad( ) :
output = model( img. to( device) )
output = torch. squeeze( output) . cpu( )
predict = torch. softmax( output, dim= 0 )
predict_class = torch. argmax( predict) . numpy( )
print_res = f"class: { class_dict[ str ( predict_class) ] } , prob: { predict[ predict_class] . numpy( ) : .3f } "
plt. title( print_res)
plt. show( )
if __name__ == '__main__' :
main( )