OpenPose -tensorflow代码解析(2)—— 多进程数据读取 dataset.py

本文对OpenPose - tensorflow代码的数据读取部分进行解析。实现了图片数据增强,如色彩增强、随机翻转等;将关键点转化为热量图,包括关键点和亲和域的热量图;使用多进程将数据放入队列以减少训练时间。最后还介绍了数据读取脚本的测试方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

该openpose-tensorflow的工程是自己实现的,所以有些地方写的会比较简单,但阅读性强、方便使用。

论文翻译 || openpose – Realtime Multi-Person 2D Pose Estimation using Part Affinity Fields
工程实现 || 基于opencv使用openpose完成人体姿态估计

OpenPose -tensorflow代码解析(1)——工程概述&&训练前的准备
OpenPose -tensorflow代码解析(2)—— 数据增强和处理 dataset.py
OpenPose -tensorflow代码解析(3)—— 网络结构的搭建 Net.py
OpenPose -tensorflow代码解析(4)—— 训练脚本 train.py
OpenPose -tensorflow代码解析(5)—— 预测代码解析 predict.py

1 代码概述

将openpose的数据读取定义称一个类 class dataset

里面主要实现了3个功能:

  • 图片的数据增强:随机翻转、随机裁剪、旋转平移、色彩的增强
  • 将关键点转化成 openpose的网络的输出形式:关键点的热量图、亲和域的热量图
  • 将增强好的image、转化好的label 使用多进程放进队列中。
    这是让cpu提前处理数据放入队列中,当训练时,GPU完成一次反向传播后,能够直接获取到新的数据进行新一次的迭代,避免GPU处于空闲,减少总共训练时间

2 初始化

初始化部分,就是读取配置文件(配置脚本 opt.py附在最后),已经设定一些遍变量。其中需要说明的是:

  • self.point_num = cfg.OP.cpm_num 设置的是关键点的数量,不包含背景
  • self.paf_num = cfg.OP.paf_num 设置的是关键点之间的连接数量*2
  • self.shuffle_ref 设置的是关键点的具体连接方式
  • self.LR_morrir 如果关键点是对称的,对于关键点[0,1,2,3,4,5],假设它们对应的镜像点的索引为 [5,4,3,2,1,0]。如果关键点不是对称的,这里为空
  • self.Q_name = Queue(self.num_samples) 读取验证数据的名字的队列
  • self.Q_data = Queue(1000) 读取训练数据的队列

    这里的队列长度的设置是有讲究的。因为 Dataset 会被实例化成 【trainset、testset】,所以一定要保证当单个队列 Q_data 满数据时,不能占满整个内存,大约50%即可,否则可能发生另外一个队列无法放置数据
 class Dataset:
   def __init__(self, dataset_type):

       self.fortest = False

       self.annot_path  = cfg.TRAIN.annot_path if dataset_type == 'train' else cfg.TEST.annot_path
       if not os.path.exists(self.annot_path):
           print(self.annot_path+" 文件不存在!")
           exit()
       self.input_sizes = cfg.TRAIN.input_size if dataset_type == 'train' else cfg.TEST.input_size
       self.batch_size  = cfg.TRAIN.batch_size if dataset_type == 'train' else cfg.TEST.batch_size
       self.data_aug    = cfg.TRAIN.data_aug   if dataset_type == 'train' else cfg.TEST.data_aug

       self.WH_ratio = cfg.OP.WH_ratio
       self.stride = cfg.OP.strides
       self.point_num = cfg.OP.cpm_num
       self.paf_num = cfg.OP.paf_num
       self.shuffle_ref = [[0, 1], [1, 2], [2, 3], [3, 4],
                           [0, 5], [5, 6], [6, 7], [7, 8],
                           [0, 9], [9, 10], [10, 11], [11, 12],
                           [0, 13], [13, 14], [14, 15], [15, 16],
                           [0, 17], [17, 18], [18, 19], [19, 20],
                           [0, 21]]
      self.LR_morrir = []
      self.sigma = 0.8

       self.annotations = self.load_annotations()
       self.num_samples = len(self.annotations)    # 样本的数量
       self.num_step_one_epoch = int(np.ceil(self.num_samples / self.batch_size))  # 一轮读取批数
       self.batch_count = 0

       self.num_trianepoch = cfg.TRAIN.first_stage_epoch + cfg.TRAIN.second_stage_epoch

       self.Q_name = Queue(self.num_samples)  # 读取验证数据的名字的队列
       self.Q_data = Queue(1000)  # 读取训练数据的队列


