MaskNet 是一种基于注意力机制的深度学习模型,特别适用于推荐系统和点击率预测任务。它通过动态生成掩码 (mask) 来选择性地关注输入特征的不同部分,从而提高模型的表达能力和泛化性能。
一、核心概念:实例引导的掩码机制
MaskNet 的核心创新是实例引导的掩码机制 (Instance-Guided Masking):
- 传统神经网络对所有输入实例使用相同的网络结构
- MaskNet 为每个输入实例动态生成不同的掩码,从而调整网络结构
- 这种动态调整使模型能够针对不同输入 "智能" 地关注重要特征,忽略噪声
可以将 MaskNet 看作是一种 "自适应网络",每个输入都有自己的 "专属网络结构"。
二、模型架构详解
1. 基础组件:MaskBlock
class MaskBlock(layers.Layer):
def __init__(self, agg_dim, num_mask_block, mask_block_ffn_size, dropout=0., l2_reg=0., **kwargs):
# ... 初始化代码 ...
参数解释:
agg_dim
:特征聚合维度num_mask_block
:掩码块数量(串行或并行)mask_block_ffn_size
:每个掩码块的前馈网络大小
核心功能:
- 实现实例引导的掩码生成和应用
- 支持两种模式:串行 (Serial) 和并行 (Parallel)
2. 掩码生成与应用流程
def instance_guided_mask(self, inputs, is_training, output_size):
agg = self.agg_dnn(inputs) # 特征聚合
if self.dropout_rate > 0:
agg = layers.Dropout(self.dropout_rate)(agg, training=is_training)
return self.mask_proj_layers[str(output_size)](agg) # 生成掩码
def mask_block(self, inputs, mask, block_index, is_training):
masked = inputs * mask # 应用掩码
output = self.serial_blocks[block_index](masked) # 通过前馈网络
output = self.ln_layers[block_index](output) # 层归一化
return tf.nn.relu(output) # ReLU激活
流程:
- 特征聚合:通过神经网络将输入特征压缩为低维表示
- 掩码生成:将聚合后的特征映射为掩码向量(与输入维度相同)
- 掩码应用:将输入与掩码逐元素相乘,选择性地 "关闭" 某些特征
- 变换与归一化:通过前馈网络处理掩码后的特征,再进行归一化和激活
3. 两种模型变体
class SerialMaskNet(MaskBlock):
# 串行模式:多个掩码块依次作用
def serial_model(self, inputs, is_training):
output = inputs
for i in range(self.num_mask_block):
mask = self.instance_guided_mask(output, is_training, self.mask_block_ffn_size[i])
output = self.mask_block(output, mask, i, is_training)
return output
class ParallelMaskNet(MaskBlock):
# 并行模式:多个掩码块同时作用,结果拼接
def parallel_model(self, inputs, is_training):
output_list = []
for i in range(self.num_mask_block):
mask = self.instance_guided_mask(inputs, is_training, self.mask_block_ffn_size[0])
output = self.serial_blocks[0](inputs * mask)
output_list.append(output)
return tf.concat(output_list, axis=-1)
区别:
- 串行模式:多个掩码块按顺序应用,后一个块处理前一个块的输出
- 并行模式:所有掩码块同时处理输入,结果拼接后输出
三、模型整体流程
1. 输入处理与特征嵌入
def build_inputs(self):
# 创建输入层
for feat in self.feature_config_list:
if feat["type"] == "float":
self.inputs[feat_name] = layers.Input(shape=(1,), dtype="float32")
elif feat["type"] == "cate":
self.inputs[feat_name] = layers.Input(shape=(1,), dtype="int64")
# 分类特征嵌入
for feat in self.cate_features:
vocab_size = feat.get("vocab_size", 10000)
self.embedding_layers[feat_name] = layers.Embedding(
input_dim=vocab_size,
output_dim=self.agg_dim
)
# 连续特征投影
for feat in self.float_features:
self.float_projections[feat_name] = layers.Dense(self.agg_dim)
# 合并所有特征
self.deep_layer_features = tf.concat(self.float_features + self.cate_features, axis=-1)
流程:
- 分类特征通过 Embedding 层转换为向量
- 连续特征通过全连接层投影到相同维度
- 所有特征拼接为统一表示
2. 掩码网络处理
def build_model(self):
self.build_inputs()
masknet_output = self.net_func(self.deep_layer_features, is_training=True)
outputs = self.final_layer(masknet_output)
model = super().get_model(inputs=self.inputs, outputs=outputs, name='MaskNet')
return model
流程:
- 输入特征进入 MaskNet 网络
- 网络根据输入动态生成掩码,选择性地关注不同特征
- 处理后的特征通过最终分类层输出预测结果
四、关键技术点
1. 动态掩码机制
掩码是 MaskNet 的核心,它实现了:
- 实例特异性:每个输入实例都有自己的掩码
- 特征选择性:可以选择性地增强或抑制某些特征
- 自适应调整:掩码的生成是自适应的,取决于输入内容
这种机制使模型能够:
- 对不同类型的输入使用不同的特征组合
- 学习特征间的复杂交互关系
- 提高模型在稀疏数据上的表现
2. 串行与并行架构
特性 | 串行模式 | 并行模式 |
---|---|---|
处理方式 | 依次通过多个掩码块 | 同时通过多个掩码块 |
信息流动 | 信息逐步提炼 | 多视角处理后整合 |
计算复杂度 | 较低(顺序执行) | 较高(并行计算) |
适用场景 | 特征间有明确的层次关系 | 特征间关系平等且需要多视角 |
3. 与其他模型的对比
模型 | 特征交互方式 | 结构灵活性 |
---|---|---|
FM/FFM | 固定的二阶特征交互 | 低 |
DeepFM | 固定的 DNN 结构 | 中 |
FiBiNet | 注意力加权的特征交互 | 高 |
MaskNet | 动态生成的实例特异性结构 | 极高 |
五、应用场景与优势
1. 适用场景
- 推荐系统:处理高维稀疏的用户 - 物品交互数据
- 点击率预测:捕捉用户行为与广告 / 商品之间的复杂关系
- 多模态学习:整合文本、图像、数值等多种类型的特征
- 特征选择:在大量特征中自动识别重要特征
2. 主要优势
- 动态结构:针对不同输入动态调整网络结构
- 特征选择性:自动识别和关注重要特征
- 稀疏数据处理:通过掩码机制有效处理稀疏特征
- 可解释性:掩码可以作为特征重要性的一种解释
- 泛化能力:减少过拟合,提高模型在不同场景下的表现
六、训练与调优建议
-
参数配置:
agg_dim
:通常设置为 32-256,根据特征复杂度调整num_mask_block
:串行模式建议 2-4,并行模式建议 3-6mask_block_ffn_size
:与agg_dim
相当或略大
-
训练技巧:
- 使用 Adam 优化器,学习率 0.001-0.0001
- Dropout 率设置为 0.2-0.5,防止过拟合
- L2 正则化系数设置为 1e-5-1e-3
-
模型选择:
- 特征间关系层次分明时,选择串行模式
- 特征间关系平等且需要多视角时,选择并行模式
七、总结
MaskNet 通过实例引导的掩码机制,为深度学习模型带来了动态结构能力,使其能够根据不同输入自适应地调整网络结构。这种创新设计使其在推荐系统、CTR 预测等高维稀疏数据场景中表现出色,同时提高了模型的解释性和泛化能力。