横向LR代码理解

横向LR的代码理解


横向LR以及其他类似的横向联邦学习的代码框架比较清晰,相对于纵向逻辑更好理清 整体的横向LR的结构如图:

1
我们可以看到逻辑就是A和B在本地训练若干轮(aggregate_iters)后来一次向Arbiter的聚合,之后收到Arbiter的回复后便结束了一轮。所以,我们想要理清整个横向LR的代码框架可以从client怎么向arbiter传送数据以及arbiter怎么拿到传来的数据这两方面考虑。至于arbiter反过来传给client以及client的接收和前面两步类似,所以只需了解前两步即可。

一、client向arbiter传送过程

还是以Guest方为例(Host方类似),找到federatedml / linear_model / logistic_regression / homo_logsitic_regression/homo_lr_guest.py文件后观察其fit函数,也就是训练函数,可以看到其结构如下:
2
下方绿色部分是指guest在本地训练的部分,上分红色部分是指将模型聚合(也就是发送给arbiter方的部分),从上方红色框的条件语句可以看到,当我们当前轮数满足了aggregate_iters的整数倍(也就是每隔aggregate_iters聚合一次)或者达到最大轮数时候,我们的guest方会向arbiter发送一次,我们这部分内容就重点关注是如何发送的。
step 1: 了解model_weights变量

weight = self.aggregator.aggregate_then_get(model_weights, degree=degree,suffix=self.n_iter_)

传输的源头起于上面这行代码,但传入的参数model_weights究竟是什么东西呢,如果我们想查看它甚至是想修改它应该怎么做呢?
model_weights的起源是

self.model_weights = self._init_model_variables(data_instances)

不难看出它是一个对象类,具体打印出type是
<class ‘federatedml.framework.weights.TransferableWeights’>,我们具体进入这个类可以发现当我们调用其 unboxed 方法可以将这个对象类中的值取出来,而这个取出来的值是numpy类型,这个类型是可以供我们修改的。

step 2 : 进入aggregate_then_get函数
3
可以看到这个函数分为发和收两部分,我们上边也提到了我们只看client向arbiter这一趟线,所以我们只关注 send_model 这个函数即可。
顺便一提,此函数所在的位置 federatedml / framework / homo / procedure / aggregator.py中,这个aggregator.py可以说是整个横向联邦的大管家,我们交接的主要函数都是通过它来转手的,而且看过往期教学视频的同学应该也知道,这个函数正是横向联邦的核心。

step 3: 进入send_model函数
4
先看secure_aggregate,它调用了上面的_func,从形参的字面意思可以看出它是把_func作为一个发射器了,我们先看一下secure_aggregate是什么样子

step 4 查看secure_aggregate函数
5
1、degree行是说是否有权重的相乘
2、enable_secure_aggregate行是说要不要加密(如果我们开始修改了model_weights的值,这里很可能会带上加密后改变原有值,所以如果数值变了不妨在这儿找找看)
3、remote_weights行调用的for_remote函数是将我们的
<class ‘federatedml.linear_model.linear_model_weight.LinearModelWeights’>类型的model_weights转变为:<class ‘federatedml.framework.weights.TransferableWeights’>以便后续一些操作

所以上面这个函数就是包装了一下我们的model_weights然后就利用_func发射出去了,我们接下来看一下这个发射函数什么样子:
step 5 : 查看_func函数
6
点击进入_model_scatter.send_model函数
7
继续进入remote_parties(这里边我认为是remote函数是传递功能)
在这里插入图片描述
step 6 进入remote函数
进入后会惊奇的发现这个函数啥也没写
7
对此我也困惑了许久,后来找到了实现它的函数位置(对于单机版的用户来说它在arch / api / impl / based_1x / federation_standalone.py中)
具体如上图
至此,client向arbiter传送的过程大体结束,接下来介绍一下arbiter如何收下这些数据。

二、arbiter接收client的数据

1
打开homo_lr_arbiter,py发现和guest的结构类似,我们下手的目标便是上图中红色框中的内容。

step 1 : 打开aggregator.aggregate_and_broadcast函数
2
因为我们想要看arbiter如何接收数据,所以我们应该关注上图红色框中的aggregate_model函数

