if is_train:
data_gen2 = ImageDataGenerator(rescale = .3, rotation_range =0.2 , zoom_range = 0.2, width_shift_range = 0.2, height_shift_range= 0.2)
dataflow_generator2 = data_gen2.flow_from_directory(data_dir,target_size=(160, 160),batch_size=1,color_mode='rgb',class_mode='categorical')
labels_dict = dataflow_generator2.class_indices
print("labels_dict",labels_dict)
sample_count = len(dataflow_generator2.filenames)*augment_times
print('sample_count:',sample_count)
#filenames = dataflow_generator2.filenames
#labels = dataflow_generator2.class_indices
#print(filenames)
#print(labels)
for image_data in dataflow_generator:
# TODO:使用 plt.imshow 和 plt.show() 显示图像
#print(len(image_data[1]))
for j in range(0,len(image_data[1])):
image = image_data[0][j].astype('uint8')
images.append(image)
labels.append(image_data[1][j])
#print(image_data[1]) #label
#print(image_data[0][0].shape) #image
sample_count -= 1
if sample_count <= 0:
images = np.array(images)
labels = np.array(labels)
break