一、背景
1、代码复现|Demucs Music Source Separation_demucs架构原理-优快云博客
2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方
3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?
4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?
7、Hybrid Transformer 拆解频域解码模块、ISTFT模块
从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。已完成解读:STFT模块、频域编码模块(时域编码和频域编码类似,后续不再解读时域编码模块)、频域解码模块(时域解码和频域解码类似,后续不再解读频域解码模块)、ISTFT模块。
本篇目标:拆解Cross-Domain Transformer Encoder模块。
二、Cross-Domain Transformer Encoder模块
2.1 Cross-Domain Transformer Encoder模块的组成
Cross-Domain Transformer Encoder 核心源代码如下所示:
class CrossTransformerEncoder(nn.Module):
def __init__(
self,
dim: int,
emb: str = "sin",
hidden_scale: float = 4.0,
num_heads: int = 8,
num_layers: int = 6,
cross_first: bool = False,
dropout: float = 0.0,
max_positions: int = 1000,
norm_in: bool = True,
norm_in_group: bool = False,
group_norm: int = False,
norm_first: bool = False,
norm_out: bool = False,
max_period: float = 10000.0,
weight_decay: float = 0.0,
lr: tp.Optional[float] = None,
layer_scale: bool = False,
gelu: bool = True,
sin_random_shift: int = 0,
weight_pos_embed: float = 1.0,
cape_mean_normalize: bool = True,
cape_augment: bool = True,
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
sparse_self_attn: bool = False,
sparse_cross_attn: bool = False,
mask_type: str = "diag",
mask_random_seed: int = 42,
sparse_attn_window: int = 500,
global_window: int = 50,
auto_sparsity: bool = False,
sparsity: float = 0.95,
):
super().__init__()
"""
"""
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.num_layers = num_layers
# classic parity = 1 means that if idx%2 == 1 there is a
# classical encoder else there is a cross encoder
self.classic_parity = 1 if cross_first else 0
self.emb = emb
self.max_period = max_period
self.weight_decay = weight_decay
self.weight_pos_embed = weight_pos_embed
self.sin_random_shift = sin_random_shift
if emb == "cape":
self.cape_mean_normalize = cape_mean_normalize
self.cape_augment = cape_augment
self.cape_glob_loc_scale = cape_glob_loc_scale
if emb == "scaled":
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
self.lr = lr
activation: tp.Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
self.norm_in_t: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
self.norm_in_t = nn.LayerNorm(dim)
elif norm_in_group:
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
else:
self.norm_in = nn.Identity()
self.norm_in_t = nn.Identity()
# spectrogram layers
self.layers = nn.ModuleList()
# temporal layers