nerf及sat-nerf源码

神经辐射场(NeRF)是一种利用神经网络来表示和渲染复杂的三维场景的方法。它可以从一组二维图片中学习出一个连续的三维函数,这个函数可以给出空间中任意位置和方向上的颜色和密度。通过体积渲染的技术,NeRF可以从任意视角合成出逼真的图像,包括透明和半透明物体,以及复杂的光线传播效果今天让我们从代码层面上揭秘这个引爆元宇宙三维重建领域的方法。

本次解读只是笔者针对实现NeRF最重要部分的解读,其他细节还需要进一步优化。需要完整代码的读者们可以通过下面两个链接下载获取:

原论文及代码——https://github.com/bmild/nerf

大佬实现的pytorch版本——https://github.com/yenchenlin/nerf-pytorch

首先我们来看一下实现nerf代码的整体结构以及我们需要主要研究的代码函数部分:

图片

▲图|NeRF的算法流程

NeRF主要是下面的几个主要函数构成了大体的程序框架:

1)首先是加载NeRF运行需要的各种参数(包括所用的数据集、数据类型、输出文件位置、训练轮次、bs、数据采样、训练所用的网络参数、训练形式的选择等)详细说明可见下面的代码注释。

2)[config_parser](run_nerf.py)

def config_parser():
    import configargparse
    parser = configargparse.ArgumentParser()

    parser.add_argument('--config', is_config_file=True,
                        help='config file path')
    # 本次实验的名称,作为log中文件夹的名字
    parser.add_argument("--expname", type=str,
                        help='experiment name')
    # 输出目录
    parser.add_argument("--basedir", type=str, default='./logs/',
                        help='where to store ckpts and logs')
    # 指定数据集的目录
    parser.add_argument("--datadir", type=str, default='./data/llff/fern',
                        help='input data directory')

    # training options
    # 全连接的层数
    parser.add_argument("--netdepth", type=int, default=8,
                        help='layers in network')
    # 网络宽度
    parser.add_argument("--netwidth", type=int, default=256,
                        help='channels per layer')

    # 精细网络的全连接层数
    # 默认精细网络的深度和宽度与粗糙网络是相同的
    parser.add_argument("--netdepth_fine", type=int, default=8,
                        help='layers in fine network')
    parser.add_argument("--netwidth_fine", type=int, default=256,
                        help='channels per layer in fine network')

    # 这里的batch size,指的是光线的数量,像素点的数量
    # N_rand 配置文件中是1024
    # 32*32*4=4096
    # 800*800/4096=156 400*400/1024=156
    parser.add_argument("--N_rand", type=int, default=32 * 32 * 4,
                        help='batch size (number of random rays per gradient step)')
    # 学习率
    parser.add_argument("--lrate", type=float, default=5e-4,
                        help='learning rate')
    # 学习率衰减
    parser.add_argument("--lrate_decay", type=int, default=250,
                        help='exponential learning rate decay (in 1000 steps)')

    parser.add_argument("--chunk", type=int, default=1024 * 32,
                        help='number of rays processed in parallel, decrease if running out of memory')

    # 网络中处理的点的数量
    parser.add_argument("--netchunk", type=int, default=1024 * 64,
                        help='number of pts sent through network in parallel, decrease if running out of memory')

    # 合成的数据集一般都是True, 每次只从一张图片中选取随机光线
    # 真实的数据集一般都是False, 图形先混在一起
    parser.add_argument("--no_batching", action='store_true',
                        help='only take random rays from 1 image at a time')

    # 不加载权重
    parser.add_argument("--no_reload", action='store_true',
                        help='do not reload weights from saved ckpt')
    # 粗网络的权重文件的位置
    parser.add_argument("--ft_path", type=str, default=None,
                        help='specific weights npy file to reload for coarse network')

    # rendering options
    parser.add_argument("--N_samples", type=int, default=64,
                        help='number of coarse samples per ray')
    parser.add_argument("--N_importance", type=int, default=0,
                        help='number of additional fine samples per ray')

    parser.add_argument("--perturb", type=float, default=1.,
                        help='set to 0. for no jitter, 1. for jitter')
    # 不适用视角数据
    parser.add_argument("--use_viewdirs", action='store_true',
                        help='use full 5D input instead of 3D')
    # 0 使用位置编码,-1 不使用位置编码
    parser.add_argument("--i_embed", type=int, default=0,
                        help='set 0 for default positional encoding, -1 for none')

    # L=10
    parser.add_argument("--multires", type=int, default=10,
                        help='log2 of max freq for positional encoding (3D location)')
    # L=4
    parser.add_argument("--multires_views", type=int, default=4,
                        help='log2 of max freq for positional encoding (2D direction)')

    parser.add_argument("--raw_noise_std", type=float, default=0.,
                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')

    # 仅进行渲染
    parser.add_argument("--render_only", action='store_true',
                        help='do not optimize, reload weights and render out render_poses path')
    # 渲染test数据集
    parser.add_argument("--render_test", action='store_true',
                        help='render the test set instead of render_poses path')
    # 下采样的倍数
    parser.add_argument("--render_factor", type=int, default=0,
                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')

    # training options
    # 中心裁剪的训练轮数
    parser.add_argument("--precrop_iters", type=int, default=0,
                        help='number of steps to train on central crops')
    parser.add_argument("--precrop_frac", type=float,
                        default=.5, help='fraction of img taken for central crops')

    # dataset options
    # 数据格式
    parser.add_argument("--dataset_type", type=str, default='llff',
                        help='options: llff / blender / deepvoxels')

    # 对于大的数据集,test和val数据集,只使用其中的一部分数据
    parser.add_argument("--testskip", type=int, default=8,
                        help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')

    ## deepvoxels flags
    parser.add_argument("--shape", type=str, default='greek',
                        help='options : armchair / cube / greek / vase')

    ## blender flags
    # 白色背景
    parser.add_argument("--white_bkgd", action='store_true',
                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')

    # 使用一半分辨率
    parser.add_argument("--half_res", action='store_true',
                        help='load blender synthetic data at 400x400 instead of 800x800')

    ## llff flags
    parser.add_argument("--factor", type=int, default=8,
                        help='downsample factor for LLFF images')
    parser.add_argument("--no_ndc", action='store_true',
                        help='do not use normalized device coordinates (set for non-forward facing scenes)')
    parser.add_argument("--lindisp", action='store_true',
                        help='sampling linearly in disparity rather than depth')
    parser.add_argument("--spherify", action='store_true',
                        help='set for spherical 360 scenes')
    parser.add_argument("--llffhold", type=int, default=8,
                        help='will take every 1/N images as LLFF test set, paper uses 8')

    # logging/saving options
    # log输出的频率
    parser.add_argument("--i_print", type=int, default=100,
                        help='frequency of console printout and metric loggin')
    parser.add_argument("--i_img", type=int, default=500,
                        help='frequency of tensorboard image logging')
    # 保存模型的频率
    # 每隔1w保存一个
    parser.add_argument("--i_weights", type=int, default=10000,
                        help='frequency of weight ckpt saving')
    # 执行测试集渲染的频率
    parser.add_argument("--i_testset", type=int, default=50000,
                        help='frequency of testset saving')
    # 执行渲染视频的频率
    parser.add_argument("--i_video", type=int, default=50000,
                        help='frequency of render_poses video saving')

    return parser

1)接下来是各种数据集的加载方式(包括blender、deepvoxels、linemod、llff数据集)接下来以加载blender数据集为例。整个模块以load_blender_data函数为入口读取blender数据图像以及图片的相机坐标系转化到世界坐标系的处理的json文件。通过pose_spherical函数的测算得到我们需要用到的渲染位姿render_poses。

2)[load_blender_data](load_blender.py)

具体的关于相机坐标与世界坐标系之间的转化,相机标定等资料可以参考下面的博客:

旋转矩阵——https://blog.youkuaiyun.com/csxiaoshui/article/details/65446125

相机标定——https://blog.youkuaiyun.com/Kalenee/article/details/99207102


import os
import torch
import numpy as np
import imageio
import json
import cv2

# 平移
trans_t = lambda t: torch.Tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, t],
    [0, 0, 0, 1]]).float()

