Pytorch学习

本文介绍了在PyTorch中如何自定义Flatten层,以及如何读取图片和进行训练。还探讨了预训练模型的解封和优化部分层的方法,包括在解封后添加新层并关注解封后层的形状,以及针对特定层进行优化的策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. Faltten 层定义

自定义一个Flatten 层

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        # prod元素乘积
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

2 读取图片及对应种类名

自定义一个读一张图片的类:

class Readdata(Dataset):

    def __init__(self, root, resize, mode, namefile):
        """

        :param root:  数据所在目录
        :param resize:  图片统一大小
        :param mode:  train, val, test
        :param namefile: 保存图片名和种类名的文件名
        """
        self.root = root
        self.size = resize
        self.namefile = namefile
        self.name2label = {} # calss-->num
        for name in sorted(os.listdir((root))):

            if not os.path.isdir(os.path.join(root, name)):
                continue
            print(name)
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)


        self.images, self.labels = self.load_csv()


    def load_csv(self):
            # 如果已经有 csv文件,就不会执行写入图片名和种类
            # 如果需要执行可以,自立一个不要样的文件名
            if not os.path.isdir(os.path.join(self.root, self.namefile)):
                images = []
                for name in self.name2label.keys():
                    images += glob.glob(os.path.join(self.root, name, '*jpg'))
                    images += glob.glob(os.path.join(self.root, name, '*jpeg'))
                    images += glob.glob(os.path.join(self.root, name, '*png'))

                # 1165 ['pokemon/pokeman/bulbasaur/00000159.jpg',
                #print(len(images), images)
                random.shuffle(images)
                with open(os.path.join(self.root, self.namefile), mode='w',newline='') as f:
                    writer = csv.writer(f)
                    for img in images:
                        name = img.split(os.sep)[-2]
                        label = self.name2label[name]
                        writer.writerow([img, label])

                    print('writer into cvs file:', self.namefile)

            # read csv
            images, labels = [],[]
            with open(os.path.join(self.root, self.namefile)) as f:
                reader = csv.reader(f)
                for row in reader:
                    img, label = row
                    label = int(label)

                    images.append(img)
                    labels.append(label)

            assert len(images) == len(labels), 'num of imgs != labels'

            return images, labels

    def __len__(self):
        return  len(self.labels)

    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean

        return x

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]

        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize(int(self.size*1.25)),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img, label

读完一张图片后,再读一个batch

    root = './数据目录'
    db = Readdata(root , 32, 'train', 'images.csv')

    x, y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    loader = DataLoader(db, batch_size=16, shuffle=True, num_workers=4)

    for x,y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='label'))

        time.sleep(10)

如果图片是放置好的,可不用自己定义的类,直接使用pytroch 自带的数据读取API


    # #如果是正规数据,不需要自己写ReadData
    # db = torchvision.datasets.ImageFolder(root=root, transform=transforms.Compose([
    #     transforms.Resize((64, 64)),
    #     transforms.ToTensor(),
    # ]))
    # loader = DataLoader(db, batch_size=32, shuffle=True)
    # # print(db.class_to_idx)
    #
    # for x,y in loader:
    #     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    #
    #     time.sleep(10)

3 训练

  1. model
  2. optimizer
  3. loss function
  4. 梯度清零
  5. 反向传播

先定义好1-3

		model = ResNet18(5).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criteon = nn.CrossEntropyLoss()

开始迭代

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

    for step, (x,y) in enumerate(train_loader):

        # x: [b, 3, 224, 224], y: [b]
        x, y = x.to(device), y.to(device)
        
        model.train()
        logits = model(x)
        loss = criteon(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        viz.line([loss.item()], [global_step], win='loss', update='append')
        global_step += 1

    if epoch % 1 == 0:

        val_acc = evalute(model, val_loader)
        if val_acc> best_acc:
            best_epoch = epoch
            best_acc = val_acc

            torch.save(model.state_dict(), 'best.mdl')

            viz.line([val_acc], [global_step], win='val_acc', update='append')


print('best acc:', best_acc, 'best epoch:', best_epoch)

model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')

test_acc = evalute(model, test_loader)
print('test acc:', test_acc)

其中evalute函数

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

预训练模型及解封

  1. 第一种:解封后自己添加层

解开之后需要自己添加别的层,其中需要打印一下解封后的层的shape

  from    torchvision.models import resnet18
   trained_model = resnet18(pretrained=True)
    # 取训练好的0:-1层,最后一层全连接层自己训练
    model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
                          Flatten(), # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)
    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)
  1. 直接优化自己想训练的层
    https://www.jianshu.com/p/d67d62982a24
    https://blog.youkuaiyun.com/TTdreamloong/article/details/84823705

优化部分层:
原文链接:https://blog.youkuaiyun.com/u012494820/article/details/79068625

count = 0
    para_optim = []
    for k in model.children():
        count += 1
        # 6 should be changed properly
        if count > 6:
            for param in k.parameters():
                para_optim.append(param)
        else:
            for param in k.parameters():
                param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值