PyTorch-Tutorial项目中的自编码器实现详解

PyTorch-Tutorial项目中的自编码器实现详解

PyTorch-Tutorial Build your neural network easy and fast, 莫烦Python中文教学 PyTorch-Tutorial 项目地址: https://gitcode.com/gh_mirrors/pyt/PyTorch-Tutorial

自编码器概述

自编码器(AutoEncoder)是一种无监督学习的神经网络模型,主要用于数据的降维和特征提取。它通过将输入数据压缩到一个低维空间(编码),然后再从这个低维表示重建原始数据(解码),从而学习数据的有用特征。

代码实现解析

1. 数据准备

首先加载MNIST手写数字数据集,这是深度学习中最常用的基准数据集之一:

train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

数据加载器将数据集分成小批量,便于训练:

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

2. 自编码器模型结构

自编码器由编码器和解码器两部分组成:

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),  # 压缩到3维特征
        )
        
        # 解码器部分
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid()  # 输出在0-1范围内
        )

编码器将784维(28×28)的输入图像逐步压缩到3维空间,解码器则从3维空间重建原始图像。使用Tanh作为激活函数,最后解码器输出使用Sigmoid确保像素值在0-1之间。

3. 训练过程

训练过程使用均方误差(MSE)作为损失函数,Adam优化器进行参数更新:

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()

for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate(train_loader):
        b_x = x.view(-1, 28*28)  # 展平输入
        b_y = x.view(-1, 28*28)  # 目标输出
        
        encoded, decoded = autoencoder(b_x)
        loss = loss_func(decoded, b_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

4. 可视化结果

训练过程中实时显示原始图像和重建图像的对比:

# 初始化图像显示
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()

# 训练过程中更新显示
_, decoded_data = autoencoder(view_data)
for i in range(N_TEST_IMG):
    a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')

训练完成后,将200个样本的3维编码结果在3D空间中可视化:

fig = plt.figure(2); ax = Axes3D(fig)
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):
    c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)

技术要点解析

  1. 数据预处理:MNIST图像被归一化到[0,1]范围,并展平为784维向量
  2. 网络设计:采用全连接层逐步压缩/扩展维度,使用Tanh激活函数防止梯度消失
  3. 损失函数:使用MSE衡量重建图像与原始图像的差异
  4. 可视化:3D编码空间展示不同数字类别的分布情况

实际应用场景

自编码器在实际中有多种应用:

  1. 数据降维:比PCA等线性方法更强大的非线性降维
  2. 异常检测:重建误差大的样本可能是异常值
  3. 图像去噪:训练时加入噪声,学习重建干净图像
  4. 特征提取:编码器部分可作为特征提取器

总结

本教程实现了一个基本的自编码器模型,展示了如何使用PyTorch构建、训练和评估自编码器。通过将784维的手写数字图像压缩到3维空间并重建,我们不仅理解了自编码器的工作原理,还通过可视化直观地看到了不同数字在潜在空间中的分布情况。

PyTorch-Tutorial Build your neural network easy and fast, 莫烦Python中文教学 PyTorch-Tutorial 项目地址: https://gitcode.com/gh_mirrors/pyt/PyTorch-Tutorial

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

资源下载链接为: https://pan.quark.cn/s/5c50e6120579 在Android移动应用开发中,定位功能扮演着极为关键的角色,尤其是在提供导航、本地搜索等服务时,它能够帮助应用获取用户的位置信息。以“baiduGPS.rar”为例,这是一个基于百度地图API实现定位功能的示例项目,旨在展示如何在Android应用中集成百度地图的GPS定位服务。以下是对该技术的详细阐述。 百度地图API简介 百度地图API是由百度提供的一系列开放接口,开发者可以利用这些接口将百度地图的功能集成到自己的应用中,涵盖地图展示、定位、路径规划等多个方面。借助它,开发者能够开发出满足不同业务需求的定制化地图应用。 Android定位方式 Android系统支持多种定位方式,包括GPS(全球定位系统)和网络定位(通过Wi-Fi及移动网络)。开发者可以根据应用的具体需求选择合适的定位方法。在本示例中,主要采用GPS实现高精度定位。 权限声明 在Android应用中使用定位功能前,必须在Manifest.xml文件中声明相关权限。例如,添加<uses-permission android:name="android.permission.ACCESS_FINE_LOCATION" />,以获取用户的精确位置信息。 百度地图SDK初始化 集成百度地图API时,需要在应用启动时初始化地图SDK。通常在Application类或Activity的onCreate()方法中调用BMapManager.init(),并设置回调监听器以处理初始化结果。 MapView的创建 在布局文件中添加MapView组件,它是地图显示的基础。通过设置其属性(如mapType、zoomLevel等),可以控制地图的显示效果。 定位服务的管理 使用百度地图API的LocationClient类来管理定位服务
资源下载链接为: https://pan.quark.cn/s/dab15056c6a5 Oracle Instant Client是一款轻量级的Oracle数据库连接工具,能够在不安装完整Oracle客户端软件的情况下,为用户提供访问Oracle数据库的能力。以“instantclient-basic-nt-12.1.0.1.0.zip”为例,它是针对Windows(NT)平台的Instant Client基本版本,版本号为12.1.0.1.0,包含连接Oracle数据库所需的基本组件。 Oracle Instant Client主要面向开发人员和系统管理员,适用于数据库查询、应用程序调试、数据迁移等工作。它支持运行SQL*Plus、PL/SQL Developer等管理工具,还能作为ODBC和JDBC驱动的基础,让非Oracle应用连接到Oracle数据库。 安装并解压“instantclient_12_1”后,为了使PL/SQL Developer等应用程序能够使用该客户端,需要进行环境变量配置。设置ORACLE_HOME指向Instant Client的安装目录,如“C:\instantclient_12_1”。添加TNS_ADMIN环境变量,用于存放网络配置文件(如tnsnames.ora)。将Instant Client的bin目录添加到PATH环境变量中,以便系统能够找到oci.dll等关键动态链接库。 oci.dll是OCI(Oracle Call Interface)库的重要组成部分。OCI是Oracle提供的C语言接口,允许开发者直接与数据库交互,执行SQL语句、处理结果集和管理事务等功能。确保系统能够找到oci.dll是连接数据库的关键。 tnsnames.ora是Oracle的网络配置文件,用于定义数据库服务名与网络连接参数的映射关系,包括服务器地址
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江燕娇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值