TF-slim进行快速进行模型分类实验(1.运行demo)

本文详细介绍如何使用TensorFlow的Slim模块进行图像分类任务,包括环境搭建、数据集准备、模型训练及微调等关键步骤。通过实例演示,帮助读者掌握Slim模块的使用方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

引用参考了以下网站:

https://github.com/tensorflow/models/tree/master/research/slim

https://blog.youkuaiyun.com/stesha_chen/article/details/81976415

https://blog.youkuaiyun.com/rookie_wei/article/details/80796009

 

"""

contrib模块在tensorflow2.0将会移除,slim模块位于contrib中

tensorflow slim模块适用于图像分类,使用的数据格式为tfrecord

"""

 

备注:如果使用docker方式安装tensorflow的话

docker run --runtime=nvidia --name=yuankun_tfslim  -it -p 15003:8000 -p 15004:22 -p 15005:5000 -p 15006:6006  -v /home/yk:/yk tensorflow/tensorflow:latest-gpu

注意映射tensorboard 6006端口到外部。

 

1.在某些版本中tensorflow不存在TF-slim模块,执行下列语句确保slim存在

如果存在则不会报错(新版本中有slim模块,当前使用1.13.1)

python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"

 

2.安装TF models模块(该模块tensorflow并不自带,需自行安装)

cd $HOME/workspace

git clone https://github.com/tensorflow/models/   #安装tensorflow包根目录

 

运行以下命令,不报错则安装成功

cd $HOME/workspace/models/research/slim

python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"

 

 

3.下载flower数据集转换为tfrecord格式

 

DATA_DIR=/tmp/data/flowers
python download_and_convert_data.py \
    --dataset_name=flowers \
    --dataset_dir="${DATA_DIR}"

      

4.开始训练

 

 

DATASET_DIR=/tmp/data/flowers
TRAIN_DIR=/tmp/train_logs
python train_image_classifier.py \
    --train_dir=${TRAIN_DIR} \
    --dataset_name=flowers \
    --dataset_split_name=train \
    --dataset_dir=${DATASET_DIR} \
    --model_name=inception_resnet_v2 \
    --max_number_of_steps=500 \
    --batch_size=32 \
    --learning_rate=0.0001 \
    --learning_rate_decay_type=fixed \
    --save_interval_secs=60 \
    --save_summaries_secs=60 \
    --log_every_n_steps=10 \
    --optimizer=rmsprop \
    --weight_decay=0.00004

      

备注:#model_name:定义所使用的模型

#model_name:定义所使用的模型

 

可供选择的model:

inception_resnet_v2

inception_v1

inception_v2

inception_v3

inception_v4

vgg_16

vgg_19

resnet_v1_50,resnet_v1_101,resnet_v1_152,resnet_v1_200

resnet_v2_50,resnet_v2_101,resnet_v2_152,resnet_v2_200

等等,具体参考https://github.com/tensorflow/models/tree/master/research/slim/nets

5.查看tensorboard日志

tensorboard --logdir=/tmp/train_logs

6.微调模型的方法

使用inception_resnet_v2模型机型微调

PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
MODEL_NAME=inception_resnet_v2
TRAIN_DIR=/tmp/flowers-models/${MODEL_NAME}
DATASET_DIR=/tmp/data/flowers

mkdir ${PRETRAINED_CHECKPOINT_DIR}
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar -xvf inception_resnet_v2_2016_08_30.tar.gz
#下载预训练模型
mv inception_resnet_v2.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt

python train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=${MODEL_NAME} \
  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004
#微调说明:
--checkpoint_exclude_scopes  # 第一次不加载这些参数
--trainable_scopes # 重新训练这部分参数

 

### VGGish模型简介 VGGish是一种用于声音分类的深度学习模型,其架构受到图像识别领域著名VGG网络的影响[^2]。该模型采用卷积神经网络(CNN),通过多层卷积和池化操作来提取音频特征。此设计允许捕捉不同频率和时间尺度上的声音模式。 为了提升训练效率以及增强表现力,VGGish引入了批量归一化技术和ReLU作为激活函数。预训练权重的存在让开发者能够在个人音频数据集上迅速开展工作而不需要重新开始整个训练流程。 ### 使用VGGish模型的方法 #### 准备环境与资源 安装必要的Python库对于准备使用VGGish至关重要: ```bash pip install numpy resampy tensorflow tf_slim six soundfile ``` 接着,需从GitHub下载官方提供的源码仓库,并定位至`models/research/audioset/vggish`目录下[^3]。同时也要确保已获得预训练模型文件并将它们放置在同一路径内。 #### 测试配置正确性 执行简单的冒烟测试可以验证当前设置是否无误: ```python python vggish_smoke_test.py ``` 如果一切正常,则会看到输出提示"LGTGM"字样表明准备工作完成良好。 #### 进行实际推理 当确认环境搭建成功之后,就可以尝试对特定音频样本进行分析了。下面是一条命令用来指定输入WAV格式文件及其对应的TFRecord输出位置: ```python python vggish_inference_demo.py --wav_file speech_whistling2.wav --tfrecord_file b.npy ``` 这条指令将会读取给定的声音片段`speech_whistling2.wav`并通过VGGish模型计算得到相应的嵌入向量保存于`b.npy`之中。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值