工业缺陷检测实战——磁瓦表面缺陷分类

 第一步:准备数据

6种磁瓦表面缺陷:self.class_indict = ["MT_Blowhole", "MT_Break", "MT_Crack", "MT_Fray", "MT_Free", "MT_Uneven"]

,总共有1330张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个ConvNext网络,其原理介绍如下:

ConvNext (Convolutional Network Net Generation), 即下一代卷积神经网络, 是近些年来 CV 领域的一个重要发展. ConvNext 由 Facebook AI Research 提出, 仅仅通过卷积结构就达到了与 Transformer 结构相媲美的 ImageNet Top-1 准确率, 这在近年来以 Transformer 为主导的视觉问题解决趋势中显得尤为突出.

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import json
import os
import argparse
import time

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets

from model import convnext_tiny as create_model
from utils import  create_lr_scheduler, get_params_groups, train_one_epoch, evaluate,plot_class_preds


def main(args):
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    # 创建定义文件夹以及文件
    filename = 'record.txt'
    save_path = 'runs'
    path_num = 1
    while os.path.exists(save_path + f'{path_num}'):
        path_num += 1
    save_path = save_path + f'{path_num}'
    os.mkdir(save_path)
    f = open(save_path + "/" + filename, 'w')
    f.write("{}\n".format(args))

    # print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    # 实例化SummaryWriter对象
    # #######################################
    tb_writer = SummaryWriter(log_dir=save_path + "/experiment")
    if os.path.exists(save_path + "/weights") is False:
        os.makedirs(save_path + "/weights")

    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_data_set = datasets.ImageFolder(root=os.path.join(args.data_path, "train"),
                                          transform=data_transform["train"])

    # 实例化验证数据集
    val_data_set = datasets.ImageFolder(root=os.path.join(args.data_path, "val"),
                                        transform=data_transform["val"])

    # 生成class_indices.json文件,包括有模型对应的序列号
    # #######################################
    classes_list = train_data_set.class_to_idx
    cla_dict = dict((val, key) for key, val in classes_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw)

    val_loader = torch.utils.data.DataLoader(val_data_set,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw)

    model = create_model(num_classes=args.num_classes).to(device)

    # Write the model into tensorboard
    # #######################################
    init_img = torch.zeros((1, 3, 224, 224), device=device)
    tb_writer.add_graph(model, init_img)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        # weights_dict = torch.load(args.weights, map_location=device)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    # pg = [p for p in model.parameters() if p.requires_grad]
    pg = get_params_groups(model, weight_decay=args.wd)
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd)
    lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
                                       warmup=True, warmup_epochs=10)

    best_acc = 0.0
    for epoch in range(args.epochs):
        # 计时器time_start
        time_start = time.time()
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch,
                                                lr_scheduler=lr_scheduler)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)
        time_end = time.time()
        f.write("[epoch {}] train_loss: {:.3f},train_acc:{:.3f},val_loss:{:.3f},val_acc:{:.3f},Spend_time:{:.3f}S"
                .format(epoch + 1, train_loss, train_acc, val_loss, val_acc, time_end - time_start))
        f.flush()

        # add Training results into tensorboard
        # #######################################
        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        # add figure into tensorboard
        # #######################################
        fig = plot_class_preds(net=model,
                               images_dir=r"plot_img",
                               transform=data_transform["val"],
                               num_plot=6,
                               device=device)
        if fig is not None:
            tb_writer.add_figure("predictions vs. actuals",
                                 figure=fig,
                                 global_step=epoch)

        if val_acc >= best_acc:
            best_acc = val_acc
            f.write(',save best model')
            torch.save(model.state_dict(), save_path + "/weights/bestmodel.pth")
        f.write('\n')
    f.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=6)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--wd', type=float, default=5e-2)
    parser.add_argument('--data-path', type=str,
                        default=r"E:\Industrial_inspection\data\MagneticTile_c")
    parser.add_argument('--weights', type=str,
                        default=r"convnext_tiny_1k_224_ema.pth",
                        help='initial weights path')
    # 是否冻结head以外所有权重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷

工业缺陷检测实战——磁瓦表面缺陷分类_哔哩哔哩_bilibili

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI街潜水的八角

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值