1 数据准备
def run():
print_environment_info()
parser = argparse.ArgumentParser(description="Trains the YOLO model.")
parser.add_argument("-m", "--model", type=str, default="config/yolov3.cfg", help="Path to model definition file (.cfg)")
parser.add_argument("-d", "--data", type=str, default="config/coco.data", help="Path to data config file (.data)")
parser.add_argument("-e", "--epochs", type=int, default=300, help="Number of epochs")
parser.add_argument("-v", "--verbose", action='store_true', help="Makes the training more verbose")
parser.add_argument("--n_cpu", type=int, default=8, help="Number of cpu threads to use during batch generation")
parser.add_argument("--pretrained_weights", type=str, help="Path to checkpoint file (.weights or .pth). Starts training from checkpoint model")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="Interval of epochs between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="Interval of epochs between evaluations on validation set")
parser.add_argument("--multiscale_training", action="store_true", help="Allow multi-scale training")
parser.add_argument("--iou_thres", type=float, default=0.5, help="Evaluation: IOU threshold required to qualify as detected")
parser.add_argument("--conf_thres", type=float, default=0.1, help="Evaluation: Object confidence threshold")
parser.add_argument("--nms_thres", type=float, default=0.5, help="Evaluation: IOU threshold for non-maximum suppression")
parser.add_argument("--logdir", type=str, default="logs", help="Directory for training log files (e.g. for TensorBoard)")
parser.add_argument("--seed", type=int, default=-1, help="Makes results reproducable. Set -1 to disable.")
args = parser.parse_args()
print(f"Command line arguments: {
args}")
if args.seed != -1:
provide_determinism(args.seed)
logger = Logger(args.logdir) # Tensorboard logger
# Create output directories if missing
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
# Get data configuration
data_config = parse_data_config(args.data)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
if args.seed != -1:
provide_determinism(args.seed)
调整随机数生成的种子 -
logger = Logger(args.logdir)TensorBoard
用于指定日志文件的目录或路径。通过将该目录或路径传递给Logger的构造函数,我们可以创建一个日志记录器实例,以便在指定的目录或路径下记录日志。 -
os.makedirs(“output”, exist_ok=True)
训练结果 -
os.makedirs(“checkpoints”, exist_ok=True)
checkpoints文件夹通常用于存储模型的训练检查点。在机器学习中,训练模型是一个迭代的过程,每次迭代都会更新模型的参数。为了在训练过程中进行断点续训或保存模型的中间状态,可以将模型的参数和其他相关信息保存在checkpoints文件夹中。
2 数据预处理
在深度学习中,Dataset和DataLoader是两个常用的类,用于加载和处理训练和测试数据。
Dataset:
Dataset是一个抽象类,用于表示数据集。它提供了一种统一的接口来访问训练和测试数据。通常情况下,我们需要自定义一个子类来实现具体的数据集,该子类需要继承Dataset并实现两个主要方法:
len():返回数据集中样本的数量。
getitem(index):返回指定索引的样本数据。可以根据索引从磁盘、内存或其他位置加载数据,并对其进行预处理。
DataLoader:
DataLoader是一个用于批量加载数据的实用程序类。它接受一个Dataset对象作为输入,并以批次的形式提供数据。DataLoader提供了以下功能:
数据批处理:将数据划分为小批次,便于模型的训练和优化。
数据随机化:在每个训练迭代中,可以随机化数据的顺序,以增加训练的多样性和泛化能力。
并行加载:可以使用多个数据加载器并行地加载数据,加快数据的读取速度。
数据预处理:可以在DataLoader中使用各种数据转换和增强操作,如缩放、裁剪、标准化等。
通过将Dataset与DataLoader结合使用,可以方便地加载和处理训练和测试数据,并将其提供给深度学习模型进行训练和推理。Dataset负责表示数据集本身,而DataLoader负责批量加载和处理数据,提供数据的迭代和批处理功能,从而更好地适应深度学习模型的训练需求。
加载模型
model = load_model(args.model, args.pretrained_weights)
TODO
if args.verbose:
summary(model, input_size=(3, model.hyperparams['height'], model.hyperparams['height']))
mini_batch_size