今天在运行李沐老师“图像分类数据集”这一课的代码时出现了错误:
RuntimeError: DataLoader worker (pid(s) 21316) exited unexpectedly
一、原因
发送给deepseek帮我分析,可能有几个原因:
1. 内存不足导致workers进程终止;2. 数据集代码中存在错误;3. 多进程兼容性问题;4. 系统资源限制。
于是对代码和详细的报错信息进行检查,发现将DataLoader()函数中的num_workers参数设置为0时脚本不会报错,且详细错误信息中说明:
RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your child processes and you have forgotten to use the proper idiom in the main module:......
结合资料分析可以确定问题出现在DataLoader()函数创建的子进程会重新导入主模块,导致了无限递归,产生了错误。而书中的代码是在Jupyter上运行的,由于其在单元格中编写且在用户交互下运行,因此一般很难出现这一情况。
二、解决
接下来对代码进行修改。最开始我将调用DataLoader()函数这一句放进了“if __name__ == '__main__':”中,但还是会报错,再仔细浏览代码并结合DS提供的信息,发现需要将所有执行代码都放入if __name__ == '__main__':中,是我想当然的了。也就是说多进程启动逻辑不止包括DataLoader()函数一个。
关于if __name__ == '__main__':的用法参考了B站视频【Python中的 if __name__ == '__main__' 是干嘛的?】 https://www.bilibili.com/video/BV1T66mYpE7s/?share_source=copy_web&vd_source=3639f14589f273f696b4a18401257b0c
下面是可以在Windows上直接运行的代码:
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 19 13:24:58 2025
@author: ASUS
"""
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签。"""
text_labels = [
't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',
'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 可视化
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
def get_dataloader_workers():
# 使用4个进程来读取数据。
return 3
if __name__ == '__main__':
d2l.use_svg_display()
trans = transforms.ToTensor()
# 转换为 tensor。通过Tensor实例将图像数据从PIL类型转化为32位浮点数格式
mnist_train = torchvision.datasets.FashionMNIST(root="../data_Fashion_MNIST", train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data_Fashion_MNIST", train=False,
transform=trans, download=True)
print(len(mnist_train), len(mnist_test))
print('size of picture 1 :', mnist_train[0][0].shape)
# 第一张图片。0是图片,1是标号
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
batch_size = 256
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X, y in train_iter:
continue
print(f'{timer.stop():.2f} sec')
虽然问题解决了,但是关于“Windows 使用 spawn
方式启动子进程”相关的内容还需要以后进行深入了解。