Udacity作业——TensorFlow notMNIST代码及输出结果——Udacity学习笔记

本文介绍了一个基于 notMNIST 数据集的预处理流程,包括数据下载、解压、图像显示、数据集构建、数据清洗等多个步骤,并最终使用逻辑回归模型进行训练与评估。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

           代码块1-载入模块

In [1]:
from __future__ import print_function
import matplotlib.pyplot as plt#绘图模块
import numpy as np#矩阵模块
import os
import sys
import tarfile#文件解压模块
from IPython.display import display, Image
from scipy import ndimage
from sklearn.linear_model import LogisticRegression#回归模块
from six.moves.urllib.request import urlretrieve#下载模块
from six.moves import cPickle as pickle#压缩模块
# Config the matplotlib backend as plotting inline in IPython
%matplotlib inline

代码块2-下载文件

注:由于该网络地址下载数据太慢,因此,建议不使用该函数进行数据下载。而是自己将数据下载到本地文件夹中。下载的网址如下:

http://yaroslavvb.com/upload/notMNIST/

http://pan.baidu.com/s/1bpKt19P

代码块3-解压文件并存储解压后的文件地址

由于该代码块的解压速度太慢,因此,利用解压工具解压该文件。并将解压后的文件夹地址存储起来,为后续的调用做好准备即可。下文代码删除了作业中解压部分的代码,而保留存储文件夹路径的代码。

In [2]:
num_classes = 10
np.random.seed(133)

#创建每一个类别的文件夹名
def maybe_extract(filename, force=False):
  root = os.path.splitext(os.path.splitext(filename)[0])[0]  # remove .tar.gz
  data_folders = [os.path.join(root, d) for d in sorted(os.listdir(root))
    if os.path.isdir(os.path.join(root, d))]
  if len(data_folders) != num_classes:
    raise Exception(
      'Expected %d folders, one per class. Found %d instead.' % (
        num_classes, len(data_folders)))
  return data_folders

#本地存储notMnist数据的的文件夹
train_filename = '/home/zlong/workspace/udacity/notMNIST/notMNIST_large'  
test_filename = '/home/zlong/workspace/udacity/notMNIST/notMNIST_small'
train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)

问题1-显示解压后的图像

In [3]:
#Problem1: Display a sample of the images that we just download
nums_image_show = 2#显示的图像张数
for index_class in range(num_classes):
    #i from 0 to 9
    imagename_list = os.listdir(train_folders[index_class])
    imagename_list_indice = imagename_list[0:nums_image_show]
    for index_image in range(nums_image_show):
        path = train_folders[index_class] +'/' + imagename_list_indice[index_image]
        display(Image(filename = path))

代码块4-加载和归一化图像数据

该代码块主要实现了3个功能:1是将本地硬盘中的每类图像文件夹中的图像数据读到一个3维的dataset对象中,第1维是图像个数索引,其余2维则是图像数据。其中主要是利用了scipy模块中的ndarray对象兑取硬盘中的图像数据。2是将读取到的图像数据按照上文所述的公式进行了归一化。3是将ndarray对象打包为pickle格式并存储在工作目录下,每个类别有一个.pickle文件。并将打包后.pickle文件的地址存储为train_datasets和test_datasets返回。

注:将数据打包为.pickle文件更便于数据的调用与处理。因为,图像的原始数据是使用循环打入到对象中的,如果每次使用图像数据均需要循环来加载,这样加大了代码量。而对.pickle文件只需要读取一次即可,而无需使用循环。

In [4]:
image_size = 28  # Pixel width and height.
pixel_depth = 255.0  # Number of levels per pixel.

def load_letter(folder, min_num_images):
  """Load the data for a single letter label."""
  image_files = os.listdir(folder)
  dataset = np.ndarray(shape=(len(image_files), image_size, image_size),
                         dtype=np.float32)
  print(folder)
  num_images = 0
  for image in image_files:
    image_file = os.path.join(folder, image)
    try:
      image_data = (ndimage.imread(image_file).astype(float) - 
                    pixel_depth / 2) / pixel_depth
      if image_data.shape != (image_size, image_size):
        raise Exception('Unexpected image shape: %s' % str(image_data.shape))
      dataset[num_images, :, :] = image_data
      num_images = num_images + 1
    except IOError as e:
      print('Could not read:', image_file, ':', e, '- it\'s ok, skipping.')
    
  dataset = dataset[0:num_images, :, :]
  if num_images < min_num_images:
    raise Exception('Many fewer images than expected: %d < %d' %
                    (num_images, min_num_images))
    
  print('Full dataset tensor:', dataset.shape)
  print('Mean:', np.mean(dataset))
  print('Standard deviation:', np.std(dataset))
  return dataset
        
