题目1.基于CNN+RNN的手写数字图像序列预测
一、实验背景
手写数字识别是计算机视觉中的经典任务。基于此,本实验可以设计一个更有趣的综合性实验:给定一组连续手写数字图像序列,模型需要预测序列中下一个数字。
二、实验要求:
CNN(采用:自定义CNN或AlexNet或ResNet等)用于对MNIST手写数字图像进行分类。
RNN(采用:自定义RNN或LSTM或GRU等)用于数字序列进行预测(基于序列的前几个数字预测下一个数字)。
这个任务既涉及图像分类,又涉及序列建模,是CNN + RNN 综合应用的典型案例。
三、数据集
1.图像数据集
名称:MNIST 手写数字数据集,可直接使用PyTorch内置数据集。
torchvision.datasets.MNIST(root='./data', train=True, download=True)
内容:28×28 灰度图像,训练集 60,000 张,测试集 10,000 张
2.序列生成方法
将 MNIST 图像随机组成长度为 5 的序列,用前 4 个数字预测第 5 个数字作为标签。
序列规则示例:
l 等差数列:1,3,5,7 → 预测9
l 等比数列:2,4,8,16 → 预测32
l 斐波那契:1,1,2,3 → 预测5
l 简单模式:1,2,1,2 → 预测1
这种方式可以快速生成序列数据,无需额外大数据集。
四、代码框架
1.生成数字序列数据集
class NumberSequenceDataset(Dataset):
"""数字序列数据集"""
def __init__(self, num_samples=5000, sequence_length=5):
self.num_samples = num_samples
self.sequence_length = sequence_length
self.sequences, self.labels = self._generate_sequences()
def _generate_sequences(self):
"""生成多种模式的数字序列"""
sequences = []
labels = []
for _ in range(self.num_samples):
# 随机选择一种序列模式
pattern_type = np.random.choice(['arithmetic', 'geometric', 'fibonacci', 'alternating'])
if pattern_type == 'arithmetic':
# 等差数列
start = np.random.randint(0, 5)
diff = np.random.randint(1, 4)
sequence = [start + i * diff for i in range(self.sequence_length)]
next_num = sequence[-1] + diff
elif pattern_type == 'geometric':
# 等比数列
start = np.random.randint(1, 3)
ratio = np.random.randint(2, 4)
sequence = [start * (ratio ** i) for i in range(self.sequence_length)]
next_num = sequence[-1] * ratio
elif pattern_type == 'fibonacci':
# 斐波那契数列变种
a, b = np.random.randint(1, 3), np.random.randint(1, 3)
sequence = [a, b]
for i in range(2, self.sequence_length):
sequence.append(sequence[i-1] + sequence[i-2])
next_num = sequence[-1] + sequence[-2]
else: # alternating
# 交替序列
a, b = np.random.randint(0, 5), np.random.randint(0, 5)
sequence = [a if i % 2 == 0 else b for i in range(self.sequence_length)]
next_num = a if self.sequence_length % 2 == 0 else b
# 限制数字范围在0-9之间(模拟MNIST数字)
sequence = [min(max(x, 0), 9) for x in sequence]
next_num = min(max(next_num, 0), 9)
sequences.append(sequence)
labels.append(next_num)
return np.array(sequences), np.array(labels)
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
sequence = torch.FloatTensor(self.sequences[idx])
label = torch.LongTensor([self.labels[idx]]).squeeze()
return sequence, label
2.数据加载和预处理
print("=== 数据准备 ===")
# MNIST数据变换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载MNIST数据集
train_mnist = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform
)
test_mnist = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform
)
# 创建数字序列数据集
train_sequences = NumberSequenceDataset(num_samples=5000, sequence_length=5)
test_sequences = NumberSequenceDataset(num_samples=1000, sequence_length=5)
# 创建数据加载器
mnist_train_loader = DataLoader(train_mnist, batch_size=64, shuffle=True)
mnist_test_loader = DataLoader(test_mnist, batch_size=64, shuffle=False)
sequence_train_loader = DataLoader(train_sequences, batch_size=64, shuffle=True)
sequence_test_loader = DataLoader(test_sequences, batch_size=64, shuffle=False)
print(f"MNIST训练集: {len(train_mnist)} 张图像")
print(f"MNIST测试集: {len(test_mnist)} 张图像")
print(f"序列训练集: {len(train_sequences)} 个序列")
print(f"序列测试集: {len(test_sequences)} 个序列")
3.模型定义
(1)定义手写数字识别CNN模型(采用:自定义RNN或LSTM或GRU等):补充!
(2)定义数字序列预测RNN模型(采用:自定义RNN或LSTM或GRU等):补充!
4.训练模型
(1)训练、优化CNN模型:补充!
(2)训练、优化RNN模型:补充!
5.评估模型
(1)评估CNN模型:补充!
(2)评估RNN模型:补充!
6.数字序列预测测试
(1)输入数字图像序列,实现下一个数字预测:补充!