CS231n-线性SVM分类Cifar10

本文详细介绍了如何使用线性SVM对CIFAR10数据集进行分类。从目标函数、代价函数(包含正则项和折叶损失)、梯度计算到模型的训练和预测函数,逐一深入讲解。最后通过自动化确定超参数并训练模型,得到测试集准确率约为37%。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

线性分类器其实我已经接触不少了,不同于KNN,它涉及到了更多的知识,比如cost function, objective function等,svm涉及到的知识确实比较多且难理解,但当我们得到相应公式后其实实现起来并不算繁琐,相反很容易理解

线性SVM

目标函数

目标函数我们在以前的线性回归,逻辑回归中都见到过

目标函数

代价函数

SVM的代价函数想要SVM在正确分类上的得分始终比不正确分类上的得分高出一个边界值delta, 所以它的代价函数如下:

代价函数

即在确定正确分类的分数(scores[y[i]])后,其他分类上的分数都要减去它并且加上一个边界值 ( scores[y[!i] - screos[y[i]] + delta ),当得到的值小于0时则代表正确分类比不正确分类高出了一个边界值,否则则要计算损失值。 比如,假设有3个分类,并且得到了分值[13, -7, 11], 第一个分类为正确分类,delta为10,那么根据代价函数,我们可以得到以下算式

计算公式

以上代价函数计算公式称为折叶损失(hinge loss),当然除此之外还有平方折叶损失SVM(即L2-SVM),就是加个平方,我们可以通过交叉验证或者验证集来确定到底选用哪个

正则项

在ML中,过拟合问题一直是影响模型准确率的重大因素,所以我们还要加上L2范式正则项(在这里,正则项还确保了SVM有最大边界(max margin)等好处),所以最终我们得到以下整个代价函数公式

完整代价函数

梯度计算

在训练过程中,我们需要通过最优化方法来是代价函数的损失值达到尽可能的小,所以我们对代价函数进行微分,然后计算其偏导数,得到以下公式

对于每一个训练样本,我们计算它在每个分类上的得分,每当它在某一分类产生了损失(即scores[y[!i] - screos[y[i]] + delta > 0),那么我们就将该分类上的参数梯度+Xi

错误分类的参数梯度计算

同时正确分类(y[i])的参数梯度-Xi

正确分类的参数梯度计算

再简单的说就是,对于每个Xi,每一个会产生损失值的分类(scores[y[!i] - screos[y[i]] + delta > 0)都会将其参数梯度+Xi,同时在正确分类(y[i])上的梯度-Xi

将以上的公式转化成代码,用非向量化实现(更容易理解)如下

def svm_loss_naive(W, X, y, reg):
  """
  Structured 
<think>嗯,用户想要用SVM分类CIFAR-10数据集,用Matlab实现。首先我得回忆一下CIFAR-10的结构。CIFAR-1010类图像,每类6000张,32x32的彩色图。所以数据维度应该是32x32x3,每张图转成向量的话是3072维。 然后SVM在Matlab里的实现,通常用fitcecoc函数,这是多类分类用的,内部用一对一或者一对多策略。用户可能对预处理步骤不太清楚,比如数据标准化、降维或者特征提取。因为直接用像素作为特征可能效果不好,但作为示例,可能需要先展示基础方法。 接下来,数据加载的问题。CIFAR-10在Matlab中可能需要用load函数读取,但用户可能不知道数据的具体格式。通常CIFAR-10的Matlab版本是.mat文件,里面包含data和labels。比如每个batch的文件名可能是data_batch_1.mat,测试集是test_batch.mat。 处理数据的时候,需要将图像数据从uint8转换为double,可能还需要归一化到0-1之间,除以255。标签可能需要从1开始的索引转换为0-9,或者反过来,具体看数据集的存储方式。比如,如果标签是0-9,而Matlab的索引从1开始,可能需要加1,或者fitcecoc会自动处理。 然后,考虑到数据量较大,全部训练可能内存不够,所以用户可能需要先提取部分数据,比如每个类别选几百张进行训练。代码里可能需要用datastore或者直接读取数据矩阵。但CIFAR-10的每个batch可能有10000张,五个训练batch,所以总共有50000张训练10000测试。如果内存不够,可能需要用小样本测试。 然后,SVM训练时间可能比较长,特别是对于高维数据。所以用户可能需要使用线性SVM,因为核函数如RBF计算更耗时。或者建议用户使用PCA降维,比如降到100维,这样可以加速训练,但可能会损失准确率。 接下来,代码步骤大概是: 1. 加载数据集,可能需要用load函数读取各个batch,合并训练数据。 2. 预处理:转换数据类型,归一化,reshape成向量。 3. 划分训练集和测试集,或者直接使用已有的划分。 4. 特征提取或降维(可选,但推荐)。 5. 训练多类SVM模型,使用fitcecoc,指定模板为线性SVM。 6. 预测测试集,计算准确率。 可能的错误点包括数据维度不正确,标签格式不对,内存不足。用户可能需要减少样本量,或者使用更高效的特征,比如HOG而不是原始像素。 另外,可能需要注意fitcecoc的默认参数,比如是否使用线性核,或者其他设置。例如,模板SVM可以设置Solver为'sgd'或'lbfgs',但线性SVM通常用SGD更快。 测试阶段,预测时的数据同样需要预处理,和训练集相同的步骤。最后计算准确率,可能比较低,因为原始像素+SVMCIFAR-10上效果一般,正确率可能在30-40%左右,但用户可能需要更高的方法,比如CNN,但问题明确要求SVM。 所以,综合这些步骤,给出代码示例,包括数据加载、预处理、训练、评估,并在最后提到可能的改进方法,比如特征提取、调参等。同时,提醒用户可能的内存问题,建议使用子样本。 相关问题方面,用户可能想知道如何提高准确率,处理大数据集,使用其他分类器,或者特征提取方法。所以问题可以围绕SVM参数优化、数据预处理、特征提取、与其他模型比较等。</think><< 使用SVM分类CIFAR-10数据集的MATLAB实现步骤如下: ### 1. **数据加载与预处理** CIFAR-10数据集包含60,000张32×32的彩色图像,分为10类。以下是加载和预处理代码: ```matlab % 加载CIFAR-10数据(假设数据文件为.mat格式) load('data_batch_1.mat'); % 示例加载一个训练批次 X_train = double(data); % 转换为double类型 y_train = double(labels) + 1; % 标签从0-9转为1-10(MATLAB要求类别从1开始) % 归一化到[0,1] X_train = X_train / 255; % 可选:PCA降维(减少计算量) [coeff, score, ~] = pca(X_train); X_train_pca = score(:,1:100); % 保留前100个主成分 ``` ### 2. **训练SVM分类器** 使用MATLAB的`fitcecoc`函数实现多类SVM: ```matlab % 使用线性SVM模板 template = templateSVM('KernelFunction', 'linear', 'Standardize', true); % 训练模型 model = fitcecoc(X_train_pca, y_train, 'Learners', template); ``` ### 3. **测试与评估** ```matlab % 加载测试数据 load('test_batch.mat'); X_test = double(data); y_test = double(labels) + 1; X_test = X_test / 255; % 对测试数据应用相同的PCA X_test_pca = X_test * coeff(:,1:100); % 预测 predicted_labels = predict(model, X_test_pca); % 计算准确率 accuracy = sum(predicted_labels == y_test) / numel(y_test); fprintf('Accuracy: %.2f%%\n', accuracy * 100); ``` ### **解释** 1. **数据预处理**: - CIFAR-10图像的像素值范围为0-255,归一化到[0,1]可加速收敛。 - PCA降维可减少计算复杂度(原始维度为3072,降维后为100)。 2. **模型选择**: - `templateSVM`使用线性核(`linear`),适合高维数据(如图像)。 - `fitcecoc`采用“一对多”策略处理多分类问题。 3. **性能瓶颈**: - 直接使用原始像素特征+SVM的准确率较低(约30-40%),因SVM对图像特征的表达能力有限。 - 更高准确率需结合特征提取(如HOG、CNN特征)。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值