然后

  • 定义 len(dataset) = num_step_one_epoch,也就是一轮训练的步数
  • 加载 train.txt 或者 test.txt 文件,获取到数据集的 image-label 的路径
  • 根据batch 的大小,设定好一批数据的numpy 数组。
    其中值得注意的是,self.input_size 的计算。当我们想要设置多尺度图片进行训练,只需要在opt.py 文件中,设置多个尺寸,这里就会每个batch的数据,随机获取一个尺寸的大小进行处理数据。所以 Prepare() 函数,是要在每个batch都进行调用,就不能放在 __init__() 中。
   def __len__(self):
       return self.num_step_one_epoch

   def load_annotations(self,):
       with open(self.annot_path, 'r') as f:
           txt = f.readlines()
           annotations = [line.split() for line in txt ]
       return annotations

   def Prepare(self):
       size = random.choice(self.input_sizes)

       self.input_size = [size, int(size//self.WH_ratio)]
       self.output_size = [int(self.input_size[0]//self.stride), int(self.input_size[1]//self.stride)]

       self.batch_image = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 3), dtype=float)
       self.batch_label_heatmap = np.zeros((
           self.batch_size, self.output_size[0], self.output_size[1], self.point_num+1), dtype=float)
       self.batch_label_vectmap = np.zeros((
           self.batch_size, self.output_size[0], self.output_size[1], len(self.shuffle_ref)*2), dtype=float)

2 数据增强

数据增强:色彩增强、随机翻转、随机旋转、随机平移
数据处理:给图片填充和缩放,图片的内容保持原本的长宽比例

   def load_data(self, image_path, label_path):

       if not os.path.exists(image_path):
           print(image_path+" 图片不存在")
           raise KeyError("%s does not exist ... " %image_path)
       
       image = np.array(cv2.imread(image_path))
       joint = np.loadtxt(label_path)

       show_image("image_or", image)  if self.fortest else None

       if self.data_aug:

           image = self.change_img(image)
           show_image("change_img", image) if self.fortest else None

           image,joint = self.random_horizontal_flip(image, joint)
           show_image("random_horizontal_flip", image) if self.fortest else None

           image,joint = self.random_horizontal_rotation(image, joint)
           show_image("random_horizontal_rotation", image) if self.fortest else None

           image, joint = self.random_translate(image, joint)
           show_image("random_horizontal_flip", image) if self.fortest else None

       image, joint = image_preporcess(image, self.input_size, joint)
       show_image("random_horizontal_flip", image) if self.fortest else None

       return image, joint

def show_image(name, image):
   cv2.namedWindow(name, 0)  # 0 窗口可伸缩
   cv2.resizeWindow(name, 500, 500)  # 初始窗口大小
   cv2.imshow(name, image)  # 展示图片
   cv2.waitKey(0)  # 保持展示
   # cv2.destroyAllWindows()  # 注销窗口
   

2.1 色彩增强

色彩的改变,在 Pillow 库中,有很方便的api


   def change_img(self,img):

       p = random.randint(0, 3)
       a1 = random.uniform(0.8, 2)
       a2 = random.uniform(0.8, 1.4)
       a3 = random.uniform(0.8, 1.7)
       a4 = random.uniform(0.8, 2.5)
       img = Image.fromarray(img)

       img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
       img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
       img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
       img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
       img = np.array(img)

       return img
       


在这里插入图片描述

2.2 随机水平翻转

如果标注的关节点是镜像的,如人体的关节点,在做水平翻转时,主要关键点的位置和索引,都要进行镜像处理,也就是 joint = joint[self.LR_morrir,:]

   def random_horizontal_flip(self, image, joint):
       if random.random() < 0.5:
           _, w, _ = image.shape
           image = image[:, ::-1, :]
           joint[:, 0] = w - joint[:, 0]
           # joint = joint[self.LR_morrir,:]
       return image, joint
       


在这里插入图片描述

2.3 随机旋转

思路:
根据随机获取的角度值,得到相应的旋转矩阵;
用这个旋转矩阵,以及opencv中的api,对图片进行旋转;
用这个旋转矩阵,对关键点进行相应的旋转。


   def random_horizontal_rotation(self, image, joint):
       if random.random() < 0.7:
           # 设置旋转矩阵
           transform_matrix = affine_rotation_matrix(angle=(-10,10), x=self.input_size[1]//2, y=self.input_size[0]//2)
           # 使用旋转矩阵旋转图片
           image = affine_transform_cv2(image, transform_matrix)
           #  使用旋转矩阵旋转关键点
           joint = affine_transform_keypoints(joint, transform_matrix)
       return image, joint

# 设置旋转矩阵
def affine_rotation_matrix(angle, x, y):

   if isinstance(angle, tuple):
       theta = np.pi / 180 * np.random.uniform(angle[0], angle[1])
   else:
       theta = np.pi / 180 * angle
   rotation_matrix = np.array([[np.cos(theta), np.sin(theta), 0],
                               [-np.sin(theta), np.cos(theta), 0],
                               [0, 0, 1]])
   o_x = (x - 1) / 2.0
   o_y = (y - 1) / 2.0
   offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
   reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
   transform_matrix = np.dot(np.dot(offset_matrix, rotation_matrix), reset_matrix)
   return transform_matrix

# 使用旋转矩阵旋转图片
def affine_transform_cv2(x, transform_matrix, flags=None, border_mode='constant'):

   rows, cols = x.shape[0], x.shape[1]
   if flags is None:
       flags = cv2.INTER_AREA
   if border_mode is 'constant':
       border_mode = cv2.BORDER_CONSTANT
   elif border_mode is 'replicate':
       border_mode = cv2.BORDER_REPLICATE
   else:
       raise Exception("unsupport border_mode, check cv.BORDER_ for more details.")
   return cv2.warpAffine(x, transform_matrix[0:2, :], (cols, rows), flags=flags, borderMode=border_mode)

#  使用旋转矩阵旋转关键点
def affine_transform_keypoints(coords_list, transform_matrix):

   coords = coords_list.transpose([1, 0])
   coords = np.insert(coords, 2, 1, axis=0)

   coords_result = np.matmul(transform_matrix, coords)
   coords_result = coords_result[0:2, :].transpose([1, 0])
   return coords_result


在这里插入图片描述

2.4 随机平移

进行随机平移的操作,一定要保证不能将标签所在区域 平移超出图片的范围。
所以需要先计算关键点的最小凸集,然后用这个参数,来设定平移的范围。


   def random_translate(self, image, joint):

       if random.random() < 0.5:
           h, w, _ = image.shape

           # 求图片中所有点的最小凸集框的左上角和右下角
           max_bbox = np.concatenate([np.min(joint, axis=0), np.max(joint, axis=0)], axis=-1)

           # 获取最小凸集与图片的最上角的距离
           max_l_trans = max_bbox[0]
           max_u_trans = max_bbox[1]
           max_r_trans = w - max_bbox[2]
           max_d_trans = h - max_bbox[3]

           tx = random.uniform(-(max_l_trans - 1), (max_r_trans - 1))
           ty = random.uniform(-(max_u_trans - 1), (max_d_trans - 1))

           M = np.array([[1, 0, tx], [0, 1, ty]])
           image = cv2.warpAffine(image, M, (w, h))

           joint = joint + np.array([tx,ty])

       return image, joint


在这里插入图片描述

2.5 数据尺寸处理

我们需要将图片处理成 神经网络输入的尺寸。
原则是,填充短边 使长宽比例与神经网络输入长款比例一样,然后再进行缩放,保证图片没有被拉伸或压缩。具体实现的方式很多种,只要实现没有改变长款比例就行。

def image_preporcess(image, target_size, joint=None):

   # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

   ih, iw    = target_size
   h,  w, _  = image.shape

   scale = min(iw/w, ih/h)
   nw, nh  = int(scale * w), int(scale * h)
   image_resized = cv2.resize(image, (nw, nh))

   image_paded = np.full(shape=[ih, iw, 3], fill_value=128, dtype=np.uint8)
   dw, dh = (iw - nw) // 2, (ih-nh) // 2
   image_paded[dh:nh+dh, dw:nw+dw, :] = image_resized
   # image_paded = image_paded / 255.


   if joint is None:
       return image_paded
   else:
       joint = joint * scale + + np.array([dw,dh])
       return image_paded, joint

3 关键点转化成热量图

3.1 生成关键点的热量图 heatmap

  • 生成热量图的数量为 num_keypoint + 1(背景类)。
  • 对于每一个索引的关键点,都生成一张热量图。关键点的坐标在热量图中相应的位置,会生成一个二元正态分布的数据。该函数实现的方式比较多
  • 背景类的热量图的数值,存在关键点的位置的像素值为0,其他为1。

   def get_heatmap(self, joint_list,sign=True):
       # 该函数当中的joint_list,需要的是关节点的坐标,与图片像素的索引的围度数据相反
       heatmap = np.zeros((self.point_num+1, self.output_size[0], self.output_size[1]), dtype=np.float32)
       for idx in range(self.point_num):
           joints = joint_list[idx]
           if joints[0] < 0 or joints[1] < 0:
               continue
           # print("==")
           self.put_heatmap(heatmap, idx, joints, self.sigma)

       heatmap = heatmap.transpose((1, 2, 0))
       heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0)  # background

       return heatmap

   def put_heatmap(self, heatmap, plane_idx, center, sigma):
       center_x, center_y = center
       _, height, width = heatmap.shape[:3]

       th = 4.6052 
       delta = math.sqrt(th * 2)
       x0 = int(max(0, center_x - delta * sigma))
       y0 = int(max(0, center_y - delta * sigma))
       x1 = int(min(width, center_x + delta * sigma))
       y1 = int(min(height, center_y + delta * sigma))
       for y in range(y0, y1):
           for x in range(x0, x1):
               d = (x - center_x) ** 2 + (y - center_y) ** 2
               exp = d / 2.0 / sigma / sigma
               if exp > th:
                   continue
               heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp))
               heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0)
       return heatmap