# 绕x轴的旋转
rot_phi = lambda phi: torch.Tensor([
    [1, 0, 0, 0],
    [0, np.cos(phi), -np.sin(phi), 0],
    [0, np.sin(phi), np.cos(phi), 0],
    [0, 0, 0, 1]]).float()

# 绕y轴的旋转
rot_theta = lambda th: torch.Tensor([
    [np.cos(th), 0, -np.sin(th), 0],
    [0, 1, 0, 0],
    [np.sin(th), 0, np.cos(th), 0],
    [0, 0, 0, 1]]).float()


def pose_spherical(theta, phi, radius):
    """
    theta: -180 -- +180,间隔为9
    phi: 固定值 -30
    radius: 固定值 4
    """
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
    return c2w


def load_blender_data(basedir, half_res=False, testskip=1):
    """
    testskip: test和val数据集,只会读取其中的一部分数据,跳着读取
    """
    splits = ['train', 'val', 'test']
    # 存储了三个json文件的数据
    metas = {
   }
    for s in splits:
        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s == 'train' or testskip == 0:
            skip = 1
        else:
            # 测试集如果数量很多,可能会设置testskip
            skip = testskip
        # 读取所有的图片,以及所有对应的transform_matrix
        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            imgs.append(imageio.imread(fname))
            poses.append(np.array(frame['transform_matrix']))
        # 归一化
        imgs = (np.array(imgs) / 255.).astype(np.float32)  # keep all 4 channels (RGBA),4通道 rgba
        poses = np.array(poses).astype(np.float32)
        # 用于计算train val test的递增值
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)
    # train val test 三个list
    i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
    # train test val 拼一起
    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)

    H, W = imgs[0].shape[:2]
    # meta使用了上面的局部变量,train test val 这个变量值是相同的,文件中这三个值确实是相同的
    camera_angle_x = float(meta['camera_angle_x'])
    # 焦距
    focal = .5 * W / np.tan(.5 * camera_angle_x)

    #  np.linspace(-180, 180, 40 + 1) 9度一个间隔
    # (40,4,4), 渲染的结果就是40帧
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)

    if half_res:
        H = H // 2
        W = W // 2
        # 焦距一半
        focal = focal / 2.

        imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
        for i, img in enumerate(imgs):
            # 调整成一半的大小
            imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res

    return imgs, poses, render_poses, [H, W, focal], i_split

