TensorFlow on Android(9): 运行一个DEMO

在Object Detection API的示例代码中包含了一个训练识别宠物的Demo,包括数据集和相应的一些代码。 虽然本系列中我们会自己准备数据和脚本来进行训练, 但是在这之前我们还需要安装一些库,配置一下环境。 在配置完成之后,我们可以运行一下这个训练宠物的Demo, 以便检查我们的环境配置是否OK,同时对训练过程有先有个整体的了解,然后我们再准备自己的数据和训练脚本。

请确保已经安装好了Python 2.7

安装Object Detection API

首先下载Object Detection API的代码

git clone https://github.com/tensorflow/models.git

然后安装TensorFlow(本系列文章使用tensorflow 1.3.0 )

sudo pip install tensorflow==1.3.0

接着是一些依赖库

sudo pip install pillow
sudo pip install lxml
sudo pip install jupyter
sudo pip install matplotlib

Object Detection API中的模型和训练参数是使用protobuf来序列化和反序列化的,所以在运行之前需要将相应的protobuf文件编译出来

#进入 tensorflow/models/research/
protoc object_detection/protos/*.proto --python_out=.

成功编译以后可以在object_detection/protos/ 下找到生成.py和.pyc文件

接下来将Object Detection API的库加入到PYTHONPATH中

#进入 tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

运行Object Detection API的脚本,以及我们之后自己写的脚本都会用到这些库,如果你不想每次运行前都敲这个命令的话,可以把这条命令加入到~/.bashrc中(需要将pwd展开为实际路径)

最后运行一下测试脚本来检测安装是否正确

#进入 tensorflow/models/research/
python object_detection/builders/model_builder_test.py

如果看到下面的输出,那么Object Detection API的安装就完成了。
enter image description here

下载数据集

数据集由图片和相应的标注文件组成:

wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
tar -xvf annotations.tar.gz
tar -xvf images.tar.gz

完成以后目录应该看起来是这样的
enter image description here

images:
enter image description here
annotations:
enter image description here

在image目录就是一些宠物猫狗的照片, 而在annotations文件夹里面是对相应照片的标注,在annotations文件夹中的和images文件夹中照片文件名一致的xml文件就是标注文件, 这些标注文件为PASCAL VOC格式,我们可以打开Abyssinian_1.xml看一下

enter image description here

标注内容主要为图片的源信息,比如高和宽, 物体的名称,以及所在位置:(xmin, ymin, xmax, ymax)所标识的矩形框。

还记得我们需要一个物体类别的数字编号和物体类别实际名称的对应关系的文件吗? 我们可以在这里找到:

object_detection/data/pet_label_map.pbtxt

文件内容看起来是这样的:
enter image description here

注意:所有物体类别的数字编号都是从1开始的,因为0是一个在数学计算中很特殊的值。 

生成TFRecord文件

Object Detection API的训练框架使用TFRecord格式的文件作为输入。所以这里需要将图片和标注转换为TFRecord格式的文件。

TFRecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。

Demo里面包含了生成对应TFRecord格式文件的脚本,运行:

# 进入 tensorflow/models/research/
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/pet_label_map.pbtxt \
--data_dir=DATA_DIR \
--output_dir=DATA_DIR

这里你需要将DATA_DIR替换为images和annotations所在的文件夹(父文件夹),不出意外的话,生成的文件应该看起像这样:
enter image description here

pet_train.record为训练集, pet_val.record为测试集。

准备转移学习

我们还需要一个Pre-trained模型来进行转移学习,因为我们想尽量的缩短学习的时间,在这里仍然选择上一节课中使用的ssd_mobilenet_v1_coco

下载以后解压备用:

enter image description here

在转移学习中要用的文件是model.ckpt.* 这三个文件。

准备配置文件

我们还需要一个配置文件来对训练的流程进行配置,比如使用什么算法, 选用什么优化器等。在object_detection/samples/configs/可以找到很多配置模板, 在这里使用object_detection/samples/configs/ssd_mobilenet_v1_pets.config作为起始的配置文件,我们需要在这个模板上面稍作修改。

这个配置文件是一个JSON格式的文件,里面有很多配置项,我们先挑一些必须修改的或者重要的项目:

train_input_reader: {
   tf_record_input_reader {
   input_path:   "PATH_OF_TRAIN_TFRECORD"
   }
   label_map_path: "PATH_OF_LABEL_MAP"
}

需要将PATH_OF_TRAIN_TFRECORD替换为pet_train.record的绝对路径,将PATH_OF_LABEL_MAP替换为pet_label_map.pbtxt的绝对路径;

eval_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_OF_VAL_TFRECORD"
  }
  label_map_path: "PATH_OF_LABEL_MAP"
}

需要将PATH_OF_VAL_TFRECORD替换为pet_val.record的绝对路径,将PATH_OF_LABEL_MAP替换为pet_label_map.pbtxt的绝对路径;

train_config: {
  fine_tune_checkpoint: "CHECK_POINT_PATH"
  from_detection_checkpoint: true
  num_steps: 200000
}

如果将from_detection_checkpoint设为true的话,代表我们将从一个事先训练好的模型开始继续训练(转移学习), 此时需要将CHECK_POINT_PATH替换为model.ckpt的绝对路径(注意之前有三个文件, model.ckpt.index, model.ckpt.meta, model.ckpt.data-xxx 在配置时不需要加model.ckpt 之后的后缀), 如: fine_tune_checkpoint: "/root/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"

num_steps为训练迭代的步数, 我们这里暂时不修改。

将改好以后的配置文件重命名为pipeline.config

开始训练

准备好训练数据和配置文件以后, 我们就可以开始进行训练了。通常我们会把训练会用到的文件放到一起(训练目录), 这里建议把训练目录设置为这样:
enter image description here

注意:需要按照这个目录结构修改pipeline.config中的相应项

然后执行训练脚本:

# 进入 tensorflow/models/research/
python object_detection/train.py \
    --logtostderr \
    --pipeline_config_path=${TRAIN_DIR}/model/pipeline.config} \
    --train_dir=${TRAIN_DIR}/model/train

TRAIN_DIR需要替换为训练目录的绝对路径。

如果不出意外的话,你会听到CPU的风扇声开始响起来,电脑变得有点卡,同时可以在终端上看到以下输出:
enter image description here

每一行输出为: 训练迭代步数/当前损失值/每步训练所花时间

基本上可以看出,随着训练的进行, 每一个步的损失值是下降的,那是不是可以喝咖啡等待训练结束了呢?

眼尖的朋友可能已经发现问题了,每一步执行的时间大概在10秒左右,那么按照我们的配置20000步需要 200000 X 10秒 = 23天左右,这显然是不能接受的。

看来用笔记本的CPU进行训练可能不是一个好主意,我们需要更强的计算力:GPU。


我们配置好了训练环境,也把一个训练Demo运行了起来,但是笔记本的CPU运算能力显然不足应付这个任务,那么接下来让我们在GPU上面运行训练。

这里是一个基于TensorFlow Lite和Camera2 API的Android自动检测并拍照的Demo。 首先,在gradle文件中添加以下依赖: ``` implementation &#39;org.tensorflow:tensorflow-lite:2.5.0&#39; implementation &#39;org.tensorflow:tensorflow-lite-gpu:2.5.0&#39; ``` 接下来,创建一个Camera2的预览类CameraPreview,用于预览相机画面,并在其中初始化TensorFlow Lite模型。 ``` public class CameraPreview extends TextureView implements TextureView.SurfaceTextureListener, ImageReader.OnImageAvailableListener { private static final String MODEL_PATH = "model.tflite"; private static final String LABELS_PATH = "labels.txt"; private CameraDevice cameraDevice; private CameraCaptureSession cameraCaptureSession; private CaptureRequest.Builder captureRequestBuilder; private HandlerThread backgroundThread; private Handler backgroundHandler; private ImageReader imageReader; private Interpreter interpreter; private List<String> labels; public CameraPreview(Context context) { super(context); setSurfaceTextureListener(this); initModel(); } private void initModel() { try { // 加载模型 interpreter = new Interpreter(loadModelFile(), new Interpreter.Options()); // 加载标签 labels = loadLabelList(); } catch (IOException e) { e.printStackTrace(); } } private MappedByteBuffer loadModelFile() throws IOException { AssetFileDescriptor fileDescriptor = getContext().getAssets().openFd(MODEL_PATH); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } private List<String> loadLabelList() throws IOException { List<String> labelList = new ArrayList<>(); BufferedReader reader = new BufferedReader(new InputStreamReader(getContext().getAssets().open(LABELS_PATH))); String line; while ((line = reader.readLine()) != null) { labelList.add(line); } reader.close(); return labelList; } @Override public void onSurfaceTextureAvailable(SurfaceTexture surfaceTexture, int width, int height) { openCamera(); } private void openCamera() { CameraManager cameraManager = (CameraManager) getContext().getSystemService(Context.CAMERA_SERVICE); try { String cameraId = cameraManager.getCameraIdList()[0]; CameraCharacteristics cameraCharacteristics = cameraManager.getCameraCharacteristics(cameraId); Size[] outputSizes = cameraCharacteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP).getOutputSizes(ImageFormat.JPEG); imageReader = ImageReader.newInstance(outputSizes[0].getWidth(), outputSizes[0].getHeight(), ImageFormat.JPEG, 1); imageReader.setOnImageAvailableListener(this, backgroundHandler); if (ActivityCompat.checkSelfPermission(getContext(), Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { return; } cameraManager.openCamera(cameraId, new CameraDevice.StateCallback() { @Override public void onOpened(CameraDevice cameraDevice) { CameraPreview.this.cameraDevice = cameraDevice; createCameraPreviewSession(); } @Override public void onDisconnected(CameraDevice cameraDevice) { cameraDevice.close(); CameraPreview.this.cameraDevice = null; } @Override public void onError(CameraDevice cameraDevice, int error) { cameraDevice.close(); CameraPreview.this.cameraDevice = null; } }, backgroundHandler); } catch (CameraAccessException e) { e.printStackTrace(); } } private void createCameraPreviewSession() { SurfaceTexture surfaceTexture = getSurfaceTexture(); surfaceTexture.setDefaultBufferSize(1920, 1080); Surface previewSurface = new Surface(surfaceTexture); Surface readerSurface = imageReader.getSurface(); try { captureRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); captureRequestBuilder.addTarget(previewSurface); captureRequestBuilder.addTarget(readerSurface); cameraDevice.createCaptureSession(Arrays.asList(previewSurface, readerSurface), new CameraCaptureSession.StateCallback() { @Override public void onConfigured(CameraCaptureSession cameraCaptureSession) { CameraPreview.this.cameraCaptureSession = cameraCaptureSession; updatePreview(); } @Override public void onConfigureFailed(CameraCaptureSession cameraCaptureSession) { } }, backgroundHandler); } catch (CameraAccessException e) { e.printStackTrace(); } } private void updatePreview() { if (cameraDevice == null) { return; } captureRequestBuilder.set(CaptureRequest.CONTROL_MODE, CameraMetadata.CONTROL_MODE_AUTO); try { cameraCaptureSession.setRepeatingRequest(captureRequestBuilder.build(), null, backgroundHandler); } catch (CameraAccessException e) { e.printStackTrace(); } } @Override public void onSurfaceTextureSizeChanged(SurfaceTexture surfaceTexture, int width, int height) { } @Override public boolean onSurfaceTextureDestroyed(SurfaceTexture surfaceTexture) { closeCamera(); return true; } private void closeCamera() { if (cameraDevice != null) { cameraDevice.close(); cameraDevice = null; } if (cameraCaptureSession != null) { cameraCaptureSession.close(); cameraCaptureSession = null; } if (imageReader != null) { imageReader.close(); imageReader = null; } } @Override public void onSurfaceTextureUpdated(SurfaceTexture surfaceTexture) { } @Override public void onImageAvailable(ImageReader reader) { Image image = reader.acquireLatestImage(); Bitmap bitmap = getBitmap(image); image.close(); // 在子线程中进行模型推理 new Thread(() -> { String result = recognize(bitmap); if (result.equals("cat")) { // 拍照 takePicture(); } }).start(); } private Bitmap getBitmap(Image image) { ByteBuffer buffer = image.getPlanes()[0].getBuffer(); byte[] bytes = new byte[buffer.remaining()]; buffer.get(bytes); return BitmapFactory.decodeByteArray(bytes, 0, bytes.length); } private String recognize(Bitmap bitmap) { Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true); TensorImage inputImage = new TensorImage(DataType.FLOAT32); inputImage.load(resizedBitmap); TensorBuffer outputBuffer = TensorBuffer.createFixedSize(new int[]{1, labels.size()}, DataType.FLOAT32); interpreter.run(inputImage.getBuffer(), outputBuffer.getBuffer()); float[] results = outputBuffer.getFloatArray(); int index = getMaxIndex(results); return labels.get(index); } private int getMaxIndex(float[] array) { int maxIndex = 0; float max = array[maxIndex]; for (int i = 1; i < array.length; i++) { if (array[i] > max) { max = array[i]; maxIndex = i; } } return maxIndex; } private void takePicture() { try { CaptureRequest.Builder builder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_STILL_CAPTURE); builder.addTarget(imageReader.getSurface()); builder.set(CaptureRequest.CONTROL_MODE, CameraMetadata.CONTROL_MODE_AUTO); builder.set(CaptureRequest.JPEG_ORIENTATION, getOrientation()); cameraCaptureSession.stopRepeating(); cameraCaptureSession.abortCaptures(); cameraCaptureSession.capture(builder.build(), null, null); } catch (CameraAccessException e) { e.printStackTrace(); } } private int getOrientation() { int rotation = ((Activity) getContext()).getWindowManager().getDefaultDisplay().getRotation(); int sensorOrientation = cameraDevice.getCameraCharacteristics(CameraCharacteristics.SENSOR_ORIENTATION); return (rotation + sensorOrientation + 270) % 360; } public void startBackgroundThread() { backgroundThread = new HandlerThread("Camera Background"); backgroundThread.start(); backgroundHandler = new Handler(backgroundThread.getLooper()); } public void stopBackgroundThread() { backgroundThread.quitSafely(); try { backgroundThread.join(); backgroundThread = null; backgroundHandler = null; } catch (InterruptedException e) { e.printStackTrace(); } } } ``` 然后,在Activity中使用CameraPreview类,并在onResume和onPause方法中开始和停止后台线程。 ``` public class MainActivity extends AppCompatActivity { private CameraPreview cameraPreview; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); cameraPreview = new CameraPreview(this); FrameLayout previewLayout = findViewById(R.id.preview_layout); previewLayout.addView(cameraPreview); } @Override protected void onResume() { super.onResume(); cameraPreview.startBackgroundThread(); } @Override protected void onPause() { cameraPreview.stopBackgroundThread(); super.onPause(); } } ``` 注意,这个Demo还需要一张名为"labels.txt"的标签文件和一个名为"model.tflite"的TensorFlow Lite模型文件,应该将它们放在assets目录下。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sufish

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值