输入的是训练数据路径data_path,label_path,还有batch_size;
输出的是:
train_data_batch,shape 是( [batch_size,img_row, img_col, 3]) # 可相应变化
train_label_batch,shape是([batch_size, img_row, img_col]) # 可相应变化
import cv2
import os
import numpy as np
def next_batch(data_path, lable_path, batch_size):
train_temp = np.random.randint(low=0, high=Train_nums + 1, size=batch_size) # 生成元素的值在[low,high)区间,随机选取
train_data_batch = np.zeros([batch_size,img_row, img_col, 3]) # 其中[img_row, img_col, 3]是原数据的shape,相应变化
train_label_batch = np.zeros([batch_size, img_row, img_col]) #
count = 0 # 后面就是读入图像,并打包成四维的batch
img_list = os.listdir(data_path)
for i in train_temp:
img_path = os.path.join(data_path, img_list[i]) # 图片文件
label_path = os.path.join(lable_path, img_list[i]) # label的图片名要和img对应上
train_data_batch[count, :, :, :] = cv2.imread(img_path)
train_label_batch[count, :, :] = cv2.imread(label_path)
count+=1
return train_data_batch, train_label_batch