数据集处理完成之后是神经网络创建的部分,首先是对空间点进行位置编码,通过下面的函数进行位置编码:

embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

[create_nerf](run_nerf.py)

[get_embedder](run_nerf_helpers.py)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

▲图|2D图像不加位置编码恢复结果

NeRF的Positional encoding过程就是将空间5D位姿进行傅里叶变化:

在这里插入图片描述

代码的详细注释如下:


# Positional encoding (section 5.1)
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']  # 3
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            # tensor([  1.,   2.,   4.,   8.,  16.,  32.,  64., 128., 256., 512.])
            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                # sin(x),sin(2x),sin(4x),sin(8x),sin(16x),sin(32x),sin(64x),sin(128x),sin(256x),sin(512x)
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns

        # 3D坐标是63,2D方向是27
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


# 位置编码相关
def get_embedder(multires, i=0):
    """
    multires: 3D 坐标是10,2D方向是4
    """
    if i == -1:
        return nn.Identity(), 3

    embed_kwargs = {
   
        'include_input': True,
        'input_dims': 3,
        'max_freq_log2': multires - 1,
        'num_freqs': multires,
        'log_sampling': True,
        'periodic_fns': [torch.sin, torch.cos],
    }

    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj: eo.embed(x)
    # 第一个返回值是lamda,给定x,返回其位置编码
    return embed, embedder_obj.out_dim

做完位置编码后通过下面函数进入粗网络训练

    model = NeRF(D=args.netdepth, W=args.netwidth,   

                            input\_ch=input\_ch,output\_ch=output\_ch, 

                            skips=skips, 

                            input\_ch\_views=input\_ch\_views, 

                            use\_viewdirs=args.use\_viewdirs).to(device)

[NeRF](run_nerf_helpers.py)这个函数结构定义了训练NeRF的神经网络的大小,层数结构。


▲图|NeRF网络结构


class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ajaxm

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值