第五章 单层网络:分类
取一个输入向量x∈
,并将其归入K个离散类
中的一个类,其中k=1,...,K。
(总览)思维导图
5.1
分类问题的线性方法核心是 “用超平面切分数据”:
- 二分类用一条线(超平面),多分类用多条线或 “一对多 / 一对一” 策略;
- 独热编码给类别贴标签,最小二乘试图用回归方法分类,但对离群值敏感,实际中常用逻辑回归等更鲁棒的模型。
本质是用线性模型找 “分类边界”,但要注意数据分布和模型假设的匹配度,否则易被异常点带偏~
5.2
决策理论是 “分类的最后一步”:先通过推断算概率(生成 / 判别模型),再根据 “错误代价”(损失函数)选类别,还能拒绝不确定的样本。评估分类器时,用准确率、精确率、召回率、ROC 曲线看效果。
核心是 “用概率量化不确定性,用损失权衡决策风险”,让分类从 “瞎猜” 变成 “有理有据的选择”~
5.3生成分类器
核心逻辑:先弄清 “每个类别数据是怎么生成的”,再用概率算 “输入属于哪个类的可能性更高” 。
特点:
优点:能生成新数据(模仿垃圾邮件写新文案),对小数据鲁棒(靠先验知识补数据不足),还能做异常检测(发现概率极低的 “奇怪邮件”)
缺点:要假设 “数据分布形式”(比如高斯分布),假设错了效果就差;计算复杂,尤其是高维数据(比如图像)。
5.4判别分类器
核心逻辑:不关心 “数据怎么来的”,直接找 “类别之间的决策边界” 。
特点:
优点:专注分类任务,训练高效(不用学复杂生成过程);对数据分布假设少,适合大数据场景(直接拟合边界)。
缺点:不能生成数据(只学边界,没学数据分布);依赖训练数据分布,换数据可能要重新训练。
为了让大家更好的理解,这里依据生成分类器构建了一个“手写数字识别”的交互界面。
构建的分成两步走:
先使用python加载 sklearn 的 digits 数据集训练获得0-9的平均分布图情况,如下
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 加载数据集
digits = load_digits()
X = digits.data # 图像数据 (1797, 64)
y = digits.target # 标签 (0-9)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 计算每个数字的平均图像
digit_means = np.zeros((10, 64))
digit_counts = np.zeros(10)
for i in range(10):
digit_indices = (y_train == i)
if np.any(digit_indices):
digit_means[i] = np.mean(X_train[digit_indices], axis=0)
digit_counts[i] = np.sum(digit_indices)
# 计算模型准确率
class SimpleGenerativeModel:
def __init__(self):
self.digit_means = np.zeros((10, 64))
def train(self, X, y):
for i in range(10):
digit_indices = (y == i)
if np.any(digit_indices):
self.digit_means[i] = np.mean(X[digit_indices], axis=0)
def predict(self, X):
predictions = []
for sample in X:
similarities = []
for i in range(10):
distance = np.sqrt(np.sum((sample - self.digit_means[i]) ** 2))
similarity = 1 / (1 + distance)
similarities.append(similarity)
best_digit = np.argmax(similarities)
predictions.append(best_digit)
return np.array(predictions)
# 训练模型
model = SimpleGenerativeModel()
model.train(X_train, y_train)
# 在测试集上评估模型
y_pred = model.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"模型准确率: {accuracy:.4f}")
# 分三次输出数字平均图像
def display_digits(start, end):
fig, axes = plt.subplots(1, end - start + 1, figsize=(3 * (end - start + 1), 3))
if end - start + 1 == 1: # 处理只有一个数字的情况
axes = [axes]
for i, digit in enumerate(range(start, end + 1)):
ax = axes[i]
ax.imshow(digit_means[digit].reshape(8, 8), cmap='gray')
ax.set_title(f'数字 {digit}\n(样本数: {int(digit_counts[digit])})')
ax.axis('off')
plt.tight_layout()
plt.savefig(f'digits_{start}_to_{end}.png', dpi=300, bbox_inches='tight')
plt.show()
# 第一批:0-3
display_digits(0, 3)
# 第二批:4-6
display_digits(4, 6)
# 第三批:7-9
display_digits(7, 9)
# 输出模型参数
print("\n模型参数数据 (digit_means):")
for i in range(10):
print(f"\n数字 {i} 的平均图像:")
print(digit_means[i].reshape(8, 8).round(2))
运行结果图:
后将获取的0-9的模型参数数据输入到html设定中,借此来预测手写的默写。
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>手写数字识别演示</title>
<script src="https://cdn.tailwindcss.com"></script>
<link href="https://cdn.jsdelivr.net/npm/font-awesome@4.7.0/css/font-awesome.min.css" rel="stylesheet">
<script>
tailwind.config = {
theme: {
extend: {
colors: {
primary: '#4F46E5',
secondary: '#EC4899',
accent: '#10B981',
neutral: '#6B7280',
},
fontFamily: {
sans: ['Inter', 'system-ui', 'sans-serif'],
},
}
}
}
</script>
<style type="text/tailwindcss">
@layer utilities {
.content-auto {
content-visibility: auto;
}
.digit-canvas {
image-rendering: pixelated;
}
}
</style>
</head>
<body class="bg-gray-50 min-h-screen">
<header class="bg-white shadow-sm sticky top-0 z-50">
<div class="container mx-auto px-4 py-3 flex justify-between items-center">
<div class="flex items-center space-x-2">
<i class="fa fa-paint-brush text-primary text-2xl"></i>
<h1 class="text-xl font-bold text-primary">手写数字识别演示</h1>
</div>
</div>
</header>
<main class="container mx-auto px-4 py-8">
<section class="mb-12 text-center">
<h2 class="text-2xl font-bold mb-4">纯前端手写数字识别</h2>
<p class="text-neutral max-w-2xl mx-auto">
在右侧画布上书写一个数字(0-9),然后点击"识别"按钮查看预测结果。
整个识别过程完全在您的浏览器中完成,无需联网。
</p>
</section>
<section class="bg-white rounded-xl shadow-lg p-6">
<div class="grid grid-cols-1 md:grid-cols-2 gap-8">
<!-- 手写区域 -->
<div class="bg-gray-50 rounded-lg p-4 border border-gray-100">
<h3 class="text-lg font-semibold mb-3 flex items-center">
<i class="fa fa-pencil text-primary mr-2"></i>
请在此手写数字
</h3>
<div class="relative bg-white border border-gray-200 rounded-lg overflow-hidden mb-4">
<canvas id="writing-canvas" class="w-full h-64 cursor-crosshair"></canvas>
<div class="absolute inset-0 flex items-center justify-center text-gray-300 pointer-events-none">
<div class="text-center">
<i class="fa fa-hand-paper-o text-3xl mb-2"></i>
<p>在此书写数字</p>
</div>
</div>
</div>
<div class="flex gap-2">
<button id="clear-canvas" class="flex-1 bg-gray-200 hover:bg-gray-300 text-gray-700 py-2 px-4 rounded-lg transition-colors">
<i class="fa fa-eraser mr-1"></i> 清除
</button>
<button id="predict-button" class="flex-1 bg-primary hover:bg-primary/90 text-white py-2 px-4 rounded-lg transition-colors">
<i class="fa fa-magic mr-1"></i> 识别
</button>
</div>
</div>
<!-- 识别结果 -->
<div class="bg-gray-50 rounded-lg p-4 border border-gray-100">
<h3 class="text-lg font-semibold mb-3 flex items-center">
<i class="fa fa-search text-primary mr-2"></i>
识别结果
</h3>
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
<!-- 模型预测结果 -->
<div class="bg-white rounded-lg p-4 border border-gray-100">
<h4 class="font-medium text-primary mb-2">
<i class="fa fa-bolt mr-1"></i>
预测数字
</h4>
<div class="flex items-center justify-center h-32">
<div id="prediction-result" class="text-6xl font-bold text-primary">?</div>
</div>
</div>
<!-- 置信度 -->
<div class="bg-white rounded-lg p-4 border border-gray-100">
<h4 class="font-medium text-neutral mb-2">
<i class="fa fa-percent mr-1"></i>
置信度
</h4>
<div class="flex items-center justify-center h-32">
<div id="prediction-confidence" class="text-6xl font-bold text-neutral">0%</div>
</div>
</div>
</div>
<!-- 预测细节 -->
<div class="bg-white rounded-lg p-4 border border-gray-100 mt-4">
<h4 class="font-medium mb-2">
<i class="fa fa-info-circle text-primary mr-1"></i>
预测细节
</h4>
<div id="prediction-details" class="text-sm text-neutral">
等待手写数字...
</div>
</div>
</div>
</div>
</section>
</main>
<footer class="bg-gray-800 text-white py-6 mt-12">
<div class="container mx-auto px-4">
<div class="text-center text-gray-400 text-sm">
© 2025 手写数字识别演示 | 纯前端实现
</div>
</div>
</footer>
<script>
// 预训练的模型参数 - 每个数字的平均图像
const digitMeans = [
[
0.000e+00, 1.000e-02, 4.160e+00, 1.313e+01, 1.157e+01, 3.150e+00, 4.000e-02, 0.000e+00,
0.000e+00, 8.800e-01, 1.268e+01, 1.344e+01, 1.157e+01, 1.128e+01, 1.040e+00, 0.000e+00,
0.000e+00, 3.860e+00, 1.439e+01, 5.230e+00, 1.990e+00, 1.212e+01, 3.470e+00, 0.000e+00,
0.000e+00, 5.370e+00, 1.279e+01, 1.980e+00, 1.400e-01, 9.140e+00, 6.510e+00, 0.000e+00,
0.000e+00, 5.860e+00, 1.157e+01, 9.900e-01, 4.000e-02, 8.740e+00, 7.190e+00, 0.000e+00,
0.000e+00, 3.460e+00, 1.330e+01, 1.650e+00, 1.400e+00, 1.137e+01, 5.900e+00, 0.000e+00,
0.000e+00, 7.300e-01, 1.314e+01, 1.010e+01, 1.059e+01, 1.349e+01, 2.350e+00, 0.000e+00,
0.000e+00, 0.000e+00, 4.110e+00, 1.363e+01, 1.346e+01, 5.350e+00, 2.200e-01, 0.000e+00
],
[
0.000e+00, 1.000e-02, 2.440e+00, 9.160e+00, 1.051e+01, 6.230e+00, 9.900e-01, 0.000e+00,
0.000e+00, 9.000e-02, 3.990e+00, 1.260e+01, 1.399e+01, 8.550e+00, 1.230e+00, 0.000e+00,
1.000e-02, 9.700e-01, 6.900e+00, 1.476e+01, 1.423e+01, 7.560e+00, 7.400e-01, 0.000e+00,
1.000e-02, 2.090e+00, 9.170e+00, 1.441e+01, 1.389e+01, 6.450e+00, 4.000e-01, 0.000e+00,
0.000e+00, 1.090e+00, 6.880e+00, 1.191e+01, 1.365e+01, 5.600e+00, 3.800e-01, 0.000e+00,
0.000e+00, 4.100e-01, 5.190e+00, 1.029e+01, 1.340e+01, 5.870e+00, 3.500e-01, 0.000e+00,
0.000e+00, 9.000e-02, 4.600e+00, 1.081e+01, 1.356e+01, 7.610e+00, 2.090e+00, 6.200e-01,
0.000e+00, 1.000e-02, 2.250e+00, 9.020e+00, 1.305e+01, 8.410e+00, 2.920e+00, 1.390e+00
],
[
0.000e+00, 9.800e-01, 9.840e+00, 1.417e+01, 9.460e+00, 2.560e+00, 1.200e-01, 0.000e+00,
1.000e-02, 5.260e+00, 1.385e+01, 1.192e+01, 1.229e+01, 5.620e+00, 4.600e-01, 0.000e+00,
1.000e-02, 4.600e+00, 7.940e+00, 4.440e+00, 1.136e+01, 6.010e+00, 5.100e-01, 0.000e+00,
0.000e+00, 8.100e-01, 1.700e+00, 4.600e+00, 1.206e+01, 4.480e+00, 2.800e-01, 0.000e+00,
0.000e+00, 6.000e-02, 1.170e+00, 8.790e+00, 1.042e+01, 2.360e+00, 8.000e-02, 0.000e+00,
0.000e+00, 5.800e-01, 5.250e+00, 1.192e+01, 7.310e+00, 2.030e+00, 9.000e-01, 1.000e-02,
2.000e-02, 1.470e+00, 1.131e+01, 1.426e+01, 1.199e+01, 1.062e+01, 7.080e+00, 5.800e-01,
1.000e-02, 9.600e-01, 1.031e+01, 1.403e+01, 1.331e+01, 1.201e+01, 8.310e+00, 2.010e+00
],
[
0.000e+00, 7.200e-01, 8.480e+00, 1.417e+01, 1.428e+01, 7.590e+00, 7.900e-01, 1.000e-02,
1.000e-02, 4.360e+00, 1.271e+01, 9.000e+00, 1.137e+01, 1.212e+01, 2.030e+00, 2.000e-02,
1.000e-02, 2.280e+00, 3.730e+00, 3.230e+00, 1.204e+01, 9.310e+00, 8.000e-01, 0.000e+00,
0.000e+00, 2.600e-01, 1.510e+00, 8.850e+00, 1.417e+01, 5.680e+00, 8.000e-02, 0.000e+00,
0.000e+00, 7.000e-02, 9.700e-01, 5.590e+00, 1.207e+01, 1.133e+01, 2.190e+00, 0.000e+00,
0.000e+00, 3.800e-01, 1.320e+00, 1.010e+00, 4.440e+00, 1.205e+01, 6.320e+00, 0.000e+00,
0.000e+00, 8.300e-01, 7.110e+00, 6.430e+00, 8.350e+00, 1.291e+01, 6.010e+00, 4.000e-02,
0.000e+00, 5.500e-01, 9.380e+00, 1.455e+01, 1.370e+01, 8.490e+00, 1.440e+00, 3.000e-02
],
[
0.000e+00, 0.000e+00, 4.300e-01, 7.030e+00, 1.161e+01, 2.150e+00, 2.400e-01, 1.700e-01,
0.000e+00, 7.000e-02, 3.340e+00, 1.316e+01, 8.600e+00, 2.010e+00, 1.210e+00, 3.500e-01,
0.000e+00, 7.300e-01, 1.031e+01, 1.165e+01, 4.880e+00, 5.640e+00, 3.960e+00, 3.200e-01,
1.000e-02, 4.510e+00, 1.450e+01, 6.630e+00, 7.040e+00, 1.070e+01, 6.320e+00, 2.000e-02,
0.000e+00, 8.540e+00, 1.447e+01, 9.540e+00, 1.287e+01, 1.370e+01, 5.100e+00, 0.000e+00,
1.100e-01, 6.330e+00, 1.112e+01, 1.234e+01, 1.455e+01, 1.064e+01, 1.670e+00, 0.000e+00,
7.000e-02, 1.250e+00, 3.190e+00, 7.980e+00, 1.342e+01, 4.540e+00, 3.000e-02, 0.000e+00,
0.000e+00, 3.000e-02, 5.300e-01, 7.750e+00, 1.201e+01, 2.010e+00, 0.000e+00, 0.000e+00
],
[
0.000e+00, 1.040e+00, 1.035e+01, 1.304e+01, 1.375e+01, 1.195e+01, 4.170e+00, 5.000e-02,
1.000e-02, 4.120e+00, 1.482e+01, 1.222e+01, 8.420e+00, 6.360e+00, 2.210e+00, 3.000e-02,
0.000e+00, 5.580e+00, 1.439e+01, 5.880e+00, 2.040e+00, 7.000e-01, 7.000e-02, 0.000e+00,
0.000e+00, 5.010e+00, 1.425e+01, 1.213e+01, 9.050e+00, 3.930e+00, 3.900e-01, 0.000e+00,
0.000e+00, 1.900e+00, 7.510e+00, 8.460e+00, 9.030e+00, 7.570e+00, 1.770e+00, 0.000e+00,
0.000e+00, 2.700e-01, 1.230e+00, 3.720e+00, 7.580e+00, 8.190e+00, 2.290e+00, 0.000e+00,
0.000e+00, 9.000e-01, 5.700e+00, 8.140e+00, 1.111e+01, 7.270e+00, 1.230e+00, 0.000e+00,
0.000e+00, 1.010e+00, 1.087e+01, 1.479e+01, 9.120e+00, 2.190e+00, 1.300e-01, 0.000e+00
],
[
0.000e+00, 0.000e+00, 1.080e+00, 1.105e+01, 9.860e+00, 1.550e+00, 1.000e-02, 0.000e+00,
0.000e+00, 4.000e-02, 6.850e+00, 1.464e+01, 6.560e+00, 9.400e-01, 6.000e-02, 0.000e+00,
0.000e+00, 7.100e-01, 1.227e+01, 9.930e+00, 1.060e+00, 1.200e-01, 1.000e-02, 0.000e+00,
0.000e+00, 2.280e+00, 1.359e+01, 8.340e+00, 3.880e+00, 2.000e+00, 1.200e-01, 0.000e+00,
0.000e+00, 3.540e+00, 1.469e+01, 1.289e+01, 1.229e+01, 1.019e+01, 2.860e+00, 0.000e+00,
0.000e+00, 1.990e+00, 1.466e+01, 1.103e+01, 5.860e+00, 1.023e+01, 9.130e+00, 2.800e-01,
0.000e+00, 1.800e-01, 1.016e+01, 1.288e+01, 5.830e+00, 1.138e+01, 1.066e+01, 5.900e-01,
0.000e+00, 0.000e+00, 1.340e+00, 1.062e+01, 1.508e+01, 1.311e+01, 4.470e+00, 1.000e-01
],
[
0.000e+00, 1.200e-01, 4.980e+00, 1.301e+01, 1.422e+01, 1.102e+01, 5.390e+00, 1.080e+00,
0.000e+00, 9.300e-01, 1.043e+01, 1.180e+01, 1.104e+01, 1.251e+01, 5.490e+00, 6.300e-01,
0.000e+00, 7.000e-01, 4.900e+00, 2.600e+00, 6.940e+00, 1.143e+01, 3.280e+00, 1.400e-01,
0.000e+00, 5.400e-01, 4.320e+00, 6.290e+00, 1.206e+01, 1.245e+01, 5.100e+00, 0.000e+00,
0.000e+00, 1.300e+00, 8.830e+00, 1.351e+01, 1.478e+01, 1.122e+01, 4.360e+00, 0.000e+00,
0.000e+00, 1.100e+00, 5.300e+00, 1.186e+01, 1.099e+01, 4.200e+00, 6.700e-01, 0.000e+00,
0.000e+00, 1.000e-01, 3.060e+00, 1.245e+01, 5.940e+00, 2.300e-01, 0.000e+00, 0.000e+00,
0.000e+00, 9.000e-02, 6.290e+00, 1.183e+01, 2.290e+00, 1.000e-02, 0.000e+00, 0.000e+00
],
[
0.000e+00, 1.400e-01, 5.100e+00, 1.159e+01, 1.247e+01, 6.280e+00, 5.100e-01, 0.000e+00,
2.000e-02, 1.950e+00, 1.227e+01, 1.152e+01, 9.500e+00, 1.153e+01, 2.570e+00, 0.000e+00,
1.000e-02, 2.930e+00, 1.151e+01, 7.490e+00, 7.850e+00, 1.171e+01, 2.190e+00, 0.000e+00,
0.000e+00, 1.220e+00, 8.470e+00, 1.324e+01, 1.324e+01, 6.920e+00, 4.400e-01, 0.000e+00,
0.000e+00, 4.400e-01, 6.820e+00, 1.396e+01, 1.288e+01, 4.490e+00, 1.000e-01, 0.000e+00,
0.000e+00, 1.030e+00, 1.065e+01, 8.470e+00, 9.030e+00, 8.660e+00, 1.410e+00, 0.000e+00,
0.000e+00, 9.500e-01, 1.094e+01, 8.100e+00, 8.260e+00, 9.800e+00, 2.350e+00, 1.000e-02,
0.000e+00, 1.500e-01, 5.090e+00, 1.269e+01, 1.308e+01, 6.730e+00, 1.100e+00, 1.000e-02
],
[
0.000e+00, 0.110e+00, 5.610e+00, 1.174e+01, 1.129e+01, 6.050e+00, 1.680e+00, 0.150e+00,
0.000e+00, 2.240e+00, 1.251e+01, 9.590e+00, 9.920e+00, 1.147e+01, 2.760e+00, 0.160e+00,
0.000e+00, 3.360e+00, 1.227e+01, 5.870e+00, 7.990e+00, 1.414e+01, 3.330e+00, 0.060e+00,
0.000e+00, 1.850e+00, 1.022e+01, 1.224e+01, 1.315e+01, 1.408e+01, 4.010e+00, 0.000e+00,
0.000e+00, 0.140e+00, 2.900e+00, 4.960e+00, 4.990e+00, 1.145e+01, 4.880e+00, 0.000e+00,
0.000e+00, 0.190e+00, 0.490e+00, 0.510e+00, 2.390e+00, 9.620e+00, 5.940e+00, 0.030e+00,
0.000e+00, 0.760e+00, 6.040e+00, 4.670e+00, 5.660e+00, 10.36e+00, 5.270e+00, 0.130e+00,
0.000e+00, 0.060e+00, 5.660e+00, 11.94e+01, 13.16e+01, 8.940e+00, 2.090e+00, 0.050e+00
]
];
// 初始化画布
const canvas = document.getElementById('writing-canvas');
const ctx = canvas.getContext('2d');
const clearButton = document.getElementById('clear-canvas');
const predictButton = document.getElementById('predict-button');
const predictionResult = document.getElementById('prediction-result');
const predictionConfidence = document.getElementById('prediction-confidence');
const predictionDetails = document.getElementById('prediction-details');
// 设置画布大小和样式
canvas.width = 200;
canvas.height = 200;
ctx.fillStyle = 'white';
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.strokeStyle = 'black';
ctx.lineWidth = 10;
ctx.lineCap = 'round';
// 绘制状态
let isDrawing = false;
let lastX = 0;
let lastY = 0;
// 鼠标事件处理
canvas.addEventListener('mousedown', (e) => {
isDrawing = true;
const rect = canvas.getBoundingClientRect();
lastX = e.clientX - rect.left;
lastY = e.clientY - rect.top;
});
canvas.addEventListener('mousemove', (e) => {
if (isDrawing) {
const rect = canvas.getBoundingClientRect();
const x = e.clientX - rect.left;
const y = e.clientY - rect.top;
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(x, y);
ctx.stroke();
lastX = x;
lastY = y;
}
});
canvas.addEventListener('mouseup', () => {
isDrawing = false;
});
canvas.addEventListener('mouseout', () => {
isDrawing = false;
});
// 触摸事件处理
canvas.addEventListener('touchstart', (e) => {
e.preventDefault();
isDrawing = true;
const rect = canvas.getBoundingClientRect();
const touch = e.touches[0];
lastX = touch.clientX - rect.left;
lastY = touch.clientY - rect.top;
});
canvas.addEventListener('touchmove', (e) => {
e.preventDefault();
if (isDrawing) {
const rect = canvas.getBoundingClientRect();
const touch = e.touches[0];
const x = touch.clientX - rect.left;
const y = touch.clientY - rect.top;
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(x, y);
ctx.stroke();
lastX = x;
lastY = y;
}
});
canvas.addEventListener('touchend', (e) => {
e.preventDefault();
isDrawing = false;
});
// 清除画布
clearButton.addEventListener('click', () => {
ctx.fillStyle = 'white';
ctx.fillRect(0, 0, canvas.width, canvas.height);
predictionResult.textContent = '?';
predictionConfidence.textContent = '0%';
predictionDetails.textContent = '等待手写数字...';
});
// 从画布获取像素数据并进行预测
predictButton.addEventListener('click', () => {
// 获取画布像素数据
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const pixels = imageData.data;
// 缩小到8x8并进行二值化
const smallImage = downsampleImage(pixels, canvas.width, canvas.height, 8, 8);
// 预测数字
const prediction = predictDigit(smallImage);
// 显示结果
predictionResult.textContent = prediction.digit;
predictionConfidence.textContent = `${Math.round(prediction.confidence * 100)}%`;
// 显示预测细节
let detailsHTML = '<div class="grid grid-cols-5 gap-1">';
for (let i = 0; i < 10; i++) {
const similarity = prediction.similarities[i];
const confidence = Math.round(similarity * 100);
const barWidth = Math.min(100, confidence);
detailsHTML += `
<div class="mb-2">
<div class="flex justify-between text-xs mb-1">
<span>${i}</span>
<span>${confidence}%</span>
</div>
<div class="h-2 bg-gray-200 rounded-full overflow-hidden">
<div class="h-full bg-primary" style="width: ${barWidth}%"></div>
</div>
</div>
`;
}
detailsHTML += '</div>';
predictionDetails.innerHTML = detailsHTML;
});
// 下采样图像到8x8
function downsampleImage(pixels, width, height, targetWidth, targetHeight) {
const result = new Array(targetWidth * targetHeight).fill(0);
const scaleX = width / targetWidth;
const scaleY = height / targetHeight;
for (let y = 0; y < targetHeight; y++) {
for (let x = 0; x < targetWidth; x++) {
const srcX = Math.floor(x * scaleX);
const srcY = Math.floor(y * scaleY);
const srcIndex = (srcY * width + srcX) * 4;
// 使用灰度值(简化处理)
const gray = (pixels[srcIndex] + pixels[srcIndex + 1] + pixels[srcIndex + 2]) / 3;
// 反转颜色(因为我们的模型期望黑底白字)
result[y * targetWidth + x] = 255 - gray;
}
}
// 归一化到0-16范围(与训练数据匹配)
const maxValue = Math.max(...result);
return result.map(pixel => pixel * 16 / maxValue);
}
// 预测数字
function predictDigit(imageData) {
const similarities = [];
// 计算与每个数字的相似度
for (let digit = 0; digit < 10; digit++) {
const mean = digitMeans[digit];
let similarity = 0;
// 计算余弦相似度
let dotProduct = 0;
let imageMagnitude = 0;
let meanMagnitude = 0;
for (let i = 0; i < imageData.length; i++) {
dotProduct += imageData[i] * mean[i];
imageMagnitude += imageData[i] * imageData[i];
meanMagnitude += mean[i] * mean[i];
}
imageMagnitude = Math.sqrt(imageMagnitude);
meanMagnitude = Math.sqrt(meanMagnitude);
if (imageMagnitude > 0 && meanMagnitude > 0) {
similarity = dotProduct / (imageMagnitude * meanMagnitude);
}
similarities.push(similarity);
}
// 找出最相似的数字
let bestDigit = 0;
let bestSimilarity = similarities[0];
for (let i = 1; i < similarities.length; i++) {
if (similarities[i] > bestSimilarity) {
bestSimilarity = similarities[i];
bestDigit = i;
}
}
return {
digit: bestDigit,
confidence: bestSimilarity,
similarities: similarities
};
}
// 初始化数字预览
function initDigitPreviews() {
const digitPreviews = document.getElementById('digit-previews');
digitPreviews.innerHTML = '';
for (let digit = 0; digit < 10; digit++) {
const previewContainer = document.createElement('div');
previewContainer.className = 'bg-white rounded-lg p-2 shadow-sm border border-gray-100 flex flex-col items-center';
const canvas = document.createElement('canvas');
canvas.className = 'digit-canvas w-16 h-16 mb-1';
canvas.width = 8;
canvas.height = 8;
const ctx = canvas.getContext('2d');
const imageData = ctx.createImageData(8, 8);
const data = imageData.data;
// 将平均图像数据转换为像素
for (let i = 0; i < digitMeans[digit].length; i++) {
const value = digitMeans[digit][i];
const brightness = Math.round(value * 255 / 16);
data[i * 4] = brightness; // R
data[i * 4 + 1] = brightness; // G
data[i * 4 + 2] = brightness; // B
data[i * 4 + 3] = 255; // A
}
ctx.putImageData(imageData, 0, 0);
const label = document.createElement('div');
label.className = 'text-sm font-medium text-primary';
label.textContent = digit;
previewContainer.appendChild(canvas);
previewContainer.appendChild(label);
digitPreviews.appendChild(previewContainer);
}
}
// 页面加载完成后初始化
window.addEventListener('load', () => {
initDigitPreviews();
});
</script>
</body>
</html>
注意:这个还是存在一点使用的限制。
使用时,不要将页面最大化,同时写入时,鼠标要偏左边些写。
同时,因为使用的数据集和写入信息转化的读取相对于高水平的识别还是有一定的差距的。毕竟这个主要是用于理解生成分类器的,并没有对数据过多的处理。
较为“标准”的写入数字4,成功识别。当然,也有识别失败的时候,需要多写几次。