代码全文
python
import matplotlib.pyplot as plt
import numpy as np
import os
# 读入mnist数据集
m_x = np.loadtxt('mnist_x', delimiter=' ')
m_y = np.loadtxt('mnist_y')
# 数据集可视化
data = np.reshape(np.array(m_x[0], dtype=int), [28, 28])
plt.figure()
plt.imshow(data, cmap='gray')
# 将数据集分为训练集和测试集
ratio = 0.8
split = int(len(m_x) * ratio)
# 打乱数据
np.random.seed(0)
idx = np.random.permutation(np.arange(len(m_x)))
m_x = m_x[idx]
m_y = m_y[idx]
x_train, x_test = m_x[:split], m_x[split:]
y_train, y_test = m_y[:split], m_y[split:]
详细解释
1. 导入库
python
import matplotlib.pyplot as plt
import numpy as np
import os
-
import matplotlib.pyplot as plt:
-
作用:导入 Matplotlib 库中的 pyplot 模块,通常用于数据可视化(比如画图、显示图像)。
-
用途:这里用它来显示 MNIST 数据集中的手写数字图像。
-
别名 plt:为了方便调用,约定俗成用 plt 作为简写。
-
-
import numpy