![]()
AlexNet论文要点:
- 利用ImageNet数据库进行网络训练,库中包含22000种类的1500万标签数据。
- 利用线性整流层ReLU的非线性函数。(利用线性整流层ReLU后,运行速度比传统双曲正切函数快了几倍)
- 利用了数据扩容方法data augmentation,包括图像变换、水平反射、块提取patch extractions等方法;
- 为解决训练集过拟合问题而引入了丢包层dropout layer;
- 使用批量随机梯度下降法batch stochastic gradient descent进行训练,为动量momentum和权重衰退weight decay设定限定值;
- 使用两块GTX 580 GPU训练了5~6天。
github代码 pytorch搭建经典网络模型
数据集 http://download.tensorflow.org/example_images/flower_photos.tgz
- 将数据集分成训练集,验证集
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file):
os.makedirs(file)
file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
mkfile('flower_data/train/'+cla)
mkfile('flower_data/val')
for cla in flower_class:
mkfile('flower_data/val/'+cla)
split_rate = 0.1
for cla in flower_class:
cla_path = file + '/' + cla + '/'
images = os.listdir(cla_path)
num = len(images)
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
image_path = cla_path + image
new_path = 'flower_data/val/' + cla
copy(image_path, new_path)
else:
image_path = cla_path + image
new_path = 'flower_data/train/' + cla
copy(image_path, new_path)