程序功能:读取文件夹内图片并输出形状[m,n_H,n_W,n_C]的数组
m:图片数量
n_H:图片高度
n_W:图片宽度
n_C:图片维数
def read_picture(path,n_C):
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
#function:读取path路径下的图片,并转为形状为[m,n_H,n_W,n_C]的数组
#path:str,图片所在路径
#n_C:int,图像维数,黑白图像输入1,rgb图像输入3
#datas:返回维度为(m,n_H,n_W,n_C)的array(数组)矩阵
datas=[]
x_dirs=os.listdir(path)
for x_file in x_dirs:
fpath=os.path.join(path,x_file)
if n_C == 1 :
_x=Image.open(fpath).convert("L")
#plt.imshow(_x,"gray") #显示图像(只显示最后一张)
elif n_C ==3:
_x=Image.open(fpath)
#plt.imshow(_x) #显示图像(只显示最后一张)
else:
print("错误:图像维数错误")
n_W=_x.size[0]
n_H=_x.size[1]
#若要对图像进行放大缩小,激活(去掉注释)以下函数
'''
rat=0.8 #放大/缩小倍数
n_W=int(rat*n_W)
n_H=int(rat*n_H)
_x=_x.resize((n_W,n_H)) #直接给n_W,n_H赋值可将图像变为任意大小
'''
datas.append(np.array(_x))
_x.close()
datas=np.array(datas)
m=datas.shape[0]
datas=datas.reshape((m,n_H,n_W,n_C))
#print(datas.shape)
return datas