在VsCode中利用TensorFlow.js结合迁移学习实现商标识别。
一、加载商标数据并可视化
数据保存在data文件夹下面,需要先在data文件夹下创建一个静态服务器,用于加载图片。
http-server data --cors
Available on:
http://192.168.4.167:8080
http://127.0.0.1:8080
Hit CTRL-C to stop the server
编写获取图片的脚本文件。
const IMAGE_SIZE = 224;
const loadImg = (src) => {
return new Promise(resolve => {
const img = new Image();
img.crossOrigin = "anonymous";
img.src = src;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onload = () => resolve(img);
});
};
export const getInputs = async () => {
const loadImgs = [];
const labels = [];
for (let i = 0; i < 30; i += 1) {
['android', 'apple', 'windows'].forEach(label => {
const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
const img = loadImg(src);
loadImgs.push(img);
labels.push([
label === 'android' ? 1 : 0,
label === 'apple' ? 1 : 0,
label === 'windows' ? 1 : 0,
]);
});
}
const inputs = await Promise.all(loadImgs);
return {
inputs,
labels,
};
}
创建index.html文件,作为程序的入口文件,在index.html中利用script标签跳转到script.js文件,在script.js中编写主要代码。
加载图片数据。
import {getInputs} from "./data";
window.onload = async() =>{
const {inputs, labels} = await getInputs();
console.log(inputs, labels)
};

利用TensorFlow.js中的tfvis进行可视化。
import * as tfvis from "@tensorflow/tfjs-vis"
// 可视化图片
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});

每行显示两个,旁边的滚动体可以拉动查看更多图片。
二、定义模型结构
加载MobileNet模型并截断所有的卷积池化操作,生成截断模型。并定义新的全连接层。
import * as tf from "@tensorflow/tfjs"
// mobilenet模型存放位置
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
// 加载MobileNet模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 查看模型结构
mobilenet.summary();
// 截断mobilenet卷积操作
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 定义全连接层
const model = tf.sequential();
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
// 定义输出层
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
// 配置损失函数和优化器
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
三、迁移学习下的模型训练
首先先定义一个工具类utils.js,用于处理输入到截断模型(mobilenet)中的数据。
import * as tf from '@tensorflow/tfjs';
// img格式转成tensor
export function img2x(imgEl){
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
// 图片文件转成img格式
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
训练数据经过截断模型输出,转为可以用于自定义的全连接层的输入数据。
// 先经过截断模型
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl =>truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});
// 截断模型的输出当成自定义模型的输入
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});

可以看出训练损失值降得非常低,因为采用迁移模型,卷积层的参数是使用别人训练好的,这部分参数的训练结果是非常优秀的。
四、预测
编写前端页面用于上传带预测图片,就是编写一个上传按钮。
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
将预测图片先经过mobileNet预测,吐出来的结果再经过自定义模型预测。
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
五、完整代码
index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
script.js
import * as tf from "@tensorflow/tfjs"
import * as tfvis from "@tensorflow/tfjs-vis"
import {getInputs} from "./data";
import {img2x, file2img} from "./utils"
// mobilenet模型存放位置
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async() =>{
// 加载图片
const {inputs, labels} = await getInputs();
// 可视化图片
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
// 加载MobileNet模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 查看模型结构
mobilenet.summary();
// 截断mobilenet卷积操作
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 定义全连接层
const model = tf.sequential();
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
// 定义输出层
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
// 配置损失函数和优化器
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
// 先经过截断模型
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});
// 截断模型的输出当成自定义模型的输入
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});
// 预测
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
};
data.js
const IMAGE_SIZE = 224;
const loadImg = (src) => {
return new Promise(resolve => {
const img = new Image();
img.crossOrigin = "anonymous";
img.src = src;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onload = () => resolve(img);
});
};
export const getInputs = async () => {
const loadImgs = [];
const labels = [];
for (let i = 0; i < 30; i += 1) {
['android', 'apple', 'windows'].forEach(label => {
const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
const img = loadImg(src);
loadImgs.push(img);
labels.push([
label === 'android' ? 1 : 0,
label === 'apple' ? 1 : 0,
label === 'windows' ? 1 : 0,
]);
});
}
const inputs = await Promise.all(loadImgs);
return {
inputs,
labels,
};
}
utils.js
import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl){
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
本文介绍如何在VsCode中利用TensorFlow.js结合迁移学习实现商标识别。文章详细讲解了数据加载、模型定义、训练过程及预测方法。
2309

被折叠的 条评论
为什么被折叠?



