本文主要是使用caffe python做图片识别的示例包括训练数据lmdb生成,训练,以及模型测试,主要内容如下:
-
训练,验证数据lmdb生成,主要包括:样本的预处理 (直方图均衡化,resize),训练样本以及验证样本的lmdb的生成,以及mean_file mean.binaryproto生成
-
caffe中模型的定义,主要是修改 caffe Alexnet 训练文件train_val.prototxt ,以及训练参数文件solver.prototxt ,还有部署文件deploy.prototxt
-
训练验证数据准备完成之后,就是模型的训练
-
得到训练模型之后,一般会进行本地测试以及从数据库获取url测试然后将结果写到数据库中
先上个代码的框架图,说明见图片(下面会有详细的讲解):
下面给出最终的识别结果:
注:本文做图像分类的时候大概是在2016年,第一个分类模型用的是
Alexnet
这个模型现在基本不怎么用了。一般用的是googlenet v2版本
。
而且caffe的
model zoo
https://github.com/BVLC/caffe/wiki/Model-Zoo中有不少新的模型,比如Towards Principled Design of Deep Convolutional Networks: Introducing SimpNet
感兴趣的可以多多尝试下。
1. 训练,验证数据lmdb生成
-
对图片进行预处理包括直方图均衡化(Histogram equalization)以及resize到指定的大小,并生成lmdb格式,图片以及对于的标签(label)
-
按照一定的比例生成,训练样本lmdb以及验证样本lmdb,以及mean_file mean.binaryproto
-
在测试的时候,我们往往是从数据库中读取url以及id信息,然后将url转化为cv2 可以处理的图片样式,因此我们还要实现将url转化cv2可以处理的图片
1.1 图片进行预处理包括直方图均衡化,url->cv2 image 格式
下面通过代码来讲解(文件: utils->img_process.py):
# _*_coding:utf-8 _*_
import cv2
import urllib
import numpy as np
IMG_HEIGHT = 227
IMG_WIDTH = 227
# 对图片做直方图均衡化处理
def pre_process_img(img, img_height=IMG_HEIGHT, img_width=IMG_WIDTH):
# firstly histogram equalization
img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
# resize image to size
img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)
return img
# 通过图片url将其转化为cv2可以处理的形式
def get_cv_img__from_url(url):
"""
read image from url to cv codec
:param url:
:return:
"""
try:
url_response = urllib.urlopen(url)
img_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
img = cv2.imdecode(img_array, -1)
return img
except Exception, e:
print e
return None
if __name__ == '__main__':
url = 'http://www.sanyarb.com.cn/images/attachement/jpg/site2/20161009/A121475977636942_change_ljx6a9_b.jpg'
img = get_cv_img__from_url(url)
cv2.imshow("zhan lang", img)
img = pre_process_img(img)
cv2.imshow("pre_process_img", img)
cv2.waitKey()
pass
下面是下载网上的图片,然后对其进行直方图均衡化以及resize的运行的结果:
1.2 图片按照一定的比例生成训练样本以及验证样本lmdb]
# _*_coding:utf-8 _*_
import sys
sys.path.insert(0, '../../caffe_train_test/')
import os
import glob
import random
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
import lmdb
from utils.img_process import *
# 根据图片和标签转化为对应的lmdb格式
def make_datum(img, l