YOLO代码详解(二)

1 数据准备

def run():
    print_environment_info()
    parser = argparse.ArgumentParser(description="Trains the YOLO model.")
    parser.add_argument("-m", "--model", type=str, default="config/yolov3.cfg", help="Path to model definition file (.cfg)")
    parser.add_argument("-d", "--data", type=str, default="config/coco.data", help="Path to data config file (.data)")
    parser.add_argument("-e", "--epochs", type=int, default=300, help="Number of epochs")
    parser.add_argument("-v", "--verbose", action='store_true', help="Makes the training more verbose")
    parser.add_argument("--n_cpu", type=int, default=8, help="Number of cpu threads to use during batch generation")
    parser.add_argument("--pretrained_weights", type=str, help="Path to checkpoint file (.weights or .pth). Starts training from checkpoint model")
    parser.add_argument("--checkpoint_interval", type=int, default=1, help="Interval of epochs between saving model weights")
    parser.add_argument("--evaluation_interval", type=int, default=1, help="Interval of epochs between evaluations on validation set")
    parser.add_argument("--multiscale_training", action="store_true", help="Allow multi-scale training")
    parser.add_argument("--iou_thres", type=float, default=0.5, help="Evaluation: IOU threshold required to qualify as detected")
    parser.add_argument("--conf_thres", type=float, default=0.1, help="Evaluation: Object confidence threshold")
    parser.add_argument("--nms_thres", type=float, default=0.5, help="Evaluation: IOU threshold for non-maximum suppression")
    parser.add_argument("--logdir", type=str, default="logs", help="Directory for training log files (e.g. for TensorBoard)")
    parser.add_argument("--seed", type=int, default=-1, help="Makes results reproducable. Set -1 to disable.")
    args = parser.parse_args()
    print(f"Command line arguments: {
     args}")

    if args.seed != -1:
        provide_determinism(args.seed)

    logger = Logger(args.logdir)  # Tensorboard logger

    # Create output directories if missing
    os.makedirs("output", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    # Get data configuration
    data_config = parse_data_config(args.data)
    train_path = data_config["train"]
    valid_path = data_config["valid"]
    class_names = load_classes(data_config["names"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • if args.seed != -1:
    provide_determinism(args.seed)
    调整随机数生成的种子

  • logger = Logger(args.logdir)TensorBoard
    用于指定日志文件的目录或路径。通过将该目录或路径传递给Logger的构造函数,我们可以创建一个日志记录器实例,以便在指定的目录或路径下记录日志。

  • os.makedirs(“output”, exist_ok=True)
    训练结果

  • os.makedirs(“checkpoints”, exist_ok=True)
    checkpoints文件夹通常用于存储模型的训练检查点。在机器学习中,训练模型是一个迭代的过程,每次迭代都会更新模型的参数。为了在训练过程中进行断点续训或保存模型的中间状态,可以将模型的参数和其他相关信息保存在checkpoints文件夹中。

2 数据预处理

在深度学习中,Dataset和DataLoader是两个常用的类,用于加载和处理训练和测试数据。

Dataset:

Dataset是一个抽象类,用于表示数据集。它提供了一种统一的接口来访问训练和测试数据。通常情况下,我们需要自定义一个子类来实现具体的数据集,该子类需要继承Dataset并实现两个主要方法:
len():返回数据集中样本的数量。
getitem(index):返回指定索引的样本数据。可以根据索引从磁盘、内存或其他位置加载数据,并对其进行预处理。

DataLoader:

DataLoader是一个用于批量加载数据的实用程序类。它接受一个Dataset对象作为输入,并以批次的形式提供数据。DataLoader提供了以下功能:
数据批处理:将数据划分为小批次,便于模型的训练和优化。
数据随机化:在每个训练迭代中,可以随机化数据的顺序,以增加训练的多样性和泛化能力。
并行加载:可以使用多个数据加载器并行地加载数据,加快数据的读取速度。
数据预处理:可以在DataLoader中使用各种数据转换和增强操作,如缩放、裁剪、标准化等。
通过将Dataset与DataLoader结合使用,可以方便地加载和处理训练和测试数据,并将其提供给深度学习模型进行训练和推理。Dataset负责表示数据集本身,而DataLoader负责批量加载和处理数据,提供数据的迭代和批处理功能,从而更好地适应深度学习模型的训练需求。

加载模型

model = load_model(args.model, args.pretrained_weights)

TODO

if args.verbose:
    summary(model, input_size=(3, model.hyperparams['height'], model.hyperparams['height']))
mini_batch_size 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值