def maybe_pickle(data_folders, min_num_images_per_class, force=False):
  dataset_names = []
  for folder in data_folders:
    set_filename = folder + '.pickle'
    dataset_names.append(set_filename)
    if os.path.exists(set_filename) and not force:
      # You may override by setting force=True.
      print('%s already present - Skipping pickling.' % set_filename)
    else:
      print('Pickling %s.' % set_filename)
      dataset = load_letter(folder, min_num_images_per_class)
      try:
        with open(set_filename, 'wb') as f:
          pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
      except Exception as e:
        print('Unable to save data to', set_filename, ':', e)
  
  return dataset_names

train_datasets = maybe_pickle(train_folders, 45000)
test_datasets = maybe_pickle(test_folders, 1800)
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/A.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/B.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/C.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/D.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/E.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/F.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/G.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/H.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/I.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_large/J.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/A.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/B.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/C.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/D.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/E.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/F.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/G.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/H.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/I.pickle already present - Skipping pickling.
/home/zlong/workspace/udacity/notMNIST/notMNIST_small/J.pickle already present - Skipping pickling.

问题2 显示从pickle文件中读取的图像

In [5]:
#Problem2 Displaying a sample of the labels and images from the ndarray

# Config the matplotlib backend as plotting inline in IPython
%matplotlib inline
import matplotlib.pyplot as plt
def load_and_displayImage_from_pickle(data_filename_set,NumClass,NumImage):
    if(NumImage <= 0):
        print('NumImage <= 0')
        return
    plt.figure('subplot')
    for index,pickle_file in enumerate(data_filename_set):
        with open(pickle_file,'rb') as f:
            data = pickle.load(f)
            ImageList = data[0:NumImage,:,:]
            for i,Image in enumerate(ImageList):
                #NumClass代表类别,每个类别一行;NumImage代表每个类显示的图像张数
                plt.subplot(NumClass, NumImage, index*NumImage+i+1)
                plt.imshow(Image)
            index = index+1        
#显示10类,每类显示5张图片        
load_and_displayImage_from_pickle(train_datasets,10,5)    
load_and_displayImage_from_pickle(test_datasets,10,5) 

问题3-检测数据是否平衡

数据是否平衡的意思是各类样本的大小是否相当。

In [6]:
def show_sum_of_different_class(data_filename_set):
    plt.figure(1)
    #read .pickle file
    sumofdifferentclass = []
    for pickle_file in data_filename_set:
        with open(pickle_file,'rb') as f:
            data = pickle.load(f)
            print(len(data))
            sumofdifferentclass.append(len(data))

    #show the data
    x = range(10)
    plt.bar(x,sumofdifferentclass)    
    plt.show()

print('train_datasets:\n')    
show_sum_of_different_class(train_datasets)  
print('test_datasets:\n')    
show_sum_of_different_class(test_datasets) 
train_datasets:

52909
52911
52912
52911
52912
52912
52912
52912
52912
52911
test_datasets:

1872
1873
1873
1873
1873
1872
1872
1872
1872
1872

代码块5-将不同类别的数据混合并将得到验证集

该模块实现了2个功能:1是将不同类别的数据进行混合。之前是每个类别一个数据对象。现在,为了便于后续的训练,需将不同类别的数据存储为一个大的数据对象,即该对象同时包含A、B…J共个类别的样本。2是从训练集中提取一部分作为验证集。

In [7]:
def make_arrays(nb_rows, img_size):
  if nb_rows:
    dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
    labels = np.ndarray(nb_rows, dtype=np.int32)
  else:
    dataset, labels = None, None
  return dataset, labels

def merge_datasets(pickle_files, train_size, valid_size=0):
  num_classes = len(pickle_files)
  valid_dataset, valid_labels = make_arrays(valid_size, image_size)
  train_dataset, train_labels = make_arrays(train_size, image_size)
  vsize_per_class = valid_size // num_classes
  tsize_per_class = train_size // num_classes
    
  start_v, start_t = 0, 0
  end_v, end_t = vsize_per_class, tsize_per_class
  end_l = vsize_per_class+tsize_per_class
  for label, pickle_file in enumerate(pickle_files):       
    try:
      with open(pickle_file, 'rb') as f:
        letter_set = pickle.load(f)
        # let's shuffle the letters to have random validation and training set
        np.random.shuffle(letter_set)
        if valid_dataset is not None:
          valid_letter = letter_set[:vsize_per_class, :, :]
          valid_dataset[start_v:end_v, :, :] = valid_letter
          valid_labels[start_v:end_v] = label
          start_v += vsize_per_class
          end_v += vsize_per_class
                    
        train_letter = letter_set[vsize_per_class:end_l, :, :]
        train_dataset[start_t:end_t, :, :] = train_letter
        train_labels[start_t:end_t] = label
        start_t += tsize_per_class
        end_t += tsize_per_class
    except Exception as e:
      print('Unable to process data from', pickle_file, ':', e)
      raise
    
  return valid_dataset, valid_labels, train_dataset, train_labels
            
            