我们知道 一维的正太分布的公式为 f ( x ) = 1 2 π σ e x p − ( x − μ ) 2 2 ∗ σ f(x)=\frac{1}{\sqrt{2 \pi }\sigma}exp^{-\frac{(x-\mu)^2}{2*\sigma}} f(x)=2π σ1exp2σ(xμ)2
下面的图为二元正太分布示意图
在这里插入图片描述

3.2 生成亲和域的热量图 vectmap

  • 有连接关系的关键点对 n 组,会生成亲和域的热量图 n*2 组。
  • 在关键点对的连线上,一定宽度的像素值,都进行赋值。一张热量图中赋值为点对的方向向量的x分量,一张赋值为方向向量的 y 分量。
  • 多对连接点对 如果存在交叉重叠,那么重叠的位置的像素值,为多对连接点对的分量的平均值。

   def get_vectormap(self, joint_list,sign = True):
       # 该函数当中的joint_list,需要的是关节点的坐标,与图片像素的索引的围度数据相反
       vectormap = np.zeros((len(self.shuffle_ref)*2, self.output_size[0], self.output_size[1]), dtype=np.float32)
       countmap = np.zeros((len(self.shuffle_ref), self.output_size[0], self.output_size[1]), dtype=np.int16)
       for plane_idx, (j_idx1, j_idx2) in enumerate(self.shuffle_ref):
               center_from = joint_list[j_idx1]
               center_to = joint_list[j_idx2]
               # print("ceter from: ", center_from)
               # print("ceter to: ", center_to)
               if center_from[0] < -100 or center_from[1] < -100 or center_to[0] < -100 or center_to[1] < -100:
                   continue
               self.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to)

       vectormap = vectormap.transpose((1, 2, 0))
       nonzeros = np.nonzero(countmap)
       for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]):
           if countmap[p][y][x] <= 0:
               continue
           vectormap[y][x][p * 2 + 0] /= countmap[p][y][x]
           vectormap[y][x][p * 2 + 1] /= countmap[p][y][x]
       return vectormap.astype(np.float16)

   def put_vectormap(self, vectormap, countmap, plane_idx, center_from, center_to, threshold=1):
       _, height, width = vectormap.shape[:3]

       vec_x = center_to[0] - center_from[0]
       vec_y = center_to[1] - center_from[1]
       min_x = max(0, int(min(center_from[0], center_to[0]) - threshold))
       min_y = max(0, int(min(center_from[1], center_to[1]) - threshold))
       max_x = min(width, int(max(center_from[0], center_to[0]) + threshold))
       max_y = min(height, int(max(center_from[1], center_to[1]) + threshold))

       norm = math.sqrt(vec_x ** 2 + vec_y ** 2)
       if norm == 0:
           return
       vec_x /= norm
       vec_y /= norm
       for y in range(min_y, max_y):
           for x in range(min_x, max_x):
               bec_x = x - center_from[0]
               bec_y = y - center_from[1]
               dist = abs(bec_x * vec_y - bec_y * vec_x)

               if dist > threshold:
                   continue
               countmap[plane_idx][y][x] += 1
               vectormap[plane_idx * 2 + 0][y][x] = vec_x
               vectormap[plane_idx * 2 + 1][y][x] = vec_y

