CIFAR-10 model

本文介绍了一个使用TFLearn构建的鸟类分类器。该分类器通过卷积神经网络处理32x32像素的图像,并利用数据增强技术提高训练效果。最终模型在训练集上表现良好。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# Import tflearn and some helpers
import numpy as np
import tflearn
from tflearn.data_utils import shuffle, to_categorical
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation
import pickle

"""
Based on the tflearn example located here:
https://github.com/tflearn/tflearn/blob/master/examples/images/convnet_cifar10.py
"""


# Load the data set
#X, Y, X_test, Y_test = pickle.load(open("/home/mao/Downloads/full_dataset.pkl", "rb"), encoding='bytes')
#X, Y, X_test, Y_test = pickle.load(open("full_dataset.pkl", "rb"))


# Shuffle the data
#X, Y = shuffle(X, Y)
X, Y, X_test, Y_test = X_batch1, Y_batch1, X_test, Y_test


# Make sure the data is normalized
img_prep = ImagePreprocessing()
img_prep.add_featurewise_zero_center()
img_prep.add_featurewise_stdnorm()


# Create extra synthetic training data by flipping, rotating and blurring the
# images on our data set.
img_aug = ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_rotation(max_angle=25.)
img_aug.add_random_blur(sigma_max=3.)


# Define our network architecture:


# Input is a 32x32 image with 3 color channels (red, green and blue)
network = input_data(shape=[None, 32, 32, 3],
                     data_preprocessing=img_prep,
                     data_augmentation=img_aug)


# Step 1: Convolution
network = conv_2d(network, 32, 3, activation='relu')


# Step 2: Max pooling
network = max_pool_2d(network, 2)


# Step 3: Convolution again
network = conv_2d(network, 64, 3, activation='relu')


# Step 4: Convolution yet again
network = conv_2d(network, 64, 3, activation='relu')


# Step 5: Max pooling again
network = max_pool_2d(network, 2)


# Step 6: Fully-connected 512 node neural network
network = fully_connected(network, 512, activation='relu')


# Step 7: Dropout - throw away some data randomly during training to prevent over-fitting
network = dropout(network, 0.5)


# Step 8: Fully-connected neural network with two outputs (0=isn't a bird, 1=is a bird) to make the final prediction
network = fully_connected(network, 10, activation='softmax')


# Tell tflearn how we want to train the network
network = regression(network, optimizer='adam',
                     loss='categorical_crossentropy',
                     learning_rate=0.001)


# Wrap the network in a model object
model = tflearn.DNN(network, tensorboard_verbose=0, checkpoint_path='bird-classifier.tfl.ckpt')
#model.load("bird-classifier.tfl.ckpt-630")


# Train it! We'll do 100 training passes and monitor it as it goes.
model.fit(X, Y, n_epoch=100, shuffle=True, validation_set=(X_test, Y_test),
          show_metric=True, batch_size=96,
          snapshot_epoch=True,
          run_id='bird-classifier')


# Save model when training is complete to a file
model.save("bird-classifier.tfl")
print("Network trained and saved as bird-classifier.tfl!")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值