train_size = 200000
valid_size = 10000
test_size = 10000

valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(
  train_datasets, train_size, valid_size)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size)

print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)
Training: (200000, 28, 28) (200000,)
Validation: (10000, 28, 28) (10000,)
Testing: (10000, 28, 28) (10000,)

代码块6-将混合后的数据进行随机化

上一步只是将数据进行和混合并存储为一个大的数据对象,此步则将混合后的数据对象中的数据进行了随机化处理。只有随机化后的数据训练模型时才会有较为稳定的效果。

In [8]:
def randomize(dataset, labels):
  permutation = np.random.permutation(labels.shape[0])
  shuffled_dataset = dataset[permutation,:,:]
  shuffled_labels = labels[permutation]
  return shuffled_dataset, shuffled_labels
train_dataset, train_labels = randomize(train_dataset, train_labels)
test_dataset, test_labels = randomize(test_dataset, test_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)

问题4 从验证混合后的数据

In [9]:
'''Problem4 Convince yourself that the data is still good after shuffling!
'''
#data_set是数据集,NumImage是显示的图像张数
def displayImage_from_dataset(data_set,NumImage):
    if(NumImage <= 0):
        print('NumImage <= 0')
        return
    plt.figure('subplot')
    ImageList = data_set[0:NumImage,:,:]
    for index,Image in enumerate(ImageList):
        #NumClass代表类别,每个类别一行;NumImage代表每个类显示的图像张数
        plt.subplot(NumImage//5+1, 5, index+1)
        plt.imshow(Image)
        index = index+1    
    plt.show()
displayImage_from_dataset(train_dataset,50)  

代码块7-将不同的样本及存为.pickle文件

In [10]:
data_root = '.' # Change me to store data elsewhere

print(data_root)
pickle_file = os.path.join(data_root, 'notMNIST.pickle')
print(pickle_file)

try:
  f = open(pickle_file, 'wb')
  save = {
    'train_dataset': train_dataset,
    'train_labels': train_labels,
    'valid_dataset': valid_dataset,
    'valid_labels': valid_labels,
    'test_dataset': test_dataset,
    'test_labels': test_labels,
    }
  pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
  f.close()
except Exception as e:
  print('Unable to save data to', pickle_file, ':', e)
  raise
.
./notMNIST.pickle

问题5-数据清洗

一般来说,训练集、验证集和测试集中会有数据的重合,但是,如果重合的数据太多则会影响到测试结果的准确程度。因此,需要对数据进行清洗,使彼此之间步存在交集。

注:ndarray数据无法使用set的方式来求取交集。但如果使用循环对比的方式在数据量大的情况下会非常慢,因此,下文的做法使先将数据哈希化,再通过哈希的键值来判断数据是否相等。由于哈希的键值是字符串,因此比对起来效率会高很多。

In [11]:
#先使用hash
import hashlib

#使用sha的作用是将二维数据和哈希值之间进行一一对应,这样,通过比较哈希值就能将二维数组是否相等比较出来
def extract_overlap_hash_where(dataset_1,dataset_2):

    dataset_hash_1 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_1])
    dataset_hash_2 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_2])
    overlap = {}
    for i, hash1 in enumerate(dataset_hash_1):
        duplicates = np.where(dataset_hash_2 == hash1)
        if len(duplicates[0]):
            overlap[i] = duplicates[0]
    return overlap

#display the overlap
def display_overlap(overlap,source_dataset,target_dataset):
    overlap = {k: v for k,v in overlap.items() if len(v) >= 3}
    item = np.random.choice(list(overlap.keys()))
    imgs = np.concatenate(([source_dataset[item]],target_dataset[overlap[item][0:7]]))
    plt.suptitle(item)
    for i,img in enumerate(imgs):
        plt.subplot(2,4,i+1)
        plt.axis('off')
        plt.imshow(img)
    plt.show()

#数据清洗
def sanitize(dataset_1,dataset_2,labels_1):
    dataset_hash_1 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_1])
    dataset_hash_2 = np.array([hashlib.sha256(img).hexdigest() for img in dataset_2])
    overlap = []
    for i,hash1 in enumerate(dataset_hash_1):
        duplictes = np.where(dataset_hash_2 == hash1)
        if len(duplictes[0]):
            overlap.append(i)
    return np.delete(dataset_1,overlap,0),np.delete(labels_1, overlap, None)


overlap_test_train = extract_overlap_hash_where(test_dataset,train_dataset)
print('Number of overlaps:', len(overlap_test_train.keys()))
display_overlap(overlap_test_train, test_dataset, train_dataset)

