介绍PPML in SPU,定点数和浮点数的区别,迁移现有算法全流程
讲师:周金金
一、PPML in SPU
1. ML和MPC技术栈对比
ML和MPC之间存在巨大鸿沟!
- ML:关注于具体的某个算法、optimizer、网络结构
- MPC:底层密码学协议
2. SPU是什么
二、浮点数和定点数
在mpc中,一般会将浮点数encoding为定点数
1. 浮点数表示
就是计算机组成原理中学到的知识
FP32标是的数据有32个bit,其中:
- S S S:符号,1bit,0表示正数,1表示负数
- E E E:指数,8bit,以2为底
- f f f:尾数,23bit,决定了精度
(1) 规格数
- 取值范围: ( − 2 × 2 127 , − 1 × 2 − 126 ) ∪ [ 1 ∗ 2 − 126 , 2 × 2 127 ) (-2 \times 2 ^{127}, -1 \times 2 ^{-126}) \cup[1 * 2^{-126}, 2\times 2^{127}) (−2×2127,−1×2−126)∪[1∗2−126,2×2127)
-
E E E的范围 [ 1 , 254 ] [1,254] [1,254]
例如上述 E E E占了 8 b i t 8bit 8bit,原始可取范围为 [ 0 , 255 ] [0,255] [0,255],但是对于0和255作为特殊标记。
取0时表示非规格数,取255时表示无穷或者NaN
-
后面的23bit为小数部分
(2) 非规格数
F = ( − 1 ) S × 0. f × 2 1 − 127 F = (-1)^{S} \times0.f\times 2^{1-127} F=(−1)S×0.f×21−127
- E = 0 E=0 E=0
- 范围: ( − 1 × 2 − 126 , 1 × 2 − 126 ) (-1 \times 2^{-126}, 1 \times 2^{-126}) (−1×2−126,1×2−126)
(3) 特殊值
2. 定点数表示
因此,上述表示的数据是 5.625 5.625 5.625
-
小数和整数的bit比较固定
-
一般的,对于定义在环上的 L b i t L bit Lbit的数据,小数位数为 f x p = F fxp=F fxp=F
-
取值范围: ( − 2 L − 1 − F , 2 L − 1 − F ) (-2^{L-1-F}, 2^{L-1-F}) (−2L−1−F,2L−1−F)
例如上图,最小值为11111111全1表示的数据,取值 − 7.9375 -7.9375 −7.9375,表示为 − [ 2 L − 1 − F − 1 + ( 1 − 2 − f ) ] -[2^{L-1-F}-1+(1-2^{-f})] −[2L−1−F−1+(1−2−f)](大家可以代入上式计算一下),因为 − [ 2 L − 1 − F − 1 + ( 1 − 2 − f ) ] = − [ 2 L − 1 − F − 2 − f ] > − 2 L − 1 − F -[2^{L-1-F}-1+(1-2^{-f})]=-[2^{L-1-F}-2^{-f}] > -2^{L-1-F} −[2L−1−F−1+(1−2−f)]=−[2L−1−F−2−f]>−2L−1−F,所以表示为上述开区间。
最大表示为01111111(除了符号为0),取值为 7.9375 7.9375 7.9375,表示为 2 L − 1 − F − 1 + ( 1 − 2 − f ) = 2 L − 1 − F − 2 − f 2^{L-1-F}-1+(1-2^{-f})=2^{L-1-F}-2^{-f} 2L−1−F−1+(1−2−f)=2L−1−F−2−f,因为该值小于 2 L − 1 − F 2^{L-1-F} 2L−1−F,因此表示为开区间。
-
SPU中的取值范围: ( − 2 L − 2 − F , 2 L − 2 − F ) (-2^{L-2-F}, 2^{L-2-F}) (−2L−2−F,2L−2−F)
SPU中进行的compare是基于msb,因此为了使得基于msb的比较能够work。
-
两个数的间距为 2 − F 2^{-F} 2−F
-
3. 比较
定点数的范围: ( − 2 44 , 2 44 ) (-2^{44}, 2^{44}) (−244,244)
三、将明文算法迁移到SPU的流程
https://github.com/secretflow/spu/blob/0.9.0b2/docs/tutorials/develop_your_first_mpc_application.ipynb
1. 使用JAX实现算法
- 可以将算法用JAX实现:
- JAX.numpy:Numpy的平替
- JAX.law:low level算子
- JAX.scipy:Scipy的平替
2. SPU下验证数值精度
- 先在明文环境下计算数值
- 模拟定点数上运行的所有操作
- 提供真实的数值计算精度环境,运行速度更快、快速实验
3. SPU下验证性能
测试密态下的实际性能
- 在真实的MPC协议上通过多进程/Docker进行仿真
- 提供算法有效的性能结果
四、实践
1. 使用JAX编写明文算法
(1) 准备数据集
这里使用Cancer数据集,并对其进行了标准化。
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
# normally, LR works only when the features have been normalized!
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)
(2) 实现明文算法
Implement algorithm in plaintext
在这一部分中,我们首先忘记 MPC 设置(数据分割方式、协议等)并以plain text形式实现算法。
-
Optimizer我们选择简单有效的SGD(Stochastic Gradient Descent,随机梯度下降)
-
model我们选择使用逻辑回归Logistic Regression
-
逻辑回归问题的梯度计算公式:
g = 1 n ∑ i ( s i g m o i d ( w T x i ) − y i ) x i g = \frac{1}{n}\sum_{i}(sigmoid(\mathbf{w}^Tx_i)-y_i)x_i g=n1i∑(sigmoid(wTxi)−yi)xi -
对于SGD选择经过改进的policy-sad,能对大多数场景中的训练进行加速。
https://github.com/secretflow/spu/blob/0.9.1dev20240614/docs/development/policy_sgd_insight.rst
如果用 k k k表示第k次迭代,用 i i i表示第i个epoch
-
在第一个epoch计算 d k d_k dk
d k = 1 ∑ j p g r a d j 2 + ϵ d_k=\frac{1}{\sqrt{\sum_j^pgrad_j^2}+\epsilon} dk=∑jpgradj2+ϵ1 -
权重更新方法:
w i , k = w i , k − 1 − d k × α × g r a d w_{i,k} = w_{i,k-1} - d_k \times \alpha \times grad wi,k=wi,k−1−dk×α×grad
-
-
-
实现sigmoid的近似计算
逻辑回归的原始响应函数是 Sigmoid函数,但是由于其中包含 MPC 中比较耗时的操作(如 exp 和除法)。因此,通常使用其他 MPC 友好函数来近似 Sigmoid函数。这里我们给出两种方法,即一阶泰勒近似
sigmoid_t1
和平方根近似sigmoid_sr
。-
一阶泰勒近似
y = 1 2 + 1 4 x y = \frac{1}{2}+\frac{1}{4}x y=21+41x如果 x x x趋近于-2,则 y y y趋近于0
如果 x x x趋于2,则 y y y将趋近于1
-
平方根近似
y = 1 2 ( x 1 + x 2 ) + 1 2 y = \frac{1}{2}(\frac{x}{\sqrt{1+x^2}}) + \frac{1}{2} y=21(1+x2x)+21
# import some basic library # use jnp just like np import jax.numpy as jnp import jax.lax import numpy as np from functools import partial def sigmoid_t1(x, limit: bool = True): ''' taylor series referenced from: https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ ''' T0 = 1.0 / 2 T1 = 1.0 / 4 ret = T0 + x * T1 if limit: return jnp.select([ret < 0, ret > 1], [0, 1], ret) else: return ret def sigmoid_sr(x): """ https://en.wikipedia.org/wiki/Sigmoid_function#Examples Square Root approximation functions: F(x) = 0.5 * ( x / ( 1 + x^2 )^0.5 ) + 0.5 sigmoid_sr almost perfect fit to sigmoid if x out of range [-3,3] highly recommended use this appr as GDBT's default sigmoid method. """ return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5 def sigmoid(x, method='t1'): if method == 't1': return sigmoid_t1(x) else: return sigmoid_sr(x)
-
-
policy-sgd 需要在第一个时期调整学习率
类似与adam,对梯度进行一定标准化
# Note: we leave a method param in this function for next part, in plaintext, we won't invoke low-level op in most conditions. def compute_dk_func(x, eps=1e-6, method='norm'): # Same as Adam, need add small eps to avoid zero-division error if method == 'norm': return 1 / (jnp.linalg.norm(x) + eps) else: # invoke low-level rsqrt op by hand return jax.lax.rsqrt(jnp.sum(jnp.square(x)) + eps)
-
SGDClassifier的实现
- 初始化时提供了6个超参数,
dk_method
用于policy-sgd
class SSLRSGDClassifier: def __init__( self, epochs: int, learning_rate: float, batch_size: int, sig_type: str = 't1', eps: float = 1e-6, # eps is the small number for computing dk dk_method: str = 'norm', # method to compute dk, default is use jnp.linalg.norm function ): # parameter check. assert epochs > 0, f"epochs should >0" assert learning_rate > 0, f"learning_rate should >0" assert batch_size > 0, f"batch_size should >0" assert sig_type in ['t1', 'sr'], f"sig_type should one of ['t1', 'sr']" assert eps > 0, f"eps should >0" assert dk_method in [ 'norm', 'rsqrt', ], f"dk_method should one of ['norm', 'rsqrt']" self._epochs = epochs self._learning_rate = learning_rate self._batch_size = batch_size self._sig_type = sig_type self._eps = eps self._dk_method = dk_method self._weights = jnp.zeros(())
-
原生更新梯度的方式
-
根据 x w \mathbf{x}\mathbf{w} xw计算 y ^ i = S i g m o i d ( x w ) \hat{y}_i = Sigmoid(\mathbf{x}\mathbf{w}) y^i=Sigmoid(xw)
-
计算梯度:
g r a d = 1 n ∑ i ( y ^ i − y i ) x i grad = \frac{1}{n}\sum_{i}(\hat{y}_i-y_i)x_i grad=n1i∑(y^i−yi)xi -
批量梯度下降法更新特征权重 w \mathbf{w} w
w i , k = w i , k − 1 − d k × α × g r a d w_{i,k} = w_{i,k-1} - d_k \times \alpha \times grad wi,k=wi,k−1−dk×α×grad
def _update_weights( self, x, # array-like y, # array-like w, # array-like total_batch: int, batch_size: int, dk_arr, # array-like ): num_feat = x.shape[1] assert w.shape[0] == num_feat + 1, "w shape is mismatch to x" assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array" w = w.reshape((w.shape[0], 1)) compute_dk = False if dk_arr is None: compute_dk = True dk_arr = [] for idx in range(total_batch): begin = idx * batch_size end = min((idx + 1) * batch_size, x.shape[0]) rows = end - begin # padding one col for bias in w x_slice = jnp.concatenate((x[begin:end, :], jnp.ones((rows, 1))), axis=1) y_slice = y[begin:end, :] pred = jnp.matmul(x_slice, w) pred = sigmoid(pred, method=self._sig_type) err = pred - y_slice grad = jnp.matmul(jnp.transpose(x_slice), err) / rows if compute_dk: dk = compute_dk_func(grad, self._eps, self._dk_method) dk_arr.append(dk) else: dk = dk_arr[idx] step = self._learning_rate * grad * dk w = w - step if compute_dk: dk_arr = jnp.array(dk_arr) return w, dk_arr
-
-
训练
进行epochs轮计算,每个epoch中每次选择batch_size个样本进行迭代。
def fit(self, x, y): """Fit LR with policy-sgd.""" assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}" assert len(y.shape) == 2, f"expect y to be 2 dimension array, got {y.shape}" num_sample = x.shape[0] num_feat = x.shape[1] batch_size = min(self._batch_size, num_sample) total_batch = (num_sample + batch_size - 1) // batch_size # always fit intercept weights = jnp.zeros((num_feat + 1, 1)) dk_arr = None # do train for _ in range(self._epochs): weights, dk_arr = self._update_weights( x, y, weights, total_batch, batch_size, dk_arr ) self._weights = weights self.dk_arr = dk_arr return
-
预测
根据 w \mathbf{w} w和 b i a s bias bias计算预测值:
y ^ = s i g m o d ( w x + b i a s ) \hat{y} = sigmod(\mathbf{w}\mathbf{x}+bias) y^=sigmod(wx+bias)def predict_proba(self, x): """Probability estimates""" num_feat = x.shape[1] w = self._weights assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}" assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array" w.reshape((w.shape[0], 1)) bias = w[-1, 0] w = jnp.resize(w, (num_feat, 1)) pred = jnp.matmul(x, w) + bias pred = sigmoid(pred, method=self._sig_type) return pred
-
使用
plain_model = SSLRSGDClassifier( epochs=3, learning_rate=0.1, batch_size=8, sig_type='t1', eps=1e-6, dk_method='norm' ) # train plain_model.fit(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array # predict predict_prob = plain_model.predict_proba(X.values) # score from sklearn.metrics import roc_auc_score print(f"auc = {roc_auc_score(y.values, predict_prob)}")
auc = 0.9903083
- 初始化时提供了6个超参数,
2. 在SPU上的适配
SPU下验证精度(simulation)
通常情况下,您只需使用 spu 执行类似 LR 的操作即可在安全上下文中运行程序:将数据集移动到 PYU 或 SPU,使用您声明的 SPU 运行程序并reveal您需要的一些信息(但是reveal是一个非常危险的操作,您应该在实际应用中非常小心地使用它)。
但是,我们稍后会看到,您可能会遇到明文和秘密之间的巨大度量差距(如 LR 中的 auc)。开发人员可以更简单地运行 MPC 程序,并具有高度的灵活性来调整超参数,例如环的大小、fxp 或底层 MPC 协议等,这将是一个更好的选择。
因此,在这一部分中,我们将展示如何使用simulator运行我们的算法,就像运行普通的 MPC 程序一样,并进行最少的实验来关注和验证程序的缺陷。使用simulator而不是直接使用 SPU 设备运行程序有两个优点:
- 更少的代码:无需处理大量的 DeviceObject 并在 SPU 之间从 PYU 移动数据。
- 更快的实验:没有连接ray cluster,端到端运行实验。
(1) 两方模拟
我们首先定义一个简单的simulator,使用 CHEETAH 协议和 64 位环,采用 2pc 设置。
- 支持多种MPC协议
- 支持改变ring大小
- 支持改变fxp
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu
sim = spsim.Simulator.simple(2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64)
也可以使用
spu.RuntimeConfig
进行自定义。设置enable_hal_profile
和enable_pphlo_profile
可以在程序运行后看到对应的详细调用情况,比如通信量和总的通信量[2024-06-25 05:03:03.393] [info] [api.cc:222] Link details: total send bytes 240, recv bytes 400, send actions 3
config_aby = spu.RuntimeConfig( protocol = spu_pb2.ProtocolKind.ABY3, field = spu.FieldType.FM64, fxp_fraction_bits = 18, enable_hal_profile=True, enable_pphlo_profile=True ) sim_aby = spsim.Simulator(3, config_aby)
定义训练和预测函数:就是根据之前的plaintext进行了一次封装
def fit_and_predict(
x,
y,
epochs=3,
learning_rate=0.1,
batch_size=8,
sig_type='t1',
eps=1e-6,
dk_method='norm',
):
model = SSLRSGDClassifier(
epochs=epochs,
learning_rate=learning_rate,
batch_size=batch_size,
sig_type=sig_type,
eps=eps,
dk_method=dk_method,
)
model.fit(x, y)
return model.predict_proba(x)
result = spsim.sim_jax(sim, fit_and_predict)(
X.values, y.values.reshape(-1, 1)
) # X, y should be two-dimension array
评估拟合后的模型:
roc_auc_score(y, result)
emm,最后的auc只有0.4…(不用慌,可以选择其他协议或者调参来优化)。
(2) 三方模拟
更换了三方协议ABY3,使用64bit的环。
sim_aby = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)
result = spsim.sim_jax(sim_aby, fit_and_predict)(X.values, y.values.reshape(-1, 1))
auc = roc_auc_score(y, result)
这样,得到的auc约为0.97。
当程序在不进行任何修改的情况下以分片形式运行计算,经过 3 个 epoch 的训练后,auc 可能会急剧下降(对于 cheetah 来说从 0.990 下降到 0.490)!
我们将进行一些分析,并尝试首先从应用程序的角度对其进行修复,然后从 MPC 的角度进行更深入的思考。
(3) 算法角度继续思考
-
一阶泰勒近似Sigmoid计算
y ^ = 1 2 + 1 4 ( w x ) \hat{y} = \frac{1}{2}+\frac{1}{4}(\mathbf{wx}) y^=21+41(wx)如果内积 w x \mathbf{wx} wx趋近于-2,则 y y y趋近于0
如果内积 w x \mathbf{wx} wx大于等于2,则 y y y将限制在1;(带截断)
-
grad的计算:
g r a d = 1 n ∑ i ( s i g m o i d ( w T x i ) − y i ) x i grad = \frac{1}{n}\sum_i(sigmoid(\mathbf{w}^Tx_i)-y_i)x_i grad=n1i∑(sigmoid(wTxi)−yi)xi
如果使用t1的公式近似sigmoid,很有可能grad=0。 -
dk的计算
d k = 1 ∑ j p g r a d j 2 + ϵ d_k=\frac{1}{\sqrt{\sum_j^pgrad_j^2}+\epsilon} dk=∑jpgradj2+ϵ1
如果碰巧 grad 的所有元素非常接近 0(MPC 中可能几乎没有误差),那么第一个 epoch 计算出的 d k dk dk就会变得非常大(分母非常接近于0)。
- 参数更新
w i , k = w i , k − 1 − d k × α × g r a d w_{i,k} = w_{i,k-1} - d_k \times \alpha \times grad wi,k=wi,k−1−dk×α×grad
参数更新依赖于 d k d_k dk,如果 d k d_k dk过大,那么参数更新会变得异常剧烈,可能会导致训练失败。
-
可以考虑的解决方案:
- 增大batch_size:减少梯度为0的概率
- 增大 d k d_k dk计算使用的 ϵ \epsilon ϵ:减少极端情况下学习率的抖动
- 使用不带截断的Sigmoid近似方法(SR):减少梯度为0的概率
-
例如,对于上面使用CHEETAH的simulator,这里我们修改了batch_size重新执行
result = spsim.sim_jax(sim, partial(fit_and_predict, batch_size=64))( X.values, y.values.reshape(-1, 1) ) roc_auc_score(y, result) # rather pool under cheetah protocol!
0.9892579673378785
Congratulations~结果正常啦
(4) MPC角度继续思考
将根据底层协议进行讨论
https://www.secretflow.org.cn/zh-CN/docs/spu/main/reference/mpc_status
-
协议选择
https://www.secretflow.org.cn/zh-CN/docs/spu/main/reference/mpc_status
一般来说,对于 2pc,使用 cheetah 是安全的,而对于 3pc,ABY3 是唯一的选择。
- ABY3:一个诚实多数3PC协议,SPU提供半诚实实现,注意,如果两个以上的计算节点部署在一起,则很难抵御合谋攻击。
- Semi2k-SPDZ*:一种类似于 SPDZ 的半诚实MPC 协议,但需要可信的第三方来生成离线随机数。默认情况下,此协议现在使用可信的第一方。因此,它应该仅用于调试目的。不安全
- Cheetah*:一种快速的2pc半诚实协议,它使用精心设计的基于同态加密的协议进行算术运算,使用Ferret进行布尔运算。
-
SPU所能表示的最小正数为 2 − f 2^{-f} 2−f,因为 f = 18 f = 18 f=18,因此为 2 − 18 2^{-18} 2−18
-
Cheetah是一个快速的两方半诚实协议,使用PHE来激素计算,但是在进行乘法运算时会产生0-2位错误。
如果是64位ring,使用大约18位的定点数,因此spu可以表示的最小正浮点数是 1 2 18 ≈ 3.8147 × 1 0 − 6 \frac{1}{2^{18}} \approx 3.8147\times 10^{-6} 2181≈3.8147×10−6
-
因此最开始进行两方模拟时导致auc小于0.5,很可能出现了计算溢出。若梯度很小(接近 2 − 18 2^{-18} 2−18),则平方运算造成的0-2bit误差可能会对结果造成显著的影响,甚至可能从最小正数溢出到最小负数
通过对一组很小的数求平方和来验证乘法是否出现溢出:
# Let's test this def test_square_and_sum_when_x_small(x): return jnp.sum(jnp.square(x)) spsim.sim_jax(sim, test_square_and_sum_when_x_small)(np.array([1e-5] * 10))
结果只有 − 3.0517578 × 1 0 − 5 -3.0517578 \times 10 ^{-5} −3.0517578×10−5,近似于 − 2 − 18 × 10 -2^{-18} \times 10 −2−18×10咦,结果符号都变了。
改进思路仍然会以减少梯度为0入手:增大fxp和环大小:提高表示精度,减少下溢到0点概率
-
我们保持batch_size不变,增加环的大小
sim128 = spsim.Simulator.simple( 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM128 ) result = spsim.sim_jax(sim128, fit_and_predict)(X.values, y.values.reshape(-1, 1)) auc = roc_auc_score(y, result)
结果为0.990308387。
3. 在SPU上验证性能
我们主要想论证在 MPC 设置中 policy-sgd
优于 naive-sgd
,因此我们可以设计以下实验:
dk_method
为policy-sgd找到最佳eps
:对于所有数据集,比较准确性和效率。sig_type
比较 policy-sgd 和 naive-sgd切换时的准确率和效率。- 为了比较 policy-sgd 和 naive-sgd,我们固定
epochs
并测试learning_rate
和batch_size
的影响
当前只支持多进程模式,且需要从spu源码编译运行!
五、常见问题
1. 怎么知道SPU支持哪些算子?
实践出真知,自己动手: simulation, emulation
- https://www.secretflow.org.cn/zh-CN/docs/spu/0.9.1b0/reference/np_op_status
- https://www.secretflow.org.cn/zh-CN/docs/spu/0.9.1b0/reference/xla_status
2. 怎么知道非线形算子大致的误差范围
误差来自两个方面:
- 系统设定误差(如环大小,fxp大小,truncation协议等)
- 非线性算子拟合误差
因此很难给出如浮点数的误差估计,以下文档给出了一些数学算子的大致误差:
https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp
3. 为什么明文下运行正常,密态下运行错误
-
若运行报错:
- 实现的算法是否jitable(即使用@jax.jit是否能运行)
- 是否使用了SPU不支持的算子
-
若能运行,但误差极大,可以自查是否有以下情况:
- 是否可能发生溢出:输入数据或参数是否太大或太小
- SPU内部是否使用了浮点随机数生成器
- 是否调用了线形代数算子(如矩阵分解,奇异值分解等)
-
若误差适中:
可以考虑增大环的大小,提高fxp精度
4. 为什么Emulation的速度比simulation快很多?
数据没有seal(即load到PYU),SPU将其视为Public数据,所有计算在明文下进行
5. 如何对密态算法进行优化
有以下几个思路可以参考:
-
减少耗时算子的调用(计算公式重写,多项式近似等)
-
避免重复计算(空间换效率)
-
并行化
实际上,SPU内部已经做了大量的并行操作,若希望进一步优化,可以尝试:
- 算法层:for循环很多时候可以通过高阶tensor运算来代替,也可以考虑使用
jax.vmap
进行自动向量化 - Runtime:尝试开启更多并行(experimental feature),如experimental_enable_inter_op_par(即DAG并行)
六、如何使用SPU开发
例如,计算policy_sgd的 d k = 1 t 0 2 + t 1 2 + ϵ dk = \frac{1}{\sqrt{t_0^2+t_1^2}+\epsilon} dk=t02+t12+ϵ1 。
- 首先生成明文数据
import numpy as np
import jax.lax
# Emulation should be run from source in SPU, so we use Secretflow here to do the efficiency experiments.
# You can use the similar trick for emulation directly in SPU.
def compute_dk_func(x, eps=1e-6):
return jax.lax.rsqrt(jnp.sum(jnp.square(x)) + eps)
x = np.random.rand(1_000_000)
alice_device = sf.PYU("alice")
bob_device = sf.PYU("bob")
spu_obj = sf.SPU(
cluster_def=cluster_conf,
link_desc={
"connect_retry_times": 60,
"connect_retry_interval_ms": 1000
},
)
# cluster_config参考之前笔记的设置,这里不再描述
-
将明文数据放到alice_device下
# first, load data to PYU alice_data = alice_device(lambda x: x)(x)
-
在spu_obj上执行
compute_dk_func
计算 d k dk dk:ret = spu_obj(compute_dk_func)(alice_data) sf.reveal(ret) # array(0.00173242, dtype=float32)