SALAD: Skeleton-aware Latent Diffusion for Text-driven Motion Generation and Editing
Seokhyeon Hong, Chaelin Kim, Serin Yoon, Junghyun Nam, Sihun Cha, Junyong Noh
论文地址:https://arxiv.org/pdf/2503.13836
code:https://github.com/seokhyeonhong/salad
创新点
1、Skeleton-aware VAE
之前的 VAE 或者 VQ-VAE 都是把一帧里所有关节点当成一个向量来处理(treat a pose as a single vector),让网络自己建模关节点之间的关系;这篇文章是引入了骨架点的拓扑结构,显式地引入了相邻关节点、相邻帧之间的信息,促进相邻帧和相邻节点之间的信息交换(有点 ST-GCN 的意思),以此学习到“骨架感知的运动表征”(skeleton-aware motion latent representation)。

1.1 motion representation
首先,先将原来 263 维的 pose vector 分为 joint-wise 特征:
其中,N 是 motion 的帧数。
salad/models/vae/trainer.py at main · seokhyeonhong/salad · GitHub
motion = batch_data.to(self.opt.device, dtype=torch.float32)
root, ric, rot, vel, contact = torch.split(motion, [4, 3 * (self.opt.joints_num - 1), 6 * (self.opt.joints_num - 1), 3 * self.opt.joints_num, 4], dim=-1)
salad/models/vae/encdec.py at main · seokhyeonhong/salad · GitHub
class MotionEncoder(nn.Module):
def __init__(self, opt):
super(MotionEncoder, self).__init__()
self.pose_dim = opt.pose_dim
self.joints_num = (self.pose_dim + 1) // 12
self.latent_dim = opt.latent_dim
self.contact_joints = opt.contact_joints
self.layers = nn.ModuleList()
for i in range(self.joints_num):
if i == 0:
input_dim = 7
elif i in self.contact_joints:
input_dim = 13
else:
input_dim = 12
self.layers.append(nn.Sequential(
nn.Linear(input_dim, self.latent_dim),
get_activation(opt.activation),
nn.Linear(self.latent_dim, self.latent_dim),
))
def forward(self, x):
"""
x: [bs, nframes, pose_dim]
nfeats = 12J + 1
- root_rot_velocity (B, seq_len, 1)
- root_linear_velocity (B, seq_len, 2)
- root_y (B, seq_len, 1)
- ric_data (B, seq_len, (joint_num - 1)*3)
- rot_data (B, seq_len, (joint_num - 1)*6)
- local_velocity (B, seq_len, joint_num*3)
- foot contact (B, seq_len, 4)
"""
B, T, D = x.size()
# split
root, ric, rot, vel, contact = torch.split(x, [4, 3 * (self.joints_num - 1), 6 * (self.joints_num - 1), 3 * self.joints_num, 4], dim=-1)
ric = ric.reshape(B, T, self.joints_num - 1, 3)
rot = rot.reshape(B, T, self.joints_num - 1, 6)
vel = vel.reshape(B, T, self.joints_num, 3)
# joint-wise input
joints = [torch.cat([root, vel[:, :, 0]], dim=-1)] # [B, T, 7]]
for i in range(1, self.joints_num):
joints.append(torch.cat([ric[:, :, i - 1], rot[:, :, i - 1], vel[:, :, i]], dim=-1))
fo

最低0.47元/天 解锁文章

被折叠的 条评论
为什么被折叠?



