torchvision.utils.save_image 讲解

文章讲述了在使用PyTorch的`torchvision.utils.save_image`函数保存图像时,normalize参数的不同设置如何影响图像的显示和读取。当normalize为False时,保存的是包含三个通道的伪灰度图像;而normalize为True则会进行归一化处理。

因为这个让我犯过错误,我记录一下。
其他参数百度即可。
主要讲解normalize

a = torch.tensor([0.8, 0.9])
a = a.unsqueeze(0)
save_path = f"111.png"
torchvision.utils.save_image(a
import argparse import random import os import numpy as np import torch import torchvision import torch.backends.cudnn as cudnn import torch.utils.data import src.utils as utils import src.dataset as dataset import crnn.seq2seq as crnn cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--train_list', type=str, help='path to train dataset list file') parser.add_argument('--eval_list', type=str, help='path to evalation dataset list file') parser.add_argument('--num_workers', type=int, default=0, help='number of data loading num_workers') parser.add_argument('--batch_size', type=int, default=32, help='input batch size') parser.add_argument('--img_height', type=int, default=64, help='the height of the input image to network') parser.add_argument('--img_width', type=int, default=480, help='the width of the input image to network') parser.add_argument('--hidden_size', type=int, default=256, help='size of the lstm hidden state') parser.add_argument('--num_epochs', type=int, default=2, help='number of epochs to train for') parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate for Critic, default=0.00005') parser.add_argument('--encoder', type=str, default='', help="path to encoder (to continue training)") parser.add_argument('--decoder', type=str, default='', help='path to decoder (to continue training)') parser.add_argument('--model', default='./model/', help='Where to store samples and models') parser.add_argument('--random_sample', default=True, action='store_true', help='whether to sample the dataset with random sampler') parser.add_argument('--teaching_forcing_prob', type=float, default=0.5, help='where to use teach forcing') parser.add_argument('--max_width', type=int, default=61, help='the width of the feature map out from cnn') cfg = parser.parse_args() print(cfg) # load alphabet with open('./data/alphabet.txt') as f: data = f.readlines() alphabet = [x.rstrip() for x in data] alphabet = ''.join(alphabet) # define convert bwteen string and label index converter = utils.ConvertBetweenStringAndLabel(alphabet) # len(alphabet) + SOS_TOKEN + EOS_TOKEN num_classes = len(alphabet) + 2 def train(image, text, encoder, decoder, criterion, train_loader, teach_forcing_prob=1): # optimizer encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=cfg.learning_rate, betas=(0.5, 0.999)) decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=cfg.learning_rate, betas=(0.5, 0.999)) # loss averager loss_avg = utils.Averager() for epoch in range(cfg.num_epochs): train_iter = iter(train_loader) for i in range(len(train_loader)): cpu_images, cpu_texts = next(train_iter) batch_size = cpu_images.size(0) for encoder_param, decoder_param in zip(encoder.parameters(), decoder.parameters()): encoder_param.requires_grad = True decoder_param.requires_grad = True encoder.train() decoder.train() target_variable = converter.encode(cpu_texts) utils.load_data(image, cpu_images) # CNN + BiLSTM encoder_outputs = encoder(image) #target_variable = target_variable.cuda() # start decoder for SOS_TOKEN decoder_input = target_variable[utils.SOS_TOKEN] decoder_hidden = decoder.initHidden(batch_size) loss = 0.0 teach_forcing = True if random.random() > teach_forcing_prob else False if teach_forcing: for di in range(1, target_variable.shape[0]): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) loss += criterion(decoder_output, target_variable[di]) decoder_input = target_variable[di] else: for di in range(1, target_variable.shape[0]): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) loss += criterion(decoder_output, target_variable[di]) topv, topi = decoder_output.data.topk(1) ni = topi.squeeze() decoder_input = ni encoder.zero_grad() decoder.zero_grad() loss.backward() encoder_optimizer.step() decoder_optimizer.step() loss_avg.add(loss) if i % 10 == 0: print('[Epoch {0}/{1}] [Batch {2}/{3}] Loss: {4}'.format(epoch, cfg.num_epochs, i, len(train_loader), loss_avg.val())) loss_avg.reset() # save checkpoint torch.save(encoder.state_dict(), '{0}/encoder_{1}.pth'.format(cfg.model, epoch)) torch.save(decoder.state_dict(), '{0}/decoder_{1}.pth'.format(cfg.model, epoch)) def evaluate(image, text, encoder, decoder, data_loader, max_eval_iter=100): for e, d in zip(encoder.parameters(), decoder.parameters()): e.requires_grad = False d.requires_grad = False encoder.eval() decoder.eval() val_iter = iter(data_loader) n_correct = 0 n_total = 0 loss_avg = utils.Averager() for i in range(min(len(data_loader), max_eval_iter)): cpu_images, cpu_texts = val_iter.next() batch_size = cpu_images.size(0) utils.load_data(image, cpu_images) target_variable = converter.encode(cpu_texts) n_total += len(cpu_texts[0]) + 1 decoded_words = [] decoded_label = [] encoder_outputs = encoder(image) #target_variable = target_variable.cuda() decoder_input = target_variable[0] decoder_hidden = decoder.initHidden(batch_size) for di in range(1, target_variable.shape[0]): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.data.topk(1) ni = topi.squeeze(1) decoder_input = ni if ni == utils.EOS_TOKEN: decoded_label.append(utils.EOS_TOKEN) break else: decoded_words.append(converter.decode(ni)) decoded_label.append(ni) for pred, target in zip(decoded_label, target_variable[1:,:]): if pred == target: n_correct += 1 if i % 10 == 0: texts = cpu_texts[0] print('pred: {}, gt: {}'.format(''.join(decoded_words), texts)) accuracy = n_correct / float(n_total) print('Test loss: {}, accuray: {}'.format(loss_avg.val(), accuracy)) def main(): if not os.path.exists(cfg.model): os.makedirs(cfg.model) # create train dataset train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list, transform=None) sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.batch_size, shuffle=False, sampler=sampler, num_workers=int(cfg.num_workers), collate_fn=dataset.AlignCollate(img_height=cfg.img_height, img_width=cfg.img_width)) # create test dataset test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list, transform=dataset.ResizeNormalize(img_width=cfg.img_width, img_height=cfg.img_height)) test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=1, num_workers=int(cfg.num_workers)) # create crnn/seq2seq/attention network encoder = crnn.Encoder(channel_size=1, hidden_size=cfg.hidden_size) # for prediction of an indefinite long sequence decoder = crnn.Decoder(hidden_size=cfg.hidden_size, output_size=num_classes, dropout_p=0.1, max_length=cfg.max_width) print(encoder) print(decoder) encoder.apply(utils.weights_init) decoder.apply(utils.weights_init) if cfg.encoder: print('loading pretrained encoder model from %s' % cfg.encoder) encoder.load_state_dict(torch.load(cfg.encoder)) if cfg.decoder: print('loading pretrained encoder model from %s' % cfg.decoder) decoder.load_state_dict(torch.load(cfg.decoder)) # create input tensor image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width) text = torch.LongTensor(cfg.batch_size) criterion = torch.nn.NLLLoss() # assert torch.cuda.is_available(), "Please run \'train.py\' script on nvidia cuda devices." # encoder.cuda() # decoder.cuda() # image = image.cuda() # text = text.cuda() # criterion = criterion.cuda() # train crnn train(image, text, encoder, decoder, criterion, train_loader, teach_forcing_prob=cfg.teaching_forcing_prob) # do evaluation after training evaluate(image, text, encoder, decoder, test_loader, max_eval_iter=100) if __name__ == "__main__": main() (分析这些代码,让小白也可以看懂)
09-25
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import torch.optim.lr_scheduler as lr_scheduler import numpy as np import tkinter as tk from tkinter import Button, messagebox from PIL import Image, ImageDraw, ImageOps import os # ====================== 训练部分 ====================== # # 数据预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差 ]) # 加载训练数据 train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # 增大batch size # 加载测试数据 test_dataset = datasets.MNIST('data', train=False, transform=transform) test_loader = DataLoader(test_dataset, batch_size=1000) # 定义改进模型 class Improved_MNIST_CNN(nn.Module): def __init__(self): super(Improved_MNIST_CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2) ) self.layer3 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2) ) self.fc1 = nn.Linear(128 * 3 * 3, 512) self.dropout1 = nn.Dropout(0.4) self.fc2 = nn.Linear(512, 128) self.dropout2 = nn.Dropout(0.3) self.fc3 = nn.Linear(128, 10) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = x.view(x.size(0), -1) x = self.fc1(x) x = self.dropout1(x) x = self.fc2(x) x = self.dropout2(x) x = self.fc3(x) return x # 初始化模型、损失函数和优化器 model = Improved_MNIST_CNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7) # 学习率衰减 # 训练函数 def train_model(model, train_loader, test_loader, optimizer, scheduler, epochs=15): model.train() best_accuracy = 0.0 for epoch in range(epochs): running_loss = 0.0 for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() running_loss += loss.item() if batch_idx % 100 == 0: print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}') # 更新学习率 scheduler.step() # 每个epoch结束后在测试集上评估 accuracy = test_model(model, test_loader) avg_loss = running_loss / len(train_loader) print(f'Epoch {epoch+1} completed, Avg Loss: {avg_loss:.6f}, Test Accuracy: {accuracy:.2f}%') # 保存最佳模型 if accuracy > best_accuracy: best_accuracy = accuracy torch.save(model.state_dict(), 'mnist_model_best.pth') print(f"Saved best model with accuracy: {best_accuracy:.2f}%") return best_accuracy # 测试函数 def test_model(model, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() accuracy = 100. * correct / total model.train() return accuracy # 训练模型 def run_training(): print("开始训练模型...") best_accuracy = train_model(model, train_loader, test_loader, optimizer, scheduler, epochs=15) print(f"训练完成! 最佳准确率: {best_accuracy:.2f}%") # 保存最终模型 torch.save(model.state_dict(), 'mnist_model_final.pth') print("模型已保存为: mnist_model_final.pth") return best_accuracy # ====================== 识别部分 ====================== # # 加载训练好的模型 def load_model(model_path='mnist_model_best.pth'): model = Improved_MNIST_CNN() try: if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() print(f"成功加载模型: {model_path}") return model else: print(f"警告: 找不到模型文件 '{model_path}'") return None except Exception as e: print(f"加载模型时出错: {e}") return None # 手写数字识别应用 class DigitRecognizer: def __init__(self, model): self.model = model self.transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 创建主窗口 self.root = tk.Tk() self.root.title("MNIST手写数字识别") self.root.geometry("400x500") # 标题 self.title_label = tk.Label(self.root, text="手写数字识别", font=("Arial", 16)) self.title_label.pack(pady=10) # 创建画布 self.canvas_width = 280 self.canvas_height = 280 self.canvas = tk.Canvas( self.root, width=self.canvas_width, height=self.canvas_height, bg="white", cursor="cross" ) self.canvas.pack(pady=10) # 绑定鼠标事件 self.canvas.bind("<B1-Motion>", self.draw) # 创建PIL图像 self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255) self.draw_img = ImageDraw.Draw(self.image) # 按钮框架 button_frame = tk.Frame(self.root) button_frame.pack(pady=10) # 识别按钮 self.recognize_btn = Button( button_frame, text="识别", command=self.recognize, width=10, height=2, bg="#4CAF50", fg="white", font=("Arial", 12) ) self.recognize_btn.pack(side=tk.LEFT, padx=10) # 清除按钮 self.clear_btn = Button( button_frame, text="清除", command=self.reset, width=10, height=2, bg="#F44336", fg="white", font=("Arial", 12) ) self.clear_btn.pack(side=tk.LEFT, padx=10) # 结果标签 self.result_label = tk.Label( self.root, text="结果: 请书写数字并点击'识别'", font=("Arial", 14), pady=10 ) self.result_label.pack() # 状态栏 self.status_var = tk.StringVar() self.status_var.set("就绪") self.status_bar = tk.Label( self.root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W ) self.status_bar.pack(side=tk.BOTTOM, fill=tk.X) print("请在画布上书写数字,然后点击'识别'按钮...") def reset(self): self.canvas.delete("all") self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255) self.draw_img = ImageDraw.Draw(self.image) self.result_label.config(text="结果: 请书写数字并点击'识别'") self.status_var.set("画布已清除") def draw(self, event): x, y = event.x, event.y r = 10 # 笔触半径 self.canvas.create_oval(x-r, y-r, x+r, y+r, fill="black", outline="black") self.draw_img.ellipse([x-r, y-r, x+r, y+r], fill=0) def preprocess(self): # 反转颜色:黑底白字 -> 白底黑字 (符合MNIST格式) inverted_img = ImageOps.invert(self.image) # 找到数字的边界 bbox = inverted_img.getbbox() if not bbox: return None # 裁剪数字 cropped = inverted_img.crop(bbox) # 计算缩放比例,保持宽高比 width, height = cropped.size max_dim = max(width, height) scale = 20.0 / max_dim # 缩放至20像素内 # 创建新图像并居中放置 new_width = int(width * scale) new_height = int(height * scale) resized = cropped.resize((new_width, new_height), Image.LANCZOS) # 创建28x28空白图像 final_img = Image.new("L", (28, 28), 0) # 背景为黑色 # 计算放置位置(居中) x_offset = (28 - new_width) // 2 y_offset = (28 - new_height) // 2 final_img.paste(resized, (x_offset, y_offset)) return final_img def recognize(self): if self.model is None: messagebox.showerror("错误", "模型未加载成功,请先训练模型") return processed_img = self.preprocess() if processed_img is None: self.status_var.set("错误: 未检测到书写内容") messagebox.showwarning("警告", "未检测到书写内容,请在画布上书写数字") return # 转换为张量 tensor = self.transform(processed_img).unsqueeze(0) # 预测 with torch.no_grad(): output = self.model(tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) _, predicted = torch.max(output, 1) digit = predicted.item() confidence = probabilities[digit].item() * 100 self.result_label.config(text=f"识别结果: {digit} (置信度: {confidence:.1f}%)") self.status_var.set(f"识别完成: {digit} (置信度: {confidence:.1f}%)") # 显示处理后的图像(可选) # processed_img.show() # 主函数 def main(): # 检查模型是否存在 model_path = 'mnist_model_best.pth' model = None if os.path.exists(model_path): model = load_model(model_path) else: print("未找到预训练模型,开始训练新模型...") run_training() model = load_model(model_path) if model: # 创建识别器 recognizer = DigitRecognizer(model) recognizer.root.mainloop() if __name__ == "__main__": main()这是手写数字识别代码,写实训报告大约8000字包括代码
07-01
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值