华为昇思MindSpore深度学习框架初体验,真正的遥遥领先!
MindSpore深度学习框架介绍
华为昇思 MindSpore是一个全场景深度学习框架,对于初学者而言,其最大的优势在于API友好、调试难度低1。我今天报名参加了昇思25天学习打卡营,准备记录使用MindSpore框架的体验。今天的打卡内容是MNIST手写数字识别。
MNIST手写字识别
作为深度神经网络的入门级案例,手写数字识别数据集MNIST源自图灵奖得主,深度学习领域的奠基人之一Yann LeCun在1998年发表的一篇论文《Gradient-Based Learning Applied to Document Recognition》2 (原文使用了CNN但MindSpore案例使用了一个简单的多层感知器)。对于这样一个经典案例,我曾尝试过多种深度学习框架如Pytorch, Tensorflow等,但MindSpore框架给我的感觉完全是断崖式领先,这极大地来源于后发优势,能够在现有框架基础上解决问题(后续文章会出对比)。
数据集处理
MindSpore提供基于Pipeline的数据引擎,进行数据集和数据变换,对于MNIST数据集可以自动下载。
def datapipe(dataset, batch_size):
image_transforms = [
vision.Rescale(1.0 / 255.0, 0),
vision.Normalize(mean=(0.1307,), std=(0.3081,)),
vision.HWC2CHW()
]
label_transform = transforms.TypeCast(mindspore.int32)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset
代码部分非常简洁清晰,具有极强的可读性,如果你经常处理图像问题,可以不用查询任何文档就能猜出代码块的含义,如
HWC2CHW()
对于数据的尺寸,类型的输出也十分清晰
for image, label in test_dataset.create_tuple_iterator():
print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
print(f"Shape of label: {label.shape} {label.dtype}")
break
Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
Shape of label: (64,) Int32
总体而言使用起来门槛很低,可理解性很强。
网络构建
网络构建的类代码与Pytorch类似,但相比之下更简洁,各层和激活函数的结构非常清晰
# Define model
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10)
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
model = Network()
print(model)
并且能够直接打印网络的结构和层索引,一目了然
Network<
(flatten): Flatten<>
(dense_relu_sequential): SequentialCell<
(0): Dense<input_channels=784, output_channels=512, has_bias=True>
(1): ReLU<>
(2): Dense<input_channels=512, output_channels=512, has_bias=True>
(3): ReLU<>
(4): Dense<input_channels=512, output_channels=10, has_bias=True>
我认为这种结构展现非常有助于初学者的学习,并且可以完全按照结构编写代码,降低编程难度。总体而言,今天的打卡学习体验很好,接下来将继续深度体验其他内容和模块,做更多分享。
打卡截图
今日打卡截图:

闹了个乌龙,没有改时区。附修改时区代码(UTC+8)
import time
from datetime import datetime, timedelta, timezone
print(datetime.now(timezone(timedelta(hours=8))).strftime("%Y-%m-%d %H:%M:%S"),'edwardfeng')
1185

被折叠的 条评论
为什么被折叠?



