2021SC@SDUSC
1.class TriangleMultiplication源码
class TriangleMultiplication(hk.Module):
def __init__(self, config, global_config, name='triangle_multiplication'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
del is_training
c = self.config
gc = self.global_config
mask = mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
input_act = act
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)

最低0.47元/天 解锁文章
9593

被折叠的 条评论
为什么被折叠?



