MNIST手写数字识别实验(KNN)

本文介绍了使用KNN算法进行MNIST手写数字识别的实验过程。数据预处理后,发现未经优化的KNN计算量巨大。通过采样降低计算复杂性,但即便如此,不同k值下的正确率仅达到20%左右。引入二值化处理后,正确率显著提升至96.3%,但k值对结果的影响不明显。

数据集:MINST
数据预处理参考了https://blog.youkuaiyun.com/simple_the_best/article/details/75267863
处理出来有用的信息也就是 28 × 28 28 \times 28 28×28的矩阵和Label信息。
KNN的实现:

def KNN(train_dataset, train_labels, input_vec, distance, k=1):
    dis_labels = []
    n = len(train_dataset)
    for i in range(n):
        vec = train_dataset[i]
        label = train_labels[i]
        dis_labels.append([distance(vec, input_vec
### 实现MNIST手写数字识别kNN算法 为了在MATLAB中实现基于MNIST数据集的手写数字识别,可以遵循以下方法来构建完整的解决方案。 #### 数据预处理 加载并准备MNIST数据集是至关重要的一步。通常情况下,可以从官方或其他可信资源下载该数据集,并将其转换成适合用于机器学习模型的形式: ```matlab % 加载 MNIST 数据集 (假设已经下载好) load('mnist.mat'); % 这里假定文件名为 'mnist.mat' train_images = double(train_images); test_images = double(test_images); % 归一化像素值到 [0, 1] 范围内 train_images = train_images / 255; test_images = test_images / 255; % 将图像展平为向量形式 num_train_samples = size(train_images, 4); num_test_samples = size(test_images, 4); image_height = size(train_images, 1); image_width = size(train_images, 2); X_train = reshape(permute(train_images, [3 1 2]), num_train_samples, image_height * image_width)'; X_test = reshape(permute(test_images, [3 1 2]), num_test_samples, image_height * image_width)'; y_train = train_labels'; y_test = test_labels'; ``` #### kNN分类器定义 接下来定义一个简单的函数来进行最近邻查找。这里使用欧氏距离作为相似度衡量标准: ```matlab function predictions = knn_predict(X_train, y_train, X_test, k) num_test = size(X_test, 2); predictions = zeros(num_test, 1); for i = 1:num_test distances = sqrt(sum((bsxfun(@minus, X_train', X_test(:,i).')).^2)); [~, idx] = sort(distances); %#ok<ASGLU> closest_y = y_train(idx(1:k), :); [values, ~, ic] = unique(closest_y, 'rows'); counts = accumarray(ic, 1); [~, maxIdx] = max(counts); predictions(i) = values(maxIdx); end end ``` 此部分实现了对测试样本逐个计算其与所有训练样本之间的欧式距离,并选取最小的距离对应的类别标签作为预测结果[^1]。 #### 训练与评估 最后,调用上述编写的`knn_predict()` 函数完成整个过程,并统计准确率: ```matlab k_value = 5; % 设置合适的K值 predicted_labels = knn_predict(X_train, y_train, X_test, k_value); accuracy = sum(predicted_labels == y_test) / length(y_test); disp(['Accuracy: ', num2str(accuracy)]); ``` 值得注意的是,选择恰当的\( k \)值对于获得良好的性能至关重要。如果\( k \)值设得太大,则可能会引入噪声;反之则可能导致过拟合现象发生[^4]。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值