from utils.MobileNetOps import *
from utils import helper
import pickle
import numpy as np
import tensorflow as tf
import matplotlib as mpl
import os
# 设置字符集,防止中文乱码
mpl.rcParams['font.sans-serif'] = [u'simHei']
mpl.rcParams['axes.unicode_minus'] = False
cifar10_dataset_folder_path = './cifar-10-batches-py'
if os.path.exists(cifar10_dataset_folder_path):
print('cifar10原始数据集存在!!!')
# 检查点,预处理的数据已经保存本地,每次可以从这里开始执行,之前的代码不用执行了。
valid_features, valid_labels = pickle.load(
open('../datas/cifar10/preprocess_validation.p', mode='rb')
)
def model_mobile_v3(inputs, num_classes, exp=3, bn_train=True):
"""
实现Mobile_Net_V3模型结构图。
:param inputs:
:param num_classes:
:param bn_train:
:param exp: 膨胀系数
:return:
"""
with tf.variable_scope('Network'):
with tf.variable_scope('conv1_1'):
net = conv2d(inputs, 16, 3, 3, 1, 1, name='conv2d', use_bias=True)
net = hswish(net)
# [N, 32, 32, 16]
"""
res_block(input, expansion_ratio, output_dims, stride, bn_train, name,
activation=hswish, use_se=True, se_reduction=8, use_bias=False, shortcut=True)
"""
net = res_block(
net, 1, 16, stride=1, bn_train=bn_train, activation=tf.nn.relu6, name='res2_1'
) # [N, 32, 32, 16]
net = res_block(
net, exp, 24, stride=2, bn_train=bn_train, activation=tf.nn.relu6, use_se=False,
name='res3_1') # [N,
7_MobileV3_for_cifar10.py
最新推荐文章于 2023-11-23 10:14:30 发布
本文介绍了一种基于MobileNetV3的深度学习模型在CIFAR-10数据集上的应用。该模型通过使用深度可分离卷积和倒残差结构,实现了高效的图像分类任务。文章详细描述了模型结构、训练过程以及验证方法。

最低0.47元/天 解锁文章
6262

被折叠的 条评论
为什么被折叠?