step 2 : 打开aggregate_model函数
在这里插入图片描述
第一行models的获取便是我们要关注的核心部分,经过测试我发现他是一个generator对象,next(models)[0]代表每次生成的model,next(models)[1]代表每次生成的degree,但是在这儿我重要的话说三遍!!!!!千万千万千万不要用 LOGGER.debug打印了next(models)[0]后不将LOGGER语句注释掉,因为如果用LOGGER调用过一次next,它下边reduce函数就会从下一次调用,就会匹配不上,此处bug本人曾改了许久才发现了问题。
好了废话不多说,我们继续打开models的创造者

step 3 : 打开get_models_for_aggregate函数
6
看到yield便可知道,这是个生成器函数,第一次会弹出guest_model,后边会依次弹出各个host_model,在上图中的我的注释部分提到可以在“这儿”修改,是因为我之前用guest方举例,如果我们在guest方修改了数据,就可以从这个函数的这个地方拿出来看或者修改。
值得一提的就是,这里的guest_model类型是:<class ‘federatedml.framework.weights.TransferableWeights’>,这个类对象可以通过unboxed方法提取它的值,但是!但是!但是!

我们没有直接修改它的unboxed值的办法!所以我想了一招,直接去这个类中添加一个set方法后,调用set方法传入我们新改的值,如下:

在这里插入图片描述
之后再在上述guest_model的地方修改即可
在这里插入图片描述

到此为止,横向的LR结构的client向arbiter的单向流步骤大体已经讲完,从arbiter返回client类似,在此不过多赘述。
另外顺便提一下,在host方貌似数据都是加密过的,不能直接用unboxed看到其值,如果测试用的话,大家可以将conf文件中加密部分改为null即可显示数据。如下:
在这里插入图片描述
至此,横向LR的代码理解篇结束,祝大家事业有成一帆风顺!

