Sat- nerf深度损失过程

首先损失函数定义在metrics.py,代码如下:

class DepthLoss(torch.nn.Module):
    def __init__(self, lambda_ds=1.0):
        super().__init__()
        # 初始化lambda_ds参数,用于调节深度损失的权重,并且将其缩小为原来的1/3
        self.lambda_ds = lambda_ds / 3.
        # 初始化均方误差损失函数(MSELoss),并设置reduce=False表示不对损失值进行平均
        self.loss = torch.nn.MSELoss(reduce=False)

    def forward(self, inputs, targets, weights=1.):
        # 创建一个字典,用来存储不同类型的深度损失
        loss_dict = {
   
   }

        # 默认使用'coarse'(粗略)类型
        typ = 'coarse'
        # 计算输入深度与目标深度之间的损失,并将结果存入字典
        loss_dict[f'{
     
     typ}_ds'] = self.loss(inputs['depth_coarse'], targets)

        # 如果输入中包含'fine'(精细)类型的深度数据
        if 'depth_fine' in inputs:
            typ = 'fine'
            # 计算精细深度的损失,并存入字典
            loss_dict[f'{
     
     typ}_ds'] = self.loss(inputs['depth_fine'], targets)

        # 对每个损失项应用权重
        for k in loss_dict.keys():
            # 计算加权的平均损失,并乘以lambda_ds来调整损失的权重
            loss_dict[k] = self.lambda_ds * torch.mean(weights * loss_dict[k])

        # 计算所有损失项的总和
        loss = sum(l for l in loss_dict.values())

        # 返回总损失以及包含各个深度损失的字典
        return loss, loss_dict

需要三个输入inputs, targets, weights=1(inputs为输入深度,target为gt,weight为权重),得到两个输出loss, loss_dict(loss为总和,loss_dict为记录单个损失的字典)。

当main函数运行到

system = NeRF_pl(args)  # 初始化 NeRF 模型系统,传入配置参数,为模型训练做好准备工作,确保所有需要的配置和资源都已经到位。

开始调用Nerf_pl 的_init_,会在其中实例化 DepthLoss 类:

self.depth_loss = DepthLoss(lambda_ds=args.ds_lambda)  # 初始化深度损失对象,传入深度监督系数

当运行trainer.fit(system)时,训练启动:

当执行trainer.fit(system)时,Lightning接管了训练过程,
Lightning首先调用prepare_data()准备数据集
然后调用configure_optimizers()设置优化器和学习率调度器

训练循环:

Lightning自动开始训练循环,每个epoch包含:

训练步骤: Lightning自动从train_dataloader()加载数据:

    def train_dataloader(self):
        """创建并返回训练数据加载器字典

        根据配置参数创建不同模态(颜色/深度)的训练数据加载器。当self.depth为True时,
        会同时创建颜色数据和深度数据的加载器。数据加载器使用4个工作进程进行数据加载,
        启用内存锁页(pin_memory)以加速GPU数据传输,并自动进行批次数据打乱。

        Returns:
            dict: 包含数据加载器的字典,键为模态名称("color"/"depth"),
                值为对应的torch.utils.data.DataLoader实例
        """
        # 创建颜色数据的训练集加载器(第一个数据集)
        a = DataLoader(self.train_dataset[0],
                       shuffle=True,
                       num_workers=4,
                       batch_size=self.args.batch_size,
                       pin_memory=True)
        loaders = {
   
   "color": a}

        # 当需要加载深度数据时,创建第二个数据加载器
        if self.depth:
            b = DataLoader(self.train_dataset[1],#数据从上面dataloade 的self.train_dataset[1],这是一个SatelliteDataset_depth类的实例,在prepare_data()方法中创建
                           shuffle=True,
                           num_workers=4,
                           batch_size=self.args.batch_size,
                           pin_memory=True)
            loaders["depth"] = b#通过batch["depth"]访问

        return loaders

可以看到深度数据从上面prepare_data()方法中创建的self.train_dataset[1],这是一个SatelliteDataset_depth类的实例。
接下来,调用training_step()处理每个批次:

    def training_step(self, batch, batch_nb):
        self.log("lr", train_utils.get_learning_rate(self.optimizer))
        self.train_steps += 1

        rays = batch["color"]["rays"] # (B, 11)
        rgbs = batch["color"]["rgbs"] # (B, 3)
        ts = None if not self.use_ts else batch["color"]["ts"].squeeze() # (B, 1)

        results = self(rays, ts)
        if 'beta_coarse' in results and self.get_current_epoch(self.train_steps) < 2:
            loss, loss_dict = self.loss_without_beta(results, rgbs)
        else:
            loss, loss_dict = self.loss(results, rgbs)
        self.args.noise_std *= 0.9

        if self.depth:
            tmp = self(batch["depth"]["rays"], batch["depth"]["ts"].squeeze())
            kp_depths = torch.flatten(batch["depth"]["depths"][:, 0])
            kp_weights = 1. if self.args.ds_noweights else torch.flatten(batch["depth"]["depths"][:, 1])
            loss_depth, tmp = self.depth_loss(tmp, kp_depths, kp_weights)#tmp是作为imput输入进去了,kp_depths是target,kp_weights是权重
            if self.train_steps < self.ds_drop :
                loss += loss_depth
            for k in tmp.keys():
                loss_dict[k] = tmp[k]

        self.log("train/loss", loss)
        typ = "fine" if "rgb_fine" in results else "coarse"

        with torch.no_grad():
            psnr_ = metrics.psnr(results[f"rgb_{
     
     typ}"], rgbs)
            self.log("train/psnr", psnr_)
        for k in loss_dict.keys():
            self.log("train/{}".format
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ajaxm

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值