4 将data 多进程放入队列

  • 设置操作1:获取处理后的input、label,组成一个batch,将batch 数据放入到 Q_data 的队列中
  • 设置操作2:获取所有数据路径,打乱后放入 Q_name队列中
  • 设置多进程:多进程进行 操作1/2,

  def readdata(self,  image_path, label_path, num):

       image, joint = self.load_data(image_path, label_path)

       image = image.astype(np.float32)
       image = (image - np.mean(image, axis=(0,1)))/(np.std(image, axis=(0, 1))+1e-8)

       self.batch_image[num,:,:,:] = image
       self.batch_label_heatmap[num,:,:,:] = self.get_heatmap(joint / self.stride)
       self.batch_label_vectmap[num] = self.get_vectormap(joint / self.stride)
       return image, joint

   def Q_getname(self):
       for i in range(self.num_trianepoch):
           if self.data_aug: random.shuffle(self.annotations)
           for j in range(self.num_samples):
               if not os.path.exists(self.annotations[j][0]) or not os.path.exists(self.annotations[j][1]):
                   continue
               self.Q_name.put(self.annotations[j])

   def Q_getData(self, thread):
       self.Prepare()

       name = []
       while 1:
           if self.batch_count < self.num_step_one_epoch:   # 当【读取了几批】小于【一轮总批数】
               num = 0 # 统计批内读取个数
               while num < self.batch_size:        #【批内读取数据个数】小于【一个batch数值】
                   namefile = self.Q_name.get()
                   self.readdata(namefile[0], namefile[1], num)
                   name.append(namefile[0])
                   num += 1
               self.batch_count += 1 # 统计一轮的训练,读取了几个批次
               # print(name)
               # print(thread )
               self.Q_data.put([name, thread,
                                   self.batch_image,
                                    self.batch_label_heatmap,
                                    self.batch_label_vectmap])
               name = []
           else:
               self.batch_count = 0

   def start(self, P1):
       Process(target=self.Q_getname, args=()).start()
       for thread in range(P1):
           Process(target=self.Q_getData, args=(thread,)).start()
       return self.Q_data

