import urllib
import os
import tarfile
import numpy as np
import pickle as pk
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
#下载、解压CIFAR数据集
url='https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
file_path='/Users/mac/Downloads/TensorFlow/09图像识别问题:卷积神经网络与应用/data/cifar-10-python.tar.gz'
#如果目标文件不存在,则从指定url下载该文件
if not os.path.isfile(file_path):
urllib.request.urlretrieve(url,file_path)
#如果目录下不存在文件,则解压
if not os.path.exists('/Users/mac/Downloads/TensorFlow/09图像识别问题:卷积神经网络与应用/data/cifar-10-batches-py'):
tfile=tarfile.open('/Users/mac/Downloads/TensorFlow/09图像识别问题:卷积神经网络与应用/data/cifar-10-python.tar.gz','r:gz')
tfile.extractall('/Users/mac/Downloads/TensorFlow/09图像识别问题:卷积神经网络与应用/data')
#载入数据
def load_batch(file): #读取一个批次的数据
with open(file,'rb') as f:
data_dict=pk.load(f,encoding='bytes')
images=data_dict[b'data']
labels=data_dict[b'labels']
#将一维图片数据调整为四维数组
images=images.reshape(10000,3,32,32)
#将(10000,3,32,32)调整参数数组维度为(10000,32,32,3)
images=images.transpose(0,2,3,1)
labels=np.array(labels)
return images,labels
def load_data(data_dir):
images_train=[]
labels_train=[]
cifar10-图像识别相关代码
最新推荐文章于 2025-03-06 16:00:31 发布