test_dataset_sanit,test_labels_sanit = sanitize(test_dataset,train_dataset,test_labels)
print('Overlapping images removed from test_dataset: ', len(test_dataset) - len(test_dataset_sanit))

valid_dataset_sanit, valid_labels_sanit = sanitize(valid_dataset, train_dataset, valid_labels)
print('Overlapping images removed from valid_dataset: ', len(valid_dataset) - len(valid_dataset_sanit))

print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_labels_sanit.shape, valid_labels_sanit.shape)
print('Testing:', test_dataset_sanit.shape, test_labels_sanit.shape)

pickle_file_sanit = 'notMNIST_sanit.pickle'
try:
    f = open(pickle_file_sanit,'wb')
    save = {
        'train_dataset':train_dataset,
        'train_labels': train_labels,
        'valid_dataset': valid_dataset,
        'valid_labels': valid_labels,
        'test_dataset': test_dataset,
        'test_labels': test_labels,
    }
    pickle.dump(save,f,pickle.HIGHEST_PROTOCOL)
    f.close()
except Exception as e:
  print('Unable to save data to', pickle_file, ':', e)
  raise

statinfo = os.stat(pickle_file_sanit)
print('Compressed pickle size:', statinfo.st_size)
Number of overlaps: 1284
Overlapping images removed from test_dataset:  1284
Overlapping images removed from valid_dataset:  1069
Training: (200000, 28, 28) (200000,)
Validation: (8931,) (8931,)
Testing: (8716, 28, 28) (8716,)
Compressed pickle size: 690800506

问题6-模型训练

该模型是使用逻辑回归模型进行的训练。

In [12]:
def train_and_predict(sample_size):
    regr = LogisticRegression()
    X_train = train_dataset[:sample_size].reshape(sample_size,784)
    y_train = train_labels[:sample_size]
    regr.fit(X_train,y_train)
    X_test = test_dataset.reshape(test_dataset.shape[0],28*28)
    y_test = test_labels

    pred_labels = regr.predict(X_test)
    print('Accuracy:', regr.score(X_test, y_test), 'when sample_size=', sample_size)

for sample_size in [50,100,1000,5000,len(train_dataset)]:
    train_and_predict(sample_size)
Accuracy: 0.6392 when sample_size= 50
Accuracy: 0.7437 when sample_size= 100
Accuracy: 0.8353 when sample_size= 1000
Accuracy: 0.8497 when sample_size= 5000
Accuracy: 0.8898 when sample_size= 200000
### NotMNIST 数据集概述 NotMNIST 数据集设计用于Python实验,旨在模仿经典的MNIST数据集的同时更贴近真实世界的数据特性[^1]。相较于MNIST,该数据集的任务难度更高,且数据质量如MNIST那样“干净”。此数据集首次公布于2011年,作为MNIST的一个增强版存在。 #### 数据集结构 数据集中包含了从A至J共十个类别的灰度图像,每个类别代表一个英文字母。整个数据集分为两大子集: - **Small(小型)子集**:经过人工筛选处理,包含大约19,000张图片,其误分类率为约0.5%。 - **Large(大型)子集**:未经任何特别清理工作,含有近50万(500k)幅图样,具有较高的噪声水平,估计误分类率达到6.5%左右[^3]。 ### 下载指南 对于希望获取NotMNIST数据集的研究人员或开发者来说,可能会遇到下载困难的情况。特别是在跟随某些教程如Udacity Deep Learning课程时,官方提供的链接可能失效。针对这一情况,建议访问GitHub上的TensorFlow项目页面下的Issue #1475,这里提供了有效的下载链接供用户使用[^2]。 ```python import os import urllib.request from zipfile import ZipFile url = 'http://yaroslavvb.com/upload/notmnist/notMNIST_large.zip' filename = url.split('/')[-1] if not os.path.exists(filename): print(f'Downloading {filename}...') urllib.request.urlretrieve(url, filename) with ZipFile(filename, 'r') as zip_ref: zip_ref.extractall('data') print('Extraction completed.') ``` 上述脚本展示了如何通过编程方式自动完成文件的下载与解压操作,适用于那些希望通过自动化流程来准备数据环境的人群。 ### 使用说明 当成功获得并解压缩了NotMNIST数据集之后,下一步就是将其加载入内存以便进一步分析或者训练模型。由于这是一个图像识别任务,因此通常会涉及到预处理步骤比如标准化尺寸、转换颜色模式以及划分训练测试集合等。具体实现可以根据所使用的框架有所同;例如,在PyTorch环境中,可以通过`torchvision.datasets.ImageFolder()`函数轻松读取这些JPEG格式的照片,并应用必要的变换规则。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值