【代码详解】Triplane Meets Gaussian Splatting中triplane部分解析

今天梳理一下TriplaneGaussian的代码逻辑,文章的简介可以先看这一篇博文

项目地址:Github


在这里插入图片描述
接下来我将挑选重点来梳理,逐行解析会在代码注释里。

triplane.py

既然是Triplane Meets Gaussian,我们就先从models/tokenizers/triplane.py看起。

这段代码实现了一个可学习的Triplane位置编码模块,主要用于将特征映射到三张正交平面的表示形式,便于后续3D体素或点采样操作。

class TriplaneLearnablePositionalEmbedding(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        plane_size: int = 32 # 每个平面的分辨率
        num_channels: int = 1024 # 每个平面的通道数
        # 一共会有3个平面(XY、XZ、YZ),因此总token数为3*plane_size^2

    cfg: Config

    def configure(self) -> None:
        super().configure()
        # 初始化可学习的triplane embedding
        self.embeddings = nn.Parameter(
            torch.randn(
                (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
                dtype=torch.float32,
            )
            * 1
            / math.sqrt(self.cfg.num_channels) # 这里是Xavier初始化风格的缩放,避免梯度爆炸
        )

    def forward(self, batch_size: int, cond_embeddings: Float[Tensor, "B Ct"] = None) -> Float[Tensor, "B Ct Nt"]:
        embeddings = repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size) # 为每个batch复制一份
        if cond_embeddings is not None:
            # 如果有cond_embeddings(例如图像特征条件),则加到每个平面上做条件调制
            embeddings = embeddings + cond_embeddings
        # 最后展平为 (B, Ct, Nt),其中Nt = 3 * H * W,相当于将三平面的空间像素展平成 token
        return rearrange(
            embeddings,
            "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
        )

    def detokenize(
        self, tokens: Float[Tensor, "B Ct Nt"]
    ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
        # 这里相当于是forward的逆操作
        batch_size, Ct, Nt = tokens.shape
        assert Nt == self.cfg.plane_size**2 * 3
        assert Ct == self.cfg.num_channels
        return rearrange(
            tokens,
            "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
            Np=3,
            Hp=self.cfg.plane_size,
            Wp=self.cfg.plane_size,
        )

renderer.py

我们再来看看models/renderer.py中的操作,这里我们略过一些矩阵变换的工具函数和Camera、GaussianModel的类。

GSLayer

先看GSLayer,它将输入特征映射为高斯的各个属性。这个类的作用主要是将输入特征映射成高斯参数(位置偏移、密度、缩放、旋转、球谐系数等)。

class GSLayer(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        in_channels: int = 128 # 该层输入的特征通道数
        feature_channels: dict = field(default_factory=dict) # 定义每个输出特征的输出维度。
        xyz_offset: bool = True # 是否预测位置偏移
        restrict_offset: bool = False # 是否限制offset
        use_rgb: bool = False # 果为 True,shs 输出直接变为 RGB,而不是球谐系数
        clip_scaling: Optional[float] = None # 对预测的 scaling 做截断,防止数值爆炸
        init_scaling: float = -5.0
        init_density: float = 0.1

    cfg: Config

    def configure(self, *args, **kwargs) -> None:
        self.out_layers = nn.ModuleList()
        # 遍历 feature_channels 中的每一个 key(特征类型)
        for key, out_ch in self.cfg.feature_channels.items():
            if key == "shs" and self.cfg.use_rgb:
            # 如果 shs 且 use_rgb=True,输出通道变为 3(直接预测 RGB)
                out_ch = 3
            # 使用 nn.Linear 将 in_channels 线性映射到目标通道数 out_ch
            layer = nn.Linear(self.cfg.in_channels, out_ch)

            # initialize
            # 对高斯参数权重和偏置全初始化为 0,表示默认输出接近 0
            if not (key == "shs" and self.cfg.use_rgb):
                nn.init.constant_(layer.weight, 0)
                nn.init.constant_(layer.bias, 0)
            if key == "scaling":
                nn.init.constant_(layer.bias, self.cfg.init_scaling)
            elif key == "rotation":
                nn.init.constant_(layer.bias, 0)
                nn.init.constant_(layer.bias[0], 1.0)
            elif key == "opacity":
                nn.init.constant_(layer.bias, inverse_sigmoid(self.cfg.init_density))

            # 将所有的线性层收集到 self.out_layers,这样在 forward 时可以批量计算
            self.out_layers.append(layer)

    def forward(self, x, pts):
        ret = {}
        for k, layer in zip(self.cfg.feature_channels.keys(), self.out_layers):
            v = layer(x)
            if k == "rotation":
                v = torch.nn.functional.normalize(v)
            elif k == "scaling":
                v = trunc_exp(v)
                if self.cfg.clip_scaling is not None:
                    v = torch.clamp(v, min=0, max=self.cfg.clip_scaling)
            elif k == "opacity":
                v = torch.sigmoid(v)
            elif k == "shs":
                if self.cfg.use_rgb:
                    v = torch.sigmoid(v)
                v = torch.reshape(v, (v.shape[0], -1, 3))
            elif k == "xyz":
                if self.cfg.restrict_offset:
                    max_step = 1.2 / 32
                    v = (torch.sigmoid(v) - 0.5) * max_step
                v = v + pts if self.cfg.xyz_offset else pts
            ret[k] = v

        return GaussianModel(**ret)

这里的思路和MVSGaussian、MVSPlat等方法基本上一脉相承,就是靠网络学参数即可。

GS3DRenderer

这个里面还有一个GS3DRenderer类,它是一个基于3D Gaussian Splatting的渲染器,同时支持基于Triplane的特征查询(query_triplane方法)。

class GS3DRenderer(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        mlp_network_config: Optional[dict] = None # 控制是否使用 MLP 对特征进行进一步处理的配置
        gs_out: dict = field(default_factory=dict) # 传递给 GSLayer 的配置字典(输出通道数等)
        sh_degree: int = 3 # 球谐函数的阶数
        scaling_modifier: float = 1.0 # 控制高斯点的缩放比例
        random_background: bool = False # 是否使用随机背景
        radius: float = 1.0 # 3D场景坐标的半径范围,用于triplane查询的归一化
        feature_reduction: str = "concat" # 特征融合方式,mean或者concat
        projection_feature_dim: int = 773 # 投影特征维度
        background_color: Tuple[float, float, float] = field(
            default_factory=lambda: (1.0, 1.0, 1.0)
        ) # 默认背景颜色

    cfg: Config # 将配置类型与实例绑定

    def configure(self, *args, **kwargs) -> None:
        # 根据 feature_reduction 确定输入特征维度
        if self.cfg.feature_reduction == "mean":
            mlp_in = 80
        elif self.cfg.feature_reduction == "concat":
            mlp_in = 80 * 3
        else:
            raise NotImplementedError
        # 加上 projection_feature_dim 作为额外特征输入
        mlp_in = mlp_in + self.cfg.projection_feature_dim
        if self.cfg.mlp_network_config is not None:
            # 如果提供了 mlp_network_config,会实例化 MLP 将特征映射到 gs_out 需要的通道数
            self.mlp_net = MLP(mlp_in, self.cfg.gs_out.in_channels, **self.cfg.mlp_network_config)
        else:
            # 否则,直接把输入维度作为 gs_out 的输入通道
            self.cfg.gs_out.in_channels = mlp_in
        self.gs_net = GSLayer(self.cfg.gs_out) # 最终的高斯渲染层

    def forward_gs(self, x, p):
        # 先经过 MLP(如果有),再交给 GSLayer 进行高斯相关计算
        if self.cfg.mlp_network_config is not None:
            x = self.mlp_net(x)
        return self.gs_net(x, p)

    # 顾名思义,这个方法负责单视角的高斯渲染
    def forward_single_view(self,
        gs: GaussianModel,
        viewpoint_camera: Camera,
        background_color: Optional[Float[Tensor, "3"]],
        ret_mask: bool = True,
        ):
        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
        screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0
        try:
            screenspace_points.retain_grad()
        except:
            pass
        
        # 背景与光栅化配置
        bg_color = background_color
        # Set up rasterization configuration
        # 相机的水平和垂直视场角的 tan 值,用于光栅化
        tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

        # 设置光栅化参数,包括图像大小、相机矩阵、球谐阶数、背景颜色等
        raster_settings = GaussianRasterizationSettings(
            image_height=int(viewpoint_camera.height),
            image_width=int(viewpoint_camera.width),
            tanfovx=tanfovx,
            tanfovy=tanfovy,
            bg=bg_color,
            scale_modifier=self.cfg.scaling_modifier,
            viewmatrix=viewpoint_camera.world_view_transform,
            projmatrix=viewpoint_camera.full_proj_transform.float(),
            sh_degree=self.cfg.sh_degree,
            campos=viewpoint_camera.camera_center,
            prefiltered=False,
            debug=False
        )

        rasterizer = GaussianRasterizer(raster_settings=raster_settings)

        # 高斯参数
        means3D = gs.xyz
        means2D = screenspace_points
        opacity = gs.opacity

        # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
        # scaling / rotation by the rasterizer.
        scales = None
        rotations = None
        cov3D_precomp = None
        scales = gs.scaling
        rotations = gs.rotation

        # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
        # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
        shs = None
        colors_precomp = None
        if self.gs_net.cfg.use_rgb:
            colors_precomp = gs.shs.squeeze(1)
        else:
            shs = gs.shs

        # Rasterize visible Gaussians to image, obtain their radii (on screen). 
        # 渲染
        with torch.autocast(device_type=self.device.type, dtype=torch.float32):
            rendered_image, radii = rasterizer(
                means3D = means3D,
                means2D = means2D,
                shs = shs,
                colors_precomp = colors_precomp,
                opacities = opacity,
                scales = scales,
                rotations = rotations,
                cov3D_precomp = cov3D_precomp)
        
        # 输出
        ret = {
            "comp_rgb": rendered_image.permute(1, 2, 0),
            "comp_rgb_bg": bg_color
        }
        
        # 可选 Mask 渲染
        if ret_mask:
            mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device)
            raster_settings = GaussianRasterizationSettings(
                image_height=int(viewpoint_camera.height),
                image_width=int(viewpoint_camera.width),
                tanfovx=tanfovx,
                tanfovy=tanfovy,
                bg=mask_bg_color,
                scale_modifier=self.cfg.scaling_modifier,
                viewmatrix=viewpoint_camera.world_view_transform,
                projmatrix=viewpoint_camera.full_proj_transform.float(),
                sh_degree=0,
                campos=viewpoint_camera.camera_center,
                prefiltered=False,
                debug=False
            )
            rasterizer = GaussianRasterizer(raster_settings=raster_settings)
            
            with torch.autocast(device_type=self.device.type, dtype=torch.float32):
                rendered_mask, radii = rasterizer(
                    means3D = means3D,
                    means2D = means2D,
                    # shs = ,
                    colors_precomp = torch.ones_like(means3D),
                    opacities = opacity,
                    scales = scales,
                    rotations = rotations,
                    cov3D_precomp = cov3D_precomp)
                ret["comp_mask"] = rendered_mask.permute(1, 2, 0)

        return ret
    
    def query_triplane(
        self,
        positions: Float[Tensor, "*B N 3"], # 3D 采样点 (B, N, 3)
        triplanes: Float[Tensor, "*B 3 Cp Hp Wp"], # 三平面特征 (B, 3, C, H, W)
    ) -> Dict[str, Tensor]:
        batched = positions.ndim == 3
        # 统一 batch 维度
        if not batched:
            # no batch dimension
            triplanes = triplanes[None, ...]
            positions = positions[None, ...]

        # 坐标归一化将3D点映射到[-1, 1],方便grid_sample采样
        positions = scale_tensor(positions, (-self.cfg.radius, self.cfg.radius), (-1, 1))
        # 生成2D采样索引,也就是3个平面的2D投影坐标:XY, XZ, YZ
        indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
                (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
                dim=-3,
            )
        # 对三平面进行双线性采样
        out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
            rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3),
            rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3),
            align_corners=False,
            mode="bilinear",
        )
        # 特征融合
        if self.cfg.feature_reduction == "concat":
            out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
        elif self.cfg.feature_reduction == "mean":
            out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean")
        else:
            raise NotImplementedError
        
        if not batched:
            out = out.squeeze(0)

        return out

    # 批量渲染,对一批相机逐个调用forward_single_view,再把结果堆叠
    def forward_single_batch(
        self,
        gs_hidden_features: Float[Tensor, "Np Cp"],
        query_points: Float[Tensor, "Np 3"],
        c2ws: Float[Tensor, "Nv 4 4"],
        intrinsics: Float[Tensor, "Nv 4 4"],
        height: int,
        width: int,
        background_color: Optional[Float[Tensor, "3"]],
    ):
        gs: GaussianModel = self.forward_gs(gs_hidden_features, query_points)
        out_list = []
       
        # 遍历所有相机视角
        for c2w, intrinsic in zip(c2ws, intrinsics):
            out_list.append(self.forward_single_view(
                                gs, 
                                Camera.from_c2w(c2w, intrinsic, height, width),
                                background_color
                            ))
        
        out = defaultdict(list)
        for out_ in out_list:
            for k, v in out_.items():
                out[k].append(v)
        out = {k: torch.stack(v, dim=0) for k, v in out.items()}
        out["3dgs"] = gs

        return out

    def forward(self, 
        gs_hidden_features: Float[Tensor, "B Np Cp"], # 批量高斯点的隐藏特征
        query_points: Float[Tensor, "B Np 3"], # 高斯点位置
        c2w: Float[Tensor, "B Nv 4 4"],
        intrinsic: Float[Tensor, "B Nv 4 4"],
        height,
        width,
        additional_features: Optional[Float[Tensor, "B C H W"]] = None, # 额外的特征
        background_color: Optional[Float[Tensor, "B 3"]] = None,
        **kwargs):
        batch_size = gs_hidden_features.shape[0]
        out_list = []
        # 调用 query_triplane 在三平面上采样特征
        gs_hidden_features = self.query_triplane(query_points, gs_hidden_features)
        if additional_features is not None:
            gs_hidden_features = torch.cat([gs_hidden_features, additional_features], dim=-1)

        # 渲染循环
        for b in range(batch_size):
            out_list.append(self.forward_single_batch(
                gs_hidden_features[b],
                query_points[b],
                c2w[b],
                intrinsic[b],
                height, width,
                background_color[b] if background_color is not None else None))

        out = defaultdict(list)
        for out_ in out_list:
            for k, v in out_.items():
                out[k].append(v)
        for k, v in out.items():
            if isinstance(v[0], torch.Tensor):
                out[k] = torch.stack(v, dim=0)
            else:
                out[k] = v
        return out

