基于卷积神经网络的大规模图像分类实战
1. 项目背景与问题描述
如今,美食自拍和以照片为中心的社交分享成为社交趋势。美食爱好者会在社交媒体和相关网站上上传大量美食和餐厅照片,并附上文字评论,这能显著提升餐厅的知名度。以 Yelp 为例,数百万独立访客访问该平台,撰写了超过 1.35 亿条评论,同时上传了大量照片。Yelp 通过向当地商家出售广告盈利,这些照片蕴含了丰富的本地商业信息。
本项目的挑战在于如何将这些图片转化为文字,即构建一个模型,能够自动为餐厅用户提交的照片添加多个标签,从而预测商业属性。
2. 图像数据集描述
为完成这一挑战,我们需要真实的数据集。Kaggle 是一个提供此类数据集的平台,可在 https://www.kaggle.com/c/yelp-restaurant-photo-classification 找到 Yelp 数据集及其描述。
餐厅的标签由 Yelp 用户在提交评论时手动选择,数据集中有 Yelp 社区标注的九种不同标签:
| 标签编号 | 标签含义 |
| ---- | ---- |
| 0 | good_for_lunch |
| 1 | good_for_dinner |
| 2 | takes_reservations |
| 3 | outdoor_seating |
| 4 | restaurant_is_expensive |
| 5 | has_alcohol |
| 6 | has_table_service |
| 7 | ambience_is_classy |
| 8 | good_for_kids |
数据集包含六个文件:
- train_photos.tgz:用作训练集的照片(234,545 张图像)
- test_photos.tgz:用作测试集的照片(500 张图像)
- train_photo_to_biz_ids.csv:提供照片 ID 到商业 ID 的映射(234,545 行)
- test_photo_to_biz_ids.csv:提供照片 ID 到商业 ID 的映射(500 行)
- train.csv:主要训练数据集,包括商业 ID 及其相应标签(1996 行)
- sample_submission.csv:示例提交文件,为预测提供正确格式,包括 business_id 和相应的预测标签
3. 整体项目工作流程
3.1 图像预处理
- 读取图像 :将 .jpg 格式的图像读取为 Scala 中的矩阵表示。
- 图像操作 :对图像进行一系列操作,如将所有图像变为正方形、调整大小为相同维度,最后应用灰度滤镜。
graph LR
A[读取图像] --> B[图像变正方形]
B --> C[调整图像大小]
C --> D[应用灰度滤镜]
3.2 模型训练
- 训练 CNN :为每个类别在训练数据上训练九个 CNN。
- 保存模型 :训练完成后,保存训练好的模型、CNN 配置和参数。
3.3 模型评估
- 聚合分类 :应用简单的聚合函数为每个餐厅分配类别。
- 测试评分 :对测试数据进行评分,并使用测试图像评估模型。
4. 实现 CNN 进行图像分类
4.1 工作流程步骤
- 从 train.csv 文件中读取所有商业标签。
- 读取并创建从图像 ID 到商业 ID 的映射(imageID → busID)。
- 从 photoDir 目录获取要加载和处理的图像列表,并获取 10,000 张图像的 ID(可自定义范围)。
- 将图像读取并处理为 photoID → vector 映射。
- 链接步骤 3 和步骤 4 的输出,对齐商业特征、图像 ID 和标签 ID,为 CNN 提取特征。
- 构建九个 CNN。
- 训练所有 CNN 并指定模型保存位置。
- 重复步骤 2 到步骤 6,从测试集中提取特征。
- 评估模型并将预测结果保存到 CSV 文件。
4.2 代码实现
val labelMap = readBusinessLabels("data/labels/train.csv")
val businessMap =
readBusinessToImageLabels("data/labels/train_photo_to_biz_ids.csv")
val imgs = getImageIds("data/images/train/", businessMap,
businessMap.map(_._2).toSet.toList).slice(0,100) // 20000 images
println("Image ID retreival done!")
val dataMap = processImages(imgs, resizeImgDim = 128)
println("Image processing done!")
val alignedData = new featureAndDataAligner(dataMap, businessMap,
Option(labelMap))()
println("Feature extraction done!")
val cnn0 = trainModelEpochs(alignedData, businessClass = 0, saveNN =
"models/model0")
val cnn1 = trainModelEpochs(alignedData, businessClass = 1, saveNN =
"models/model1")
val cnn2 = trainModelEpochs(alignedData, businessClass = 2, saveNN =
"models/model2")
val cnn3 = trainModelEpochs(alignedData, businessClass = 3, saveNN =
"models/model3")
val cnn4 = trainModelEpochs(alignedData, businessClass = 4, saveNN =
"models/model4")
val cnn5 = trainModelEpochs(alignedData, businessClass = 5, saveNN =
"models/model5")
val cnn6 = trainModelEpochs(alignedData, businessClass = 6, saveNN =
"models/model6")
val cnn7 = trainModelEpochs(alignedData, businessClass = 7, saveNN =
"models/model7")
val cnn8 = trainModelEpochs(alignedData, businessClass = 8, saveNN =
"models/model8")
val businessMapTE =
readBusinessToImageLabels("data/labels/test_photo_to_biz.csv")
val imgsTE = getImageIds("data/images/test//", businessMapTE,
businessMapTE.map(_._2).toSet.toList)
val dataMapTE = processImages(imgsTE, resizeImgDim = 128) // make them 128*128
val alignedDataTE = new featureAndDataAligner(dataMapTE, businessMapTE,
None)()
val Results = SubmitObj(alignedDataTE, "results/ModelsV0/")
val SubmitResults = writeSubmissionFile("kaggleSubmitFile.csv", Results,
thresh = 0.9)
5. 图像预处理
5.1 图像形状处理
由于 CNN 无法处理大小和形状各异的图像,我们需要对图像进行预处理。首先将不规则形状的图像变为正方形,代码如下:
def makeSquare(img: java.awt.image.BufferedImage):
java.awt.image.BufferedImage = {
val w = img.getWidth
val h = img.getHeight
val dim = List(w, h).min
img match {
case x if w == h => img // do nothing and returns the original one
case x if w > h => Scalr.crop(img, (w - h) / 2, 0, dim, dim)
case x if w < h => Scalr.crop(img, 0, (h - w) / 2, dim, dim)
}
}
5.2 图像大小调整
将所有图像调整为 128 x 128 大小,代码如下:
def resizeImg(img: java.awt.image.BufferedImage, width: Int, height: Int) =
{
Scalr.resize(img, Scalr.Method.BALANCED, width, height)
}
5.3 灰度转换
为简化计算,将图像转换为灰度图像,代码如下:
def pixels2Gray(R: Int, G: Int, B: Int): Int = (R + G + B) / 3
def makeGray(testImage: java.awt.image.BufferedImage):
java.awt.image.BufferedImage = {
val w = testImage.getWidth
val h = testImage.getHeight
for {
w1 <- (0 until w).toVector
h1 <- (0 until h).toVector
}
yield
{
val col = testImage.getRGB(w1, h1)
val R = (col & 0xff0000) / 65536
val G = (col & 0xff00) / 256
val B = (col & 0xff)
val graycol = pixels2Gray(R, G, B)
testImage.setRGB(w1, h1, new Color(graycol, graycol, graycol).getRGB)
}
testImage
}
5.4 综合处理
将上述三个步骤链在一起,代码如下:
import scala.Vector
import org.imgscalr._
object imageUtils {
implicit class imageProcessingPipeline(img:
java.awt.image.BufferedImage) {
// image 2 vector processing
def pixels2gray(R: Int, G:Int, B: Int): Int = (R + G + B) / 3
def pixels2color(R: Int, G:Int, B: Int): Vector[Int] = Vector(R, G, B)
private def image2vec[A](f: (Int, Int, Int) => A ): Vector[A] = {
val w = img.getWidth
val h = img.getHeight
for {
w1 <- (0 until w).toVector
h1 <- (0 until h).toVector
}
yield {
val col = img.getRGB(w1, h1)
val R = (col & 0xff0000) / 65536
val G = (col & 0xff00) / 256
val B = (col & 0xff)
f(R, G, B)
}
}
def image2gray: Vector[Int] = image2vec(pixels2gray)
def image2color: Vector[Int] = image2vec(pixels2color).flatten
// make image square
def makeSquare = {
val w = img.getWidth
val h = img.getHeight
val dim = List(w, h).min
img match {
case x if w == h => img
case x if w > h => Scalr.crop(img, (w-h)/2, 0, dim, dim)
case x if w < h => Scalr.crop(img, 0, (h-w)/2, dim, dim)
}
}
// resize pixels
def resizeImg(width: Int, height: Int) = {
Scalr.resize(img, Scalr.Method.BALANCED, width, height)
}
}
}
6. 提取图像元数据
使用
readMetadata()
方法读取 CSV 格式的图像元数据,该方法在
CSVImageMetadataReader.scala
脚本中定义:
def readMetadata(csv: String, rows: List[Int]=List(-1)): List[List[String]]
= {
val src = Source.fromFile(csv)
def reading(csv: String): List[List[String]]= {
src.getLines.map(x => x.split(",").toList)
.toList
}
try {
if(rows==List(-1)) reading(csv)
else rows.map(reading(csv))
}
finally {
src.close
}
}
6.1 商业标签读取
使用
readBusinessLabels()
方法将商业 ID 映射到标签集合,代码如下:
def readBusinessLabels(csv: String, rows: List[Int]=List(-1)): Map[String,
Set[Int]] = {
val reader = readMetadata(csv)
reader.drop(1)
.map(x => x match {
case x :: Nil => (x(0).toString, Set[Int]())
case _ => (x(0).toString, x(1).split(" ").map(y => y.toInt).toSet)
}).toMap
}
6.2 图像到商业 ID 映射读取
使用
readBusinessToImageLabels()
方法将图像 ID 映射到商业 ID,代码如下:
def readBusinessToImageLabels(csv: String, rows: List[Int] = List(-1)):
Map[Int, String] = {
val reader = readMetadata(csv)
reader.drop(1)
.map(x => x match {
case x :: Nil => (x(0).toInt, "-1")
case _ => (x(0).toInt, x(1).split(" ").head)
}).toMap
}
7. 图像特征提取
7.1 定义正则表达式
val patt_get_jpg_name = new Regex("[0-9]")
7.2 提取图像 ID
def getImgIdsFromBusinessId(bizMap: Map[Int, String], businessIds:
List[String]): List[Int] = {
bizMap.filter(x => businessIds.exists(y => y == x._2)).map(_._1).toList
}
7.3 获取图像路径
def getImageIds(photoDir: String, businessMap: Map[Int, String] = Map(-1 ->
"-1"), businessIds:
List[String] = List("-1")): List[String] = {
val d = new File(photoDir)
val imgsPath = d.listFiles().map(x => x.toString).toList
if (businessMap == Map(-1 -> "-1") || businessIds == List(-1)) {
imgsPath
}
else {
val imgsMap = imgsPath.map(x =>
patt_get_jpg_name.findAllIn(x).mkString.toInt -> x).toMap
val imgsPathSub = getImgIdsFromBusinessId(businessMap, businessIds)
imgsPathSub.filter(x => imgsMap.contains(x)).map(x => imgsMap(x))
}
}
7.4 处理图像
def processImages(imgs: List[String], resizeImgDim: Int = 128, nPixels: Int
= -1): Map[Int,Vector[Int]]= {
imgs.map(x => patt_get_jpg_name.findAllIn(x).mkString.toInt -> {
val img0 = ImageIO.read(new File(x))
.makeSquare
.resizeImg(resizeImgDim, resizeImgDim) // (128, 128)
.image2gray
if(nPixels != -1) img0.slice(0, nPixels)
else img0
}).filter( x => x._2 != ())
.toMap
}
7.5 数据对齐
val alignedData = new featureAndDataAligner(dataMap, businessMap,
Option(labelMap))()
class featureAndDataAligner(dataMap: Map[Int, Vector[Int]], bizMap:
Map[Int, String], labMap: Option[Map[String, Set[Int]]])(rowindices:
List[Int] = dataMap.keySet.toList) {
def this(dataMap: Map[Int, Vector[Int]], bizMap: Map[Int,
String])(rowindices: List[Int]) = this(dataMap, bizMap,
None)(rowindices)
def alignBusinessImgageIds(dataMap: Map[Int, Vector[Int]], bizMap:
Map[Int, String])
(rowindices: List[Int] = dataMap.keySet.toList): List[(Int, String,
Vector[Int])] = {
for {
pid <- rowindices
val imgHasBiz = bizMap.get(pid)
// returns None if img doe not have a bizID
val bid = if(imgHasBiz != None) imgHasBiz.get
else "-1"
if (dataMap.keys.toSet.contains(pid) && imgHasBiz != None)
}
yield {
(pid, bid, dataMap(pid))
}
}
def alignLabels(dataMap: Map[Int, Vector[Int]], bizMap: Map[Int, String],
labMap: Option[Map[String, Set[Int]]])(rowindices: List[Int] =
dataMap.keySet.toList): List[(Int, String, Vector[Int], Set[Int])] = {
def flatten1[A, B, C, D](t: ((A, B, C), D)): (A, B, C, D) = (t._1._1,
t._1._2, t._1._3, t._2)
val al = alignBusinessImgageIds(dataMap, bizMap)(rowindices)
for { p <- al
}
yield {
val bid = p._2
val labs = labMap match {
case None => Set[Int]()
case x => (if(x.get.keySet.contains(bid)) x.get(bid)
else Set[Int]())
}
flatten1(p, labs)
}
}
lazy val data = alignLabels(dataMap, bizMap, labMap)(rowindices)
// getter functions
def getImgIds = data.map(_._1)
def getBusinessIds = data.map(_._2)
def getImgVectors = data.map(_._3)
def getBusinessLabels = data.map(_._4)
def getImgCntsPerBusiness =
getBusinessIds.groupBy(identity).mapValues(x => x.size)
}
8. 准备 ND4j 数据集
8.1 创建数据集
def makeDataSet(alignedData: featureAndDataAligner, bizClass: Int): DataSet
= {
val alignedXData = alignedData.getImgVectors.toNDArray
val alignedLabs = alignedData.getBusinessLabels.map(x =>
if (x.contains(bizClass)) Vector(1, 0)
else Vector(0, 1)).toNDArray
new DataSet(alignedXData, alignedLabs)
}
8.2 转换为 INDArray
def makeDataSetTE(alignedData: featureAndDataAligner): INDArray = {
alignedData.getImgVectors.toNDArray
}
9. 训练 CNN 并保存模型
9.1 训练模型
def trainModelEpochs(alignedData: featureAndDataAligner, businessClass: Int
= 1, saveNN: String = "") = {
val ds = makeDataSet(alignedData, businessClass)
val nfeatures = ds.getFeatures.getRow(0).length // Hyperparameter
val numRows = Math.sqrt(nfeatures).toInt //numRows*numColumns == data*channels
val numColumns = Math.sqrt(nfeatures).toInt //numRows*numColumns == data*channels
val nChannels = 1 // would be 3 if color image w R,G,B
val outputNum = 9 // # of classes (# of columns in output)
val iterations = 1
val splitTrainNum = math.ceil(ds.numExamples * 0.8).toInt // 80/20 training/test split
val seed = 12345
val listenerFreq = 1
val nepochs = 20
val nbatch = 128 // recommended between 16 and 128
ds.normalizeZeroMeanZeroUnitVariance()
Nd4j.shuffle(ds.getFeatureMatrix, new Random(seed), 1) // shuffles rows in the ds.
Nd4j.shuffle(ds.getLabels, new Random(seed), 1) // shuffles labels accordingly
val trainTest: SplitTestAndTrain = ds.splitTestAndTrain(splitTrainNum,
new Random(seed))
// creating epoch dataset iterator
val dsiterTr = new ListDataSetIterator(trainTest.getTrain.asList(),
nbatch)
val dsiterTe = new ListDataSetIterator(trainTest.getTest.asList(),
nbatch)
val epochitTr: MultipleEpochsIterator = new
MultipleEpochsIterator(nepochs, dsiterTr)
val epochitTe: MultipleEpochsIterator = new
MultipleEpochsIterator(nepochs, dsiterTe)
//First convolution layer with ReLU as activation function
val layer_0 = new ConvolutionLayer.Builder(6, 6)
.nIn(nChannels)
.stride(2, 2) // default stride(2,2)
.nOut(20) // # of feature maps
.dropOut(0.5)
.activation("relu") // rectified linear units
.weightInit(WeightInit.RELU)
.build()
//First subsampling layer
val layer_1 = new
SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build()
//Second convolution layer with ReLU as activation function
val layer_2 = new ConvolutionLayer.Builder(6, 6)
.nIn(nChannels)
.stride(2, 2)
.nOut(50)
.activation("relu")
.build()
//Second subsampling layer
val layer_3 = new
SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build()
//Dense layer
val layer_4 = new DenseLayer.Builder()
.activation("relu")
.nOut(500)
.build()
// Final and fully connected layer with Softmax as activation function
val layer_5 = new
OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nOut(outputNum)
.weightInit(WeightInit.XAVIER)
.activation("softmax")
.build()
val builder: MultiLayerConfiguration.Builder = new
NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.miniBatch(true)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.regularization(true).l2(0.0005)
.learningRate(0.01)
.list(6)
.layer(0, layer_0)
.layer(1, layer_1)
.layer(2, layer_2)
.layer(3, layer_3)
.layer(4, layer_4)
.layer(5, layer_5)
.backprop(true).pretrain(false)
new ConvolutionLayerSetup(builder, numRows, numColumns, nChannels)
val conf: MultiLayerConfiguration = builder.build()
val model: MultiLayerNetwork = new MultiLayerNetwork(conf)
model.init()
model.setListeners(Seq[IterationListener](new
ScoreIterationListener(listenerFreq)).asJava)
model.fit(epochitTr)
val eval = new Evaluation(outputNum)
while (epochitTe.hasNext) {
val testDS = epochitTe.next(nbatch)
val output: INDArray = model.output(testDS.getFeatureMatrix)
eval.eval(testDS.getLabels(), output)
}
if (!saveNN.isEmpty) {
// model config
FileUtils.write(new File(saveNN + ".json"),
model.getLayerWiseConfigurations().toJson())
// model parameters
val dos: DataOutputStream = new
DataOutputStream(Files.newOutputStream(Paths.get(saveNN + ".bin")))
Nd4j.write(model.params(), dos)
}
}
9.2 保存模型
def saveNN(model: MultiLayerNetwork, NNconfig: String, NNparams: String) =
{
// save neural network config
FileUtils.write(new File(NNconfig),
model.getLayerWiseConfigurations().toJson())
// save neural network parms
val dos: DataOutputStream = new
DataOutputStream(Files.newOutputStream(Paths.get(NNparams)))
Nd4j.write(model.params(), dos)
}
9.3 加载模型
def loadNN(NNconfig: String, NNparams: String) = {
// get neural network config
val confFromJson: MultiLayerConfiguration =
MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new
File(NNconfig)))
// get neural network parameters
val dis: DataInputStream = new DataInputStream(new
FileInputStream(NNparams))
val newParams = Nd4j.read(dis)
// creating network object
val savedNetwork: MultiLayerNetwork = new
MultiLayerNetwork(confFromJson)
savedNetwork.init()
savedNetwork.setParameters(newParams)
savedNetwork
}
10. 模型评估
10.1 模型评分
def scoreModel(model: MultiLayerNetwork, ds: INDArray) = {
model.output(ds)
}
10.2 聚合评分
def aggImgScores2Business(scores: INDArray, alignedData:
featureAndDataAligner ) = {
assert(scores.size(0) == alignedData.data.length, "alignedData and
scores length are different. They must be equal")
def getRowIndices4Business(mylist: List[String], mybiz: String): List[Int]
= mylist.zipWithIndex.filter(x => x._1 == mybiz).map(_._2)
def mean(xs: List[Double]) = xs.sum / xs.size
alignedData.getBusinessIds.distinct.map(x => (x, {
val irows = getRowIndices4Business(alignedData.getBusinessIds, x)
val ret =
for(row <- irows)
yield scores.getRow(row).getColumn(1).toString.toDouble
mean(ret)
}))
}
11. 执行主方法
package Yelp.Classifier
import Yelp.Preprocessor.CSVImageMetadataReader._
import Yelp.Preprocessor.featureAndDataAligner
import Yelp.Preprocessor.imageFeatureExtractor._
import Yelp.Evaluator.ResultFileGenerator._
import Yelp.Preprocessor.makeND4jDataSets._
import Yelp.Evaluator.ModelEvaluation._
import Yelp.Trainer.CNNEpochs._
import Yelp.Trainer.NeuralNetwork._
object YelpImageClassifier {
def main(args: Array[String]): Unit = {
// image processing on training data
val labelMap = readBusinessLabels("data/labels/train.csv")
val businessMap =
readBusinessToImageLabels("data/labels/train_photo_to_biz_ids.csv")
val imgs = getImageIds("data/images/train/", businessMap,
businessMap.map(_._2).toSet.toList).slice(0,20000) // 20000 images
println("Image ID retreival done!")
val dataMap = processImages(imgs, resizeImgDim = 256)
println("Image processing done!")
val alignedData =
new featureAndDataAligner(dataMap, businessMap,
Option(labelMap))()
println("Feature extraction done!")
// training one model for one class at a time. Many hyperparamters
// hardcoded within
val cnn0 = trainModelEpochs(alignedData, businessClass = 0, saveNN
= "models/model0")
val cnn1 = trainModelEpochs(alignedData, businessClass = 1, saveNN
= "models/model1")
val cnn2 = trainModelEpochs(alignedData, businessClass = 2, saveNN
= "models/model2")
val cnn3 = trainModelEpochs(alignedData, businessClass = 3, saveNN
= "models/model3")
val cnn4 = trainModelEpochs(alignedData, businessClass = 4, saveNN
= "models/model4")
val cnn5 = trainModelEpochs(alignedData, businessClass = 5, saveNN
= "models/model5")
val cnn6 = trainModelEpochs(alignedData, businessClass = 6, saveNN
= "models/model6")
val cnn7 = trainModelEpochs(alignedData, businessClass = 7, saveNN
= "models/model7")
val cnn8 = trainModelEpochs(alignedData, businessClass = 8, saveNN
= "models/model8")
// processing test data for scoring
val businessMapTE =
readBusinessToImageLabels("data/labels/test_photo_to_biz.csv")
val imgsTE = getImageIds("data/images/test//", businessMapTE,
businessMapTE.map(_._2).toSet.toList)
val dataMapTE = processImages(imgsTE, resizeImgDim = 128) // make
// them 256x256
val alignedDataTE = new featureAndDataAligner(dataMapTE,
businessMapTE, None)()
// creating csv file to submit to kaggle (scores all models)
val Results = SubmitObj(alignedDataTE, "results/ModelsV0/")
val SubmitResults = writeSubmissionFile("kaggleSubmitFile.csv",
Results, thresh = 0.9)
}
}
12. 模型评估结果
==========================Scores======================================
Accuracy: 0.6833
Precision: 0.53
Recall: 0.5222
F1 Score: 0.5261
======================================================================
本项目通过一系列步骤,从图像预处理、特征提取、模型训练到最终评估,实现了基于卷积神经网络的大规模图像分类。通过合理设置超参数和优化模型结构,我们可以进一步提高模型的性能。
13. 模型评估方法总结
模型评估是整个图像分类项目中至关重要的环节,它能够直观地反映模型的性能优劣。在本项目中,我们采用了多种评估指标,如准确率(Accuracy)、精确率(Precision)、召回率(Recall)和 F1 分数(F1 Score)。这些指标从不同角度衡量了模型的分类效果,具体解释如下:
-
准确率(Accuracy)
:表示模型正确分类的样本数占总样本数的比例。在本项目的评估结果中,准确率为 0.6833,这意味着模型大约有 68.33% 的样本分类是正确的。
-
精确率(Precision)
:是指模型预测为正类的样本中,实际为正类的比例。本项目的精确率为 0.53,说明在模型预测为正类的样本中,只有 53% 是真正的正类。
-
召回率(Recall)
:也称为灵敏度,它衡量的是实际为正类的样本中,被模型正确预测为正类的比例。本项目的召回率为 0.5222,即模型能够召回大约 52.22% 的正类样本。
-
F1 分数(F1 Score)
:是精确率和召回率的调和平均数,它综合考虑了精确率和召回率两个指标。F1 分数越接近 1,说明模型的性能越好。本项目的 F1 分数为 0.5261,表明模型在精确率和召回率之间取得了一定的平衡,但仍有提升的空间。
评估指标表格
| 评估指标 | 数值 |
|---|---|
| 准确率(Accuracy) | 0.6833 |
| 精确率(Precision) | 0.53 |
| 召回率(Recall) | 0.5222 |
| F1 分数(F1 Score) | 0.5261 |
14. 项目整体流程回顾
整个项目的流程可以概括为以下几个主要步骤,每个步骤都紧密相连,共同构成了一个完整的图像分类系统:
1.
数据准备
:从 Kaggle 平台获取 Yelp 数据集,包括训练集和测试集的图像以及相关的元数据文件。对数据进行整理和分析,了解数据的结构和特点。
2.
图像预处理
:由于原始图像的大小和形状各异,不适合直接输入到 CNN 中进行训练,因此需要对图像进行预处理。具体操作包括将图像变为正方形、调整图像大小为 128x128 像素,并将图像转换为灰度图像。这些操作可以使图像数据更加规整,便于后续的特征提取和模型训练。
3.
特征提取
:通过读取图像元数据文件,建立图像 ID 到商业 ID 的映射以及商业 ID 到标签的映射。然后,从预处理后的图像中提取特征,并将其转换为适合 CNN 输入的格式。
4.
模型训练
:为每个类别在训练数据上训练九个 CNN 模型。在训练过程中,需要设置合适的超参数,如学习率、批量大小、迭代次数等,并对数据集进行归一化和洗牌操作,以提高模型的训练效果。训练完成后,保存模型的配置和参数,以便后续的使用和评估。
5.
模型评估
:使用测试集对训练好的模型进行评估,计算准确率、精确率、召回率和 F1 分数等评估指标。通过对评估结果的分析,了解模型的性能表现,并根据评估结果对模型进行调整和优化。
项目流程 mermaid 流程图
graph LR
A[数据准备] --> B[图像预处理]
B --> C[特征提取]
C --> D[模型训练]
D --> E[模型评估]
15. 项目优化建议
尽管本项目已经取得了一定的成果,但仍有许多可以优化的地方,以下是一些具体的优化建议:
-
数据方面
-
数据增强
:通过对训练图像进行旋转、翻转、缩放等操作,增加训练数据的多样性,从而提高模型的泛化能力。
-
数据清洗
:进一步清理数据集中的噪声和重复数据,提高数据的质量。例如,去除模糊不清、内容无关的图像,以及重复上传的图像。
-
模型方面
-
调整超参数
:通过网格搜索、随机搜索等方法,寻找最优的超参数组合,如学习率、批量大小、迭代次数等。不同的超参数设置可能会对模型的性能产生显著影响,因此需要进行细致的调优。
-
改进模型结构
:尝试使用更复杂的 CNN 架构,如 ResNet、Inception 等,或者增加卷积层和全连接层的数量,以提高模型的表达能力。
-
集成学习
:将多个不同的 CNN 模型进行集成,通过投票、平均等方式综合各个模型的预测结果,从而提高模型的准确性和稳定性。
-
评估方面
-
使用更多评估指标
:除了准确率、精确率、召回率和 F1 分数外,还可以考虑使用其他评估指标,如 ROC 曲线、AUC 值等,以更全面地评估模型的性能。
-
交叉验证
:采用交叉验证的方法,将数据集划分为多个子集,轮流使用不同的子集进行训练和测试,从而更准确地评估模型的泛化能力。
16. 总结
本项目通过使用卷积神经网络实现了大规模图像分类任务,从项目背景的分析、数据集的准备、图像预处理、特征提取、模型训练到最终的模型评估,每个步骤都进行了详细的阐述和实现。通过对模型评估结果的分析,我们了解了模型的性能表现,并提出了一些优化建议。在实际应用中,我们可以根据具体的需求和数据特点,对项目进行进一步的优化和改进,以提高模型的性能和实用性。同时,本项目也为其他图像分类任务提供了一个可参考的范例,希望能够对相关领域的研究和实践有所帮助。
在未来的工作中,我们可以继续探索更先进的技术和方法,不断提升图像分类的准确率和效率,为更多的实际应用场景提供支持。例如,将图像分类技术应用于智能安防、医疗影像诊断、自动驾驶等领域,为这些领域的发展带来新的机遇和挑战。
超级会员免费看
8523

被折叠的 条评论
为什么被折叠?