<think>好的,用户现在需要我给出调整字体后的完整代码。让我回顾一下之前的对话,确保我理解他的需求。用户之前提供了自己的代码,并希望调整字体大小和清晰度,同时避免重叠。我已经给出了分步优化方案,现在用户可能希望将这些调整整合到完整的代码中,以便直接运行。 首先,我需要确认用户提供的原始代码的结构,以及我建议的修改点。用户之前的代码包括导入库、字体设置、数据读取、创建子图、循环绘制散点图、颜色条设置以及保存图像。我的优化建议涉及字体参数调整、子图布局、标签旋转、图例定位等。 接下来,我需要将这些优化点整合到用户的原始代码中,同时保持代码结构清晰。需要确保所有修改都被正确嵌入,例如在plt.rcParams.update中添加字体家族,调整子图尺寸和间距,添加标签旋转逻辑,修改图例的位置和样式,以及优化保存图像的参数。 还需要注意用户可能遗漏的部分,比如在字体设置中指定一个跨平台兼容的字体,如'DejaVu Sans',以防止不同系统上默认字体不一致的问题。此外,动态旋转标签的逻辑需要检查数据中参数值的长度,以决定是否旋转刻度标签,这可能需要处理不同的数据类型,比如浮点数可能需要转换为字符串并检查长度。 另外,颜色条的设置可能需要调整刻度字体大小和标题间距,确保其与整体样式一致。在保存图像时,添加bbox_inches='tight'和facecolor参数可以避免边缘被裁剪并确保背景为白色。 需要验证代码是否完整,所有部分是否衔接正确,比如调整后的子图尺寸和间距参数是否正确,循环中的每个ax是否都应用了动态旋转和刻度限制。还要检查图例的bbox_to_anchor是否合适,避免图例超出图像范围或被遮挡。 最后,确保代码注释清晰,用户能够理解每个修改的作用,可能在代码中添加注释说明关键调整点,帮助用户后续进一步自定义。同时,提醒用户运行后检查输出图片,确认所有文字清晰不重叠,并根据需要微调参数如wspace或hspace。</think>这是整合所有优化方案的完整代码版本: ```python import pandas as pd import matplotlib.pyplot as plt import numpy as np # ================= 全局样式设置 ================= plt.rcParams.update({ 'font.family': 'DejaVu Sans', # 跨平台无衬线字体 'font.size': 18, 'axes.titlesize': 20, # 标题放大 'axes.labelsize': 20, 'xtick.labelsize': 16, # 缩小刻度字体 'ytick.labelsize': 16, 'legend.title_fontsize': 14, 'legend.fontsize': 12, 'figure.autolayout': False # 禁用自动布局 }) # ================= 数据准备 ================= df = pd.read_csv("/root/autodl-tmp/runs/train/tune2/tune_results.csv") index_order = np.arange(len(df)) parameters = ["lr0", "lrf", "momentum", "weight_decay", "warmup_epochs", "warmup_momentum", "cls", "hsv_h", "hsv_s", "hsv_v"] # ================= 画布设置 ================= fig, axs = plt.subplots(3, 3, figsize=(28, 21)) # 增大画布尺寸 plt.subplots_adjust( wspace=0.4, # 横向间距 hspace=0.6 # 纵向间距 ) # ================= 可视化循环 ================= for ax, parameter in zip(axs.flatten(), parameters): # 散点图绘制 scatter = ax.scatter(df[parameter], df["fitness"], c=index_order, cmap='viridis', alpha=0.7) # 最佳点标注 best_index = df['fitness'].idxmax() best_param = df.loc[best_index, parameter] best_fitness = df.loc[best_index, 'fitness'] ax.scatter(best_param, best_fitness, color='red', s=150, edgecolors='black', zorder=5, linewidth=2, marker='*', label=f'Best: {best_param:.4f}\nFitness: {best_fitness:.3f}') # 动态标签调整 ax.set_xlabel(parameter, fontsize=22, labelpad=12) ax.set_ylabel('Fitness', fontsize=22, labelpad=12) # 智能刻度控制 ax.xaxis.set_major_locator(plt.MaxNLocator(5)) ax.yaxis.set_major_locator(plt.MaxNLocator(5)) # 自动旋转逻辑 sample_value = str(df[parameter].iloc[0]) rotation = 45 if len(sample_value) > 5 else 0 ax.tick_params(axis='x', rotation=rotation, pad=8) # 外置图例 ax.legend( fontsize=12, bbox_to_anchor=(0.5, -0.35), loc='upper center', borderaxespad=0., framealpha=0.8, handletextpad=0.5 ) # ================= 颜色条优化 ================= cbar = fig.colorbar( scatter, ax=axs.ravel().tolist(), orientation='horizontal', fraction=0.02, pad=0.15 ) cbar.ax.set_title('Parameter Search Order', fontsize=18, pad=15) cbar.ax.tick_params(labelsize=14) # ================= 最终输出 ================= plt.savefig( "tune2_optimized.png", dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none' ) plt.close() ``` **关键改进说明**: 1. **字体架构优化**: - 使用`DejaVu Sans`作为首选字体,兼容中文显示 - 建立从标题(20pt)→轴标签(20pt)→刻度(16pt)→图例(12pt)的字体层级 2. **防碰撞系统**: - `plt.subplots_adjust`设置更大的间距参数 - 通过`labelpad`增加标签与坐标轴的间距 - `bbox_to_anchor`将图例定位到子图外下方 3. **智能刻度系统**: - `MaxNLocator(5)`确保每个坐标轴不超过5个主刻度 - 根据参数值的字符串长度自动触发45度旋转 - `pad=8`增加刻度标签与坐标线的间距 4. **视觉增强**: - 散点图添加`alpha=0.7`提升重叠点的辨识度 - 最佳点使用星形标记(`marker='*'`)并提升zorder - 颜色条添加独立标题 **输出验证建议**: ```python # 检查最终图片尺寸 from PIL import Image img = Image.open("tune2_optimized.png") print(f"图片尺寸:{img.size} 像素") # 预期输出应大于(4000,3000) # 检查字体嵌入情况 print("图片包含的字体:", [font.name for font in img.getfonts()]) ``` 如果发现特定子图仍有重叠,可以通过单独调整该子图的`rotation`参数: ```python # 示例:强制旋转第三个子图的刻度 axs[0,2].tick_params(axis='x', rotation=45, pad=10) ```
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值