到这里可以清楚,作者是让网络学到一种能从输入特征中自动生成三平面表示的机制,而不是直接手工定义三平面特征。

query_triplane不是直接用固定的voxel/grid特征,而是基于点特征动态生成三平面表示,然后再从这些三平面中对 query_points做投影采样。这个feature field本质上是一个“能在任意3D点处查询特征的函数”。

其实现方式是:

  • 给每个高斯点分配一个latent feature。
  • query_triplane根据这些latent feature生成三平面特征。
  • 三平面特征被采样后用于渲染。

这样,三平面表示是隐式学到的,而不是显示存储的。这意味着,每个点的三平面特征是由网络预测出来的,可以自适应点分布。

### Triplane Compression Implementation in NeRF Triplane压缩技术在NeRF中的应用主要通过引入三平面(Triplane)表示来简化3D场景的学习过程并提高效率。具体而言,这种方法利用三个相互垂直的二维特征平面存储空间信息[^1]。这些特征平面上的信息可以通过插值或其他操作重新组合成完整的体积密度场和颜色场。 #### 方法概述 在一个典型的基于Triplane的方法中,输入的空间坐标会被映射到对应的三个平面上的位置上。随后,通过对这三个位置上的特征向量进行线性组合或者卷积处理,可以得到最终的颜色和不透明度输出[^2]。这种方式不仅减少了计算负担,还使得网络能够更高效地捕捉复杂的几何形状以及纹理细节。 #### 实现步骤详解 以下是实现过程中涉及的一些关键技术点: 1. **Feature Planes Construction**: 首先构建出三个正交方向上的feature planes, 这些planes通常由CNN生成器创建出来[^3]. 2. **Spatial Mapping & Sampling**: 对于给定的一个世界坐标(x,y,z),将其投影至上述提到过的三个不同的planar spaces里去采样相应的features. 3. **Aggregation Mechanism**: 将来自各个viewpoint sampled features加以融合形成最后预测所需的数据形式; 此处可能涉及到简单的加权求和或是更加复杂的MLP layers来进行non-linear transformations before aggregating them together into one single vector representing the point's properties like color and opacity etc. 4. **Rendering Pipeline Integration**: 整合进传统的volume rendering pipeline当中完成整个渲染流程. ```python import torch.nn as nn class TriPlaneModel(nn.Module): def __init__(self): super(TriPlaneModel, self).__init__() # Define feature extraction networks for each plane (XY, YZ, ZX) self.plane_xy = FeatureExtractor() self.plane_yz = FeatureExtractor() self.plane_zx = FeatureExtractor() # Aggregator network to combine information from all three planes self.aggregator = MLPNetwork() def forward(self, coords): """ Forward pass through tri-plane model. Args: coords: Input spatial coordinates Returns: Combined output after processing via tri-planes aggregation mechanism. """ feat_xy = self.plane_xy(coords[:, :2]) # Sample XY plane at given z location implicitly handled within this function call feat_yz = self.plane_yz(coords[:, 1:]) # Similarly sample other two orthogonal directions' respective slices accordingly here too... feat_zx = self.plane_zx(torch.cat([coords[:, 0].unsqueeze(-1), coords[:, 2].unsqueeze(-1)], dim=-1)) combined_features = torch.stack((feat_xy, feat_yz, feat_zx)).mean(dim=0) result = self.aggregator(combined_features) return result ``` 此代码片段展示了一个基本版本如何定义一个PyTorch模块类用于执行tri-planar encoding/decoding operations inside your custom neural radiance field architecture setup while keeping things modular enough so that you can easily experiment with various architectural choices later down line without much hassle involved during development phase itself! ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YuhsiHu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值