pytorch中backward的几个要点

本文深入探讨PyTorch中backward机制的两大关键点:requires_grad属性与标量求导必要性。通过实例演示如何定义网络层并利用backward进行反向传播,强调requires_grad用于跟踪变量变化,而只有标量(loss)才能触发反向传播。

pytorch中backward的2个要点

1. requires_grad

用pytorch定义一个网络层,如果数据中加入requires_grad=True,那么对于这个变量而言,就有了属于自己的导数(grad),如果这个数据是矩阵,那么他的grad是同样大小的一个矩阵。我们将requires_grad视为该变量的一个属性,我们知道,我们在现实生活中,或者说在神经网络中,大部分的函数都是一阶连续可微的,也就是说,他的梯度具有唯一性。requires_grad的存在非常合理。

2. scale才能有backward

scale是标量的意思。
首先我们可用用如下语句查看等式中某个自变量的梯度。

print(x.grad, y.grad)

但是有个前提,我们必须要先对他的结果使用.backward()才能去查看,不然的话,他的梯度会显示为none。
非常需要注意的一点是,能够使用.backward()的必须是标量(scale),不然程序会报错。
结合实际的情况,我们看任何一个网络,使用backward的地方几乎只有一个,那就是loss.backward()。
首先loss肯定是一个标量,无论是MSE还是交叉熵,也无论是否加上了正则项,那都是求和之后的结果,也就是一个数值。这一点非常重要。

以下是我随意写的一个网络层,可以感受一下

import torch
import torch.nn as nn
class 
<think> 嗯,用户需要解决在Jupyter Notebook中使用PyTorch实现CNN来识别手机照片中的多个数字的问题。用户之前可能尝试过但遇到了环境配置或代码实现的问题。 首先需要明确几个关键点:使用PyTorch框架、在Jupyter Notebook环境下、处理MNIST数据集、最终目标是识别手机照片中的多个数字。用户提供的参考资料提到过jupyter环境配置问题和MNIST加载问题,这些可能是潜在障碍。 环境准备方面,需要确认用户是否已解决Jupyter环境问题。参考资料[2]提到环境配置失败的经历,因此会在回复中强调环境检查要点。考虑建议先测试torchtorchvision的导入,以及GPU可用性检查。 数据处理部分,手机照片与MNIST数据存在显著差异:背景复杂度、光照条件、数字位置等。需要设计预处理流程:灰度化、二值化、尺寸调整、对比度增强等。特别是多数字识别需要添加连通域分析步骤来分割单个数字。 模型构建采用经典LeNet结构比较合适,三层卷积加全连接的架构在MNIST上表现良好且训练快速。考虑在回复中给出完整模型定义代码并注释关键层作用。 训练流程需要包含数据增强(旋转、平移)以提高泛化能力。考虑提醒用户注意batch size设置不宜过大,并建议使用学习率衰减策略。 多数字识别是核心难点。准备详细说明连通域分析+轮廓检测的组合方法:通过Otsu阈值法、findContours获取数字区域、外接矩形分割。这个方案在开源项目中被验证有效。 测试阶段要强调预处理一致性。考虑给出具体代码示例展示如何将手机照片转换成MNIST风格的28x28灰度图。 最后考虑补充部署建议,比如模型保存加载方法和可能的性能优化方向(如量化)。所有代码片段都准备用Jupyter兼容格式,并标注关键参数调整点。 用户可能卡在预处理或数字分割环节,因此回复中会把多数字处理流程拆解为可操作的步骤,并提供备选方案(如YOLO检测的可行性)。参考的文献会集中在PyTorch官方文档和OpenCV图像处理教程。 </think> 以下是在Jupyter Notebook中使用PyTorch实现CNN识别手机照片中多个数字的完整解决方案: ### 一、环境准备(关键步骤) ```python # 验证环境配置 import torch import torchvision print("PyTorch版本:", torch.__version__) print("CUDA可用:", torch.cuda.is_available()) # 确保GPU支持 ``` *若遇环境问题*: 1. 创建独立环境:`conda create -n pytorch_env python=3.8` 2. 安装依赖:`conda install pytorch torchvision jupyter -c pytorch` 3. 添加内核:`python -m ipykernel install --user --name=pytorch_env` ### 二、数据处理与模型构建 ```python import torch.nn as nn import torch.optim as optim from torchvision import transforms, datasets # 数据预处理管道(适配手机图像) transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), # 手机图转灰度 transforms.Resize((28, 28)), # MNIST标准尺寸 transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST标准化参数 ]) # 构建CNN模型(LeNet改进版) class DigitCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64*7*7, 128) self.fc2 = nn.Linear(128, 10) # 10分类 def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 64*7*7) x = torch.relu(self.fc1(x)) return self.fc2(x) model = DigitCNN() ``` ### 三、模型训练(MNIST预训练) ```python # 加载MNIST数据集 train_set = datasets.MNIST('./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) # 训练配置 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练循环 for epoch in range(5): # 5个epoch足够收敛 for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f'Epoch {epoch+1} Loss: {loss.item():.4f}') ``` ### 四、手机照片多数字识别(核心实现) ```python import cv2 import numpy as np from PIL import Image def recognize_digits(image_path, model): # 手机图像预处理 img = cv2.imread(image_path) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # 查找数字轮廓 contours = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours = contours[0] if len(contours) == 2 else contours[1] digits = [] for cnt in contours: x,y,w,h = cv2.boundingRect(cnt) if w>10 and h>20: # 过滤噪声 digit_img = gray[y:y+h, x:x+w] digit_img = Image.fromarray(digit_img).convert('L') digit_tensor = transform(digit_img).unsqueeze(0) # 添加batch维度 # 模型预测 with torch.no_grad(): output = model(digit_tensor) pred = output.argmax(dim=1).item() digits.append((x, pred)) # 按x坐标排序(从左到右) return sorted(digits, key=lambda x: x[0]) # 示例使用 result = recognize_digits("phone_photo.jpg", model) print("识别结果:", "".join([str(d[1]) for d in result])) ``` ### 五、关键优化点 1. **数据增强**(提升泛化性): ```python transform_train = transforms.Compose([ transforms.RandomRotation(10), transforms.RandomAffine(degrees=0, translate=(0.1,0.1)), ... # 原转换 ]) ``` 2. **模型改进**: ```python # 添加Dropout防止过拟合 self.dropout = nn.Dropout(0.5) # 在fc1后添加 x = self.dropout(torch.relu(self.fc1(x))) ``` 3. **手机图像适配技巧**: - 光照归一化:`cv2.createCLAHE(clipLimit=2.0).apply(gray)` - 形态学操作去噪: ```python kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3)) cleaned = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel) ``` ### 六、部署建议 1. 模型保存/加载: ```python torch.save(model.state_dict(), 'digit_model.pth') model.load_state_dict(torch.load('digit_model.pth')) ``` 2. 转ONNX提升兼容性: ```python torch.onnx.export(model, torch.randn(1,1,28,28), "digit.onnx") ``` > **重要提示**:手机照片需满足: > - 数字与背景对比度明显(建议深色背景) > - 避免强反光/阴影干扰 > - 单个数字高度建议>30像素
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值