5 全面测试数据读取是否正确

当我们编写好了数据读取的脚本,需要进行两方面的测试:

  • case1:单张图:输入图片的数据增强;神经网络输出相应的label的制作
  • case2:多进程的数据读取是否正确:避免出现多进程重复读取等情况
if __name__ == '__main__':
   
   case = 1 # 1:测试单张图片的数据增强 2:测试队列的获取图片的重复性的问题
   
   if case:

       humandata = Dataset("dotest")
       humandata.data_aug = True # 是否进行数据增强
      humandata.fortest = False  # 是否显示过程中每种增强后的图片

       humandata.Prepare()

       for s in range(len(humandata.annotations)):
           # r[s] = "590DSC_0165.png"
           image_path = humandata.annotations[s][0]
           label_path = humandata.annotations[s][1]

           print(image_path)
           print(label_path)
           image, joint = humandata.readdata(image_path, label_path ,0)
           print(joint.shape)

          c1 = []
          for ii in range(len(joint)):
              c1.append((int(joint[ii][0]), int(joint[ii][1])))
          for cc in range(len(joint)):
              cv2.circle(image, c1[cc], 2, (255, 0, 0), thickness=1)
           show_image("random_horizontal_flip", image)
           
           img_heatmap = np.zeros((humandata.output_size[0], humandata.output_size[1], 3))
           for i in range(humandata.batch_label_heatmap.shape[3]-1):
               H =humandata.batch_label_heatmap[0,:,:,i]
               H = np.array([H,H,H]).transpose([1,2,0])
               img_heatmap = img_heatmap + H
               # img_test1 = cv2.resize(H, (humandata.input_size[1], humandata.input_size[0]))
               # img_heatmap1 = cv2.resize(img_heatmap, (humandata.output_size[1], humandata.output_size[0]))
               # cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
               # cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
               # cv2.imshow("demo5", H)  # 展示图片
               # cv2.waitKey(0)  # 保持展示
           cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
           cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
           cv2.imshow("demo5", img_heatmap)  # 展示图片
           cv2.waitKey(0)  # 保持展示


           img_heatmap = np.zeros((humandata.output_size[0], humandata.output_size[1], 3))
           for i in range(humandata.batch_label_vectmap.shape[3]-1):
               H = abs(humandata.batch_label_vectmap[0,:,:,i] * 255)
               H = np.array([H,H,H]).transpose([1,2,0])
               img_heatmap = img_heatmap + H
               # img_test1 = cv2.resize(H, (humandata.input_size[1], humandata.input_size[0]))
               # img_heatmap1 = cv2.resize(img_heatmap, (humandata.input_size[1], humandata.input_size[0]))
               # cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
               # cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
              # cv2.imshow("demo5", H)  # 展示图片
               # cv2.waitKey(0)  # 保持展示
           cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
           cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
           cv2.imshow("demo5", img_heatmap)  # 展示图片
           cv2.waitKey(0)  # 保持展示
           
   else:
       humandata = Dataset("train")
       Q_traindata = humandata.start(3)
       for i in range(10):
           A = Q_traindata.get()
           print(A[1],A[0])      # 打印出队列存储的名字,以及来源的进程 id

在这里插入图片描述

6 附 opt.py 脚本

from easydict import EasyDict as edict
print("read config  ====================================")
cfg                             = edict()
cfg.OP                        = edict()
# Set the class name
cfg.OP.strides                = 8
cfg.OP.WH_ratio               = 1
cfg.OP.cpm_num = 22
cfg.OP.paf_num = 21*2

# Train options
cfg.TRAIN                       = edict()
cfg.TRAIN.annot_path            = "../data/train.txt"
cfg.TRAIN.batch_size            = 8
cfg.TRAIN.input_size            = [512]
# cfg.TRAIN.INPUT_SIZE            = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
cfg.TRAIN.data_aug              = True
cfg.TRAIN.learn_rate_init       = 1e-4
cfg.TRAIN.learn_rate_end        = 1e-6

cfg.TRAIN.warmup_epoch         = 2
cfg.TRAIN.first_stage_epoch    = 100
cfg.TRAIN.second_stage_epoch   = 30
cfg.TRAIN.initial_weights        = None
cfg.TRAIN.ckpt_path        = "./model/checkpoint0/"
cfg.TRAIN.log_path = './model/log0/'
#

# TEST options
cfg.TEST                        = edict()
cfg.TEST.annot_path             = "../data/test.txt"
cfg.TEST.batch_size             = 8
cfg.TEST.input_size             = [512]
cfg.TEST.data_aug               = False
Traceback (most recent call last): File "/Users/hejiajia/Desktop/Code/diffusionDemo.py", line 516, in <module> for step, batch in enumerate(dataloader): File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 631, in __next__ data = self._next_data() File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch data = self.dataset.__getitems__(possibly_batched_index) File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 2766, in __getitems__ batch = self.__getitem__(keys) File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 2762, in __getitem__ return self._getitem(key) File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 2747, in _getitem formatted_output = format_table( File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 639, in format_table return formatter(pa_table, query_type=query_type) File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 407, in __call__ return self.format_batch(pa_table) File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 521, in format_batch batch = self.python_features_decoder.decode_batch(batch) File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 228, in decode_batch return self.features.decode_batch(batch) if self.features else batch File "/Users/hejiajia/opt/anaconda3/lib/python3.8/site-packages/datasets/features/features.py", line 2087, in decode_b
最新发布
03-12
<think>嗯,用户在使用PyTorch的DataLoader和datasets库时遇到了数据加载的问题,我得想想怎么帮他们解决。首先,我需要回忆一下常见的错误和解决方案。根据提供的引用内容,可能的问题包括路径错误、数据格式不匹配、多进程问题、预处理错误、数据集分割以及自定义数据集的问题。 首先,路径错误是常见的,尤其是内置数据集如果无法自动下载,用户可能需要手动下载并指定路径。这时候要检查路径是否正确,或者设置download=True。引用[4]提到PyTorch有预加载数据集,比如FashionMNIST,可能需要正确指定root参数。 然后是数据格式的问题。用户可能使用的数据集不符合PyTorch的Dataset格式,比如图片的尺寸不一致或者标签文件格式不对。这时候需要检查transform是否正确应用,或者自定义Dataset类时是否正确实现了__getitem__方法。引用[5]提到不同文件夹存放不同类别的数据,或者image和label分开存放的情况,可能需要不同的处理方式。 多进程问题,比如num_workers设置过高导致内存不足,或者在Windows下使用spawn方法的问题。这时候建议调整num_workers的数量,或者在if __name__ == '__main__'中运行代码,避免多进程问题。引用[3]提到prefetch_factor参数可以预取数据,减少等待时间,但可能也需要调整num_workers。 预处理错误,比如Normalize的参数不正确,或者transform没有正确组合。可能需要检查transforms.Compose的使用,以及mean和std的值是否适合当前数据集。引用[1]中的示例使用了transforms.ToTensor()和Normalize,用户需要确认这些步骤是否正确应用。 数据集分割的问题,比如没有正确划分训练集和测试集,导致数据泄漏或加载错误。这时候应该使用random_split或Subset来划分,确保每次加载正确的部分。引用[4]提到内置数据集通常有train和test参数,可以分开加载。 自定义数据集的问题,用户可能需要继承Dataset类并实现__len__和__getitem__方法。这里要检查路径是否正确,文件读取是否正常,以及返回的数据类型是否符合模型输入的要求。引用[5]讨论了两种自定义数据集的情况,可能需要参考这些例子。 此外,内存不足或显存不足也可能导致问题,尤其是在使用大规模数据集时。可以尝试减小batch_size或者使用更高效的数据格式。引用[3]提到prefetch_factor可能有助于减少等待时间,但如果资源有限,可能需要调整这个参数。 还有可能遇到版本问题,比如PyTorch或datasets库的版本不兼容,建议升级到最新版。如果使用了第三方库如Hugging Face的datasets,需要确保和PyTorch兼容,可能需要转换数据格式为torch张量。 最后,建议用户提供具体的错误信息,比如异常堆栈,这样可以更准确地定位问题。例如,如果出现KeyError,可能是数据标签的键名不正确;如果形状不匹配,可能是transform处理后的数据维度不对。</think>### 逐步解决方案 #### 1. **检查数据集路径和加载方式** - **问题**:内置数据集(如MNIST)无法自动下载或路径错误。 - **解决**: - 使用`root`参数指定正确路径,并设置`download=True`: ```python train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()) ``` - 手动下载数据集文件(如`.npz`或`.gz`格式),保存到`root`目录[^1][^4]。 #### 2. **处理数据格式不匹配** - **问题**:数据维度或类型与模型输入不兼容(例如图像未转为张量)。 - **解决**: - 使用`transforms`组合预处理步骤: ```python transform = transforms.Compose([ transforms.Resize((32, 32)), # 统一尺寸 transforms.ToTensor(), # 转为张量 transforms.Normalize(mean=[0.5], std=[0.5]) # 标准化 ]) ``` - 检查自定义数据集的`__getitem__`方法是否返回`(data, label)`对[^5]。 #### 3. **解决多进程加载问题** - **问题**:`num_workers > 0`时出现内存错误或子进程崩溃(尤其在Windows)。 - **解决**: -代码放在`if __name__ == '__main__':`块中: ```python if __name__ == '__main__': dataloader = DataLoader(dataset, num_workers=4) ``` - 降低`num_workers`(如设为0或2)或使用`prefetch_factor=1`[^3]。 #### 4. **处理不规则数据** - **问题**:数据长度不一致(如文本序列或变长音频)。 - **解决**: - 自定义`collate_fn`函数处理填充或截断: ```python def collate_fn(batch): data = [item[0] for item in batch] labels = [item[1] for item in batch] # 对数据填充对齐 data_padded = torch.nn.utils.rnn.pad_sequence(data, batch_first=True) return data_padded, torch.stack(labels) dataloader = DataLoader(dataset, collate_fn=collate_fn) ``` - 参考PyTorch文档的`pad_sequence`和`pack_padded_sequence`方法。 #### 5. **验证数据集划分** - **问题**:训练集和测试集未正确分割,导致数据泄露。 - **解决**: - 使用`random_split`分割数据集: ```python train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) ``` - 确保不同阶段使用不同的`DataLoader`实例[^4]。 #### 6. **自定义数据集实现** - **问题**:自定义数据集未继承`Dataset`或未实现关键方法。 - **解决**: - 确保实现`__len__`和`__getitem__`: ```python class CustomDataset(torch.utils.data.Dataset): def __init__(self, img_dir, labels_path, transform=None): self.img_paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir)] self.labels = pd.read_csv(labels_path) self.transform = transform def __len__(self): return len(self.img_paths) def __getitem__(self, idx): img = Image.open(self.img_paths[idx]) label = self.labels.iloc[idx] if self.transform: img = self.transform(img) return img, label ``` - 检查文件路径和标签索引是否正确[^5]。 --- ### 相关问题 1. **如何优化PyTorch DataLoader的加载速度?** - 调整`num_workers`和`prefetch_factor`,使用SSD硬盘或内存缓存。 2. **如何处理PyTorch中的类别不平衡数据?** - 使用`WeightedRandomSampler`或自定义采样器[^5]。 3. **如何加载Hugging Face的`datasets`库数据PyTorch?** - 用`.with_format("torch")`转换数据格式,再传入`DataLoader`。 --- ### 引用标识 : 【小白学习PyTorch教程】五、在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据 [^2]: At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. : PyTorch DataLoader 学习 : PyTorch中 Datasets & DataLoader 的介绍 : Pytorch之Dataset和Dataloader(加载数据)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值