这篇论文提出的核心创新主要集中在新型联邦适应范式 (FedPA) 和个性化模型改进上,特别是结合了预训练模型的丰富知识并确保用户隐私。以下是该论文创新点与代码的综合分析:
1. 新的联邦适应范式 FedPA
- 预训练模型与隐私保护结合:FedPA 是一种基于基础模型的联邦推荐适应范式,将预训练模型的知识作为联邦推荐系统的起点,旨在利用这些知识为联邦学习提供热启动,同时保持用户数据隐私。代码中
MLP模型类中的embedding_user_features和embedding_video_features是通过预训练生成的用户和视频嵌入特征,作为基础模型。 - 优化联邦推荐系统:在代码中,通过
fed_train_a_round函数实现联邦学习过程中的模型更新。该函数不仅仅更新个性化的嵌入参数,还冻结项目嵌入模块和预测函数模块,以减少计算量。
2. 个性化低秩适配器和自适应门学习机制
- 个性化低秩适配器:代码中通过
lora模块实现了用户级别和用户组级别的低秩矩阵分解。MLP类中利用self.loras参数字典存储了不同用户特定的低秩适配器参数,以减少个性化所需的参数量。这种方式允许对用户个性化建模时的参数共享,克服了现有方法的局限性。 - 自适应门学习机制:在代码
gate_layers中,实现了自适应门机制,将公共和个性化权重动态分配。每层的gate_layers使用多分支权重分配策略,允许根据不同特征自适应融合知识。模型在forward函数中调用gate_layers,从而实现不同分支的动态融合,增强了多分支融合的灵活性。
3. 优化策略与隐私增强
- 优化策略:FedPA 的优化策略主要体现在通过冻结部分模块、更新用户相关参数以减少通信和计算量。代码中
freeze方法冻结了项目嵌入模块和预测函数模块,这降低了需要传输的参数量,从而提高了系统效率。 - 隐私保护:论文中提到的隐私保护技术在代码实现上体现在将本地差分隐私(LDP)技术集成到模型中。在
engine.py中的aggregate_clients_params函数对客户端上传的模型参数进行聚合,并添加了噪声,以保护隐私。这种设计既保障了隐私,又保留了模型性能。
实验部分
-
数据集和评估指标:代码中的
train.py和data.py加载了推荐数据集,并使用了 AUC 和 Precision 作为评估指标。具体来说,SampleGenerator类负责分割数据集并生成数据加载器,而fed_evaluate函数则用于在每轮联邦学习后进行模型评估。 -
实验内容 :
- 与基线模型对比:FedPA 被测试与多种个性化联邦推荐模型对比,代码中实现了 warm-starting 策略并比较了不同的初始化方式。实验表明 FedPA 通过低秩个性化和动态门控机制实现了性能的提升,且通信开销显著减少。
- 低秩个性化分析和公共知识影响:在
forward方法中,模型通过门控机制和低秩矩阵分解来测试用户个性化的效果。同时,通过冻结项目嵌入和预测函数模块验证了公共知识对模型性能的影响。 - 隐私增强实验:为保证隐私增强的效果,代码通过添加不同强度的噪声来测试隐私保护对模型性能的影响,控制噪声的强度可以在隐私和模型性能之间取得平衡。
综合分析
FedPA 模型在架构上通过低秩适配器和自适应门机制有效增强了个性化推荐的性能,并且通过差分隐私确保了模型的安全性。
核心创新点 1:提出新的联邦适应范式 FedPA
创新点:FedPA 作为一种新的联邦推荐系统范式,首次引入预训练模型作为联邦学习的基础模型,结合联邦学习框架以保护用户隐私。
代码实现:
-
预训练模型加载与热启动:代码中
train.py的fed_train_a_round函数中,通过以下代码实现了预训练模型的加载(热启动):# 加载预训练模型的参数 if round_id == 0: self.model.load_state_dict(torch.load(self.config['pretrain_model_dir']), strict=False)通过这一代码段,在联邦训练的第一轮加载预训练模型参数,从而提供了一个热启动的初始模型状态,使得联邦推荐可以在更好的起点上开始。
-
用户和视频嵌入的初始化:在
model.py中的MLP类构造函数中,定义了embedding_user_features和embedding_video_features,用于初始化用户和视频嵌入层。此设计旨在将预训练模型的嵌入特征作为起始点:self.embedding_user_features = torch.nn.ModuleDict() for feature_name in self.selected_user_features: feature_unique_value_count = user_features[feature_name].max() + 1 self.embedding_user_features[feature_name] = torch.nn.Embedding(num_embeddings=feature_unique_value_count, embedding_dim=self.latent_dim) self.embedding_video_features = torch.nn.ModuleDict() for feature_name in self.selected_video_features: feature_unique_value_count = video_features[feature_name].max() + 1 self.embedding_video_features[feature_name] = torch.nn.Embedding(num_embeddings=feature_unique_value_count, embedding_dim=self.latent_dim)
核心创新点 2:个性化低秩适配器
创新点:在用户级和用户组级上设计个性化低秩适配器,通过低秩矩阵减少个性化建模所需的参数量,同时高效利用参数。
代码实现:
-
用户和用户组低秩适配器:在
MLP类中使用了lora参数字典来管理用户和用户组的低秩适配器参数。具体的代码如下:# 定义用户级和用户组级低秩矩阵 self.loras = torch.nn.ParameterDict() for idx, (in_size, out_size) in enumerate(zip(self.layers[:-1], self.layers[1:])): self.loras[str(idx)+'_lora_a'] = torch.nn.Parameter(torch.zeros(in_size, config['rank'])) self.loras[str(idx)+'_lora_b'] = torch.nn.Parameter(torch.zeros(config['rank'], out_size)) # 为每个用户组生成低秩矩阵 for lora_user_feature in config['lora_user_features']: lora_user_feature_values = self.lora_user_feature_value_dict[lora_user_feature] idx_lora_a_name = lora_user_feature + str(idx)+'_lora_a_' idx_lora_b_name = lora_user_feature + str(idx)+'_lora_b_' for lora_user_feature_value in lora_user_feature_values: self.loras[idx_lora_a_name+str(lora_user_feature_value)] = torch.nn.Parameter(torch.zeros(in_size, config['rank'])) self.loras[idx_lora_b_name+str(lora_user_feature_value)] = torch.nn.Parameter(torch.zeros(config['rank'], out_size)) -
适配器的动态构建和使用:在
forward函数中,user_lora_vector和lora_vector_list中分别存储了用户级和用户组级的个性化向量。代码如下:user_lora_vector = (self.lora_dropout(vector) @ self.loras[str(idx)+'_lora_a'] @ self.loras[str(idx)+'_lora_b']) * self.lora_scaling lora_vector_list = [user_lora_vector] for lora_user_feature in lora_user_feature_values.keys(): lora_user_feature_value = lora_user_feature_values[lora_user_feature] lora_vector_list.append((self.lora_dropout(vector) @ self.loras[lora_user_feature+str(idx)+'_lora_a_'+str(lora_user_feature_value)] @ self.loras[lo

最低0.47元/天 解锁文章
840

被折叠的 条评论
为什么被折叠?



