【隐语理论+实践】SML入门/基于SPU迁移机器学习算法实践

介绍PPML in SPU,定点数和浮点数的区别,迁移现有算法全流程

讲师:周金金

学习链接:https://www.bilibili.com/video/BV1os421T7vB

一、PPML in SPU

1. ML和MPC技术栈对比

ML和MPC之间存在巨大鸿沟!

在这里插入图片描述

  • ML:关注于具体的某个算法、optimizer、网络结构
  • MPC:底层密码学协议

2. SPU是什么

在这里插入图片描述

二、浮点数和定点数

​ 在mpc中,一般会将浮点数encoding为定点数

1. 浮点数表示

就是计算机组成原理中学到的知识
在这里插入图片描述

​ FP32标是的数据有32个bit,其中:

  • S S S:符号,1bit,0表示正数,1表示负数
  • E E E:指数,8bit,以2为底
  • f f f:尾数,23bit,决定了精度
(1) 规格数
  • 取值范围: ( − 2 × 2 127 , − 1 × 2 − 126 ) ∪ [ 1 ∗ 2 − 126 , 2 × 2 127 ) (-2 \times 2 ^{127}, -1 \times 2 ^{-126}) \cup[1 * 2^{-126}, 2\times 2^{127}) (2×2127,1×2126)[12126,2×2127)

在这里插入图片描述

  • E E E的范围 [ 1 , 254 ] [1,254] [1,254]

    例如上述 E E E占了 8 b i t 8bit 8bit,原始可取范围为 [ 0 , 255 ] [0,255] [0,255],但是对于0和255作为特殊标记。

    取0时表示非规格数,取255时表示无穷或者NaN

  • 后面的23bit为小数部分

(2) 非规格数

F = ( − 1 ) S × 0. f × 2 1 − 127 F = (-1)^{S} \times0.f\times 2^{1-127} F=(1)S×0.f×21127

  • E = 0 E=0 E=0
  • 范围: ( − 1 × 2 − 126 , 1 × 2 − 126 ) (-1 \times 2^{-126}, 1 \times 2^{-126}) (1×2126,1×2126)
(3) 特殊值

在这里插入图片描述

2. 定点数表示

在这里插入图片描述
在这里插入图片描述

​ 因此,上述表示的数据是 5.625 5.625 5.625

  • 小数和整数的bit比较固定

  • 一般的,对于定义在环上的 L b i t L bit Lbit的数据,小数位数为 f x p = F fxp=F fxp=F

    • 取值范围: ( − 2 L − 1 − F , 2 L − 1 − F ) (-2^{L-1-F}, 2^{L-1-F}) (2L1F,2L1F)

      ​ 例如上图,最小值为11111111全1表示的数据,取值 − 7.9375 -7.9375 7.9375,表示为 − [ 2 L − 1 − F − 1 + ( 1 − 2 − f ) ] -[2^{L-1-F}-1+(1-2^{-f})] [2L1F1+(12f)](大家可以代入上式计算一下),因为 − [ 2 L − 1 − F − 1 + ( 1 − 2 − f ) ] = − [ 2 L − 1 − F − 2 − f ] > − 2 L − 1 − F -[2^{L-1-F}-1+(1-2^{-f})]=-[2^{L-1-F}-2^{-f}] > -2^{L-1-F} [2L1F1+(12f)]=[2L1F2f]>2L1F,所以表示为上述开区间。

      ​ 最大表示为01111111(除了符号为0),取值为 7.9375 7.9375 7.9375,表示为 2 L − 1 − F − 1 + ( 1 − 2 − f ) = 2 L − 1 − F − 2 − f 2^{L-1-F}-1+(1-2^{-f})=2^{L-1-F}-2^{-f} 2L1F1+(12f)=2L1F2f,因为该值小于 2 L − 1 − F 2^{L-1-F} 2L1F,因此表示为开区间。

    • SPU中的取值范围: ( − 2 L − 2 − F , 2 L − 2 − F ) (-2^{L-2-F}, 2^{L-2-F}) (2L2F,2L2F)

      SPU中进行的compare是基于msb,因此为了使得基于msb的比较能够work。

    • 两个数的间距为 2 − F 2^{-F} 2F

3. 比较

在这里插入图片描述

定点数的范围: ( − 2 44 , 2 44 ) (-2^{44}, 2^{44}) (244,244)

三、将明文算法迁移到SPU的流程

https://github.com/secretflow/spu/blob/0.9.0b2/docs/tutorials/develop_your_first_mpc_application.ipynb

在这里插入图片描述

1. 使用JAX实现算法

  • 可以将算法用JAX实现:
    • JAX.numpy:Numpy的平替
    • JAX.law:low level算子
    • JAX.scipy:Scipy的平替

2. SPU下验证数值精度

  • 先在明文环境下计算数值
  • 模拟定点数上运行的所有操作
  • 提供真实的数值计算精度环境,运行速度更快、快速实验

3. SPU下验证性能

​ 测试密态下的实际性能

  • 在真实的MPC协议上通过多进程/Docker进行仿真
  • 提供算法有效的性能结果

四、实践

1. 使用JAX编写明文算法

(1) 准备数据集

​ 这里使用Cancer数据集,并对其进行了标准化。

from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
# normally, LR works only when the features have been normalized!
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)
(2) 实现明文算法

Implement algorithm in plaintext

在这一部分中,我们首先忘记 MPC 设置(数据分割方式、协议等)并以plain text形式实现算法。

  • Optimizer我们选择简单有效的SGD(Stochastic Gradient Descent,随机梯度下降)

  • model我们选择使用逻辑回归Logistic Regression

    • 逻辑回归问题的梯度计算公式:
      g = 1 n ∑ i ( s i g m o i d ( w T x i ) − y i ) x i g = \frac{1}{n}\sum_{i}(sigmoid(\mathbf{w}^Tx_i)-y_i)x_i g=n1i(sigmoid(wTxi)yi)xi

    • 对于SGD选择经过改进的policy-sad,能对大多数场景中的训练进行加速。

      https://github.com/secretflow/spu/blob/0.9.1dev20240614/docs/development/policy_sgd_insight.rst

      如果用 k k k表示第k次迭代,用 i i i表示第i个epoch

      • 在第一个epoch计算 d k d_k dk
        d k = 1 ∑ j p g r a d j 2 + ϵ d_k=\frac{1}{\sqrt{\sum_j^pgrad_j^2}+\epsilon} dk=jpgradj2 +ϵ1

      • 权重更新方法:
        w i , k = w i , k − 1 − d k × α × g r a d w_{i,k} = w_{i,k-1} - d_k \times \alpha \times grad wi,k=wi,k1dk×α×grad

  • 实现sigmoid的近似计算

    ​ 逻辑回归的原始响应函数是 Sigmoid函数,但是由于其中包含 MPC 中比较耗时的操作(如 exp 和除法)。因此,通常使用其他 MPC 友好函数来近似 Sigmoid函数。这里我们给出两种方法,即一阶泰勒近似sigmoid_t1和平方根近似sigmoid_sr

    • 一阶泰勒近似
      y = 1 2 + 1 4 x y = \frac{1}{2}+\frac{1}{4}x y=21+41x

      如果 x x x趋近于-2,则 y y y趋近于0

      如果 x x x趋于2,则 y y y将趋近于1

      在这里插入图片描述

    • 平方根近似
      y = 1 2 ( x 1 + x 2 ) + 1 2 y = \frac{1}{2}(\frac{x}{\sqrt{1+x^2}}) + \frac{1}{2} y=21(1+x2 x)+21

    # import some basic library
    # use jnp just like np
    import jax.numpy as jnp
    import jax.lax
    
    import numpy as np
    from functools import partial
    
    def sigmoid_t1(x, limit: bool = True):
        '''
        taylor series referenced from:
        https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/
        '''
        T0 = 1.0 / 2
        T1 = 1.0 / 4
        ret = T0 + x * T1
        if limit:
            return jnp.select([ret < 0, ret > 1], [0, 1], ret)
        else:
            return ret
    
    
    def sigmoid_sr(x):
        """
        https://en.wikipedia.org/wiki/Sigmoid_function#Examples
        Square Root approximation functions:
        F(x) = 0.5 * ( x / ( 1 + x^2 )^0.5 ) + 0.5
        sigmoid_sr almost perfect fit to sigmoid if x out of range [-3,3]
        highly recommended use this appr as GDBT's default sigmoid method.
        """
        return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5
    
    
    def sigmoid(x, method='t1'):
        if method == 't1':
            return sigmoid_t1(x)
        else:
            return sigmoid_sr(x)
    
  • policy-sgd 需要在第一个时期调整学习率

    类似与adam,对梯度进行一定标准化

    # Note: we leave a method param in this function for next part, in plaintext, we won't invoke low-level op in most conditions.
    def compute_dk_func(x, eps=1e-6, method='norm'):
        # Same as Adam, need add small eps to avoid zero-division error
        if method == 'norm':
            return 1 / (jnp.linalg.norm(x) + eps)
        else:
            # invoke low-level rsqrt op by hand
            return jax.lax.rsqrt(jnp.sum(jnp.square(x)) + eps)
    
  • SGDClassifier的实现

    1. 初始化时提供了6个超参数,dk_method用于policy-sgd
    class SSLRSGDClassifier:
        def __init__(
            self,
            epochs: int,
            learning_rate: float,
            batch_size: int,
            sig_type: str = 't1',
            eps: float = 1e-6,  # eps is the small number for computing dk
            dk_method: str = 'norm',  # method to compute dk, default is use jnp.linalg.norm function
        ):
            # parameter check.
            assert epochs > 0, f"epochs should >0"
            assert learning_rate > 0, f"learning_rate should >0"
            assert batch_size > 0, f"batch_size should >0"
            assert sig_type in ['t1', 'sr'], f"sig_type should one of ['t1', 'sr']"
            assert eps > 0, f"eps should >0"
            assert dk_method in [
                'norm',
                'rsqrt',
            ], f"dk_method should one of ['norm', 'rsqrt']"
    
            self._epochs = epochs
            self._learning_rate = learning_rate
            self._batch_size = batch_size
            self._sig_type = sig_type
            self._eps = eps
            self._dk_method = dk_method
    
            self._weights = jnp.zeros(())
    
    1. 原生更新梯度的方式

      • 根据 x w \mathbf{x}\mathbf{w} xw计算 y ^ i = S i g m o i d ( x w ) \hat{y}_i = Sigmoid(\mathbf{x}\mathbf{w}) y^i=Sigmoid(xw)

      • 计算梯度:
        g r a d = 1 n ∑ i ( y ^ i − y i ) x i grad = \frac{1}{n}\sum_{i}(\hat{y}_i-y_i)x_i grad=n1i(y^iyi)xi

      • 批量梯度下降法更新特征权重 w \mathbf{w} w
        w i , k = w i , k − 1 − d k × α × g r a d w_{i,k} = w_{i,k-1} - d_k \times \alpha \times grad wi,k=wi,k1dk×α×grad

      def _update_weights(
          self,
          x,  # array-like
          y,  # array-like
          w,  # array-like
          total_batch: int,
          batch_size: int,
          dk_arr,  # array-like
      		):
          num_feat = x.shape[1]
          assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
          assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
          w = w.reshape((w.shape[0], 1))
      
          compute_dk = False
          if dk_arr is None:
            	compute_dk = True
            	dk_arr = []
      
          for idx in range(total_batch):
              begin = idx * batch_size
              end = min((idx + 1) * batch_size, x.shape[0])
              rows = end - begin
              
              # padding one col for bias in w
              x_slice = jnp.concatenate((x[begin:end, :], jnp.ones((rows, 1))), axis=1)
              y_slice = y[begin:end, :]
      
              pred = jnp.matmul(x_slice, w)
              pred = sigmoid(pred, method=self._sig_type)
      
              err = pred - y_slice
              grad = jnp.matmul(jnp.transpose(x_slice), err) / rows
      
              if compute_dk:
                  dk = compute_dk_func(grad, self._eps, self._dk_method)
                  dk_arr.append(dk)
              else:
              		dk = dk_arr[idx]
      
                  step = self._learning_rate * grad * dk
      
                  w = w - step
      
           if compute_dk:
              dk_arr = jnp.array(dk_arr)
      
           return w, dk_arr
      
    2. 训练

      进行epochs轮计算,每个epoch中每次选择batch_size个样本进行迭代。

      def fit(self, x, y):
        	"""Fit LR with policy-sgd."""
          assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"
          assert len(y.shape) == 2, f"expect y to be 2 dimension array, got {y.shape}"
      
          num_sample = x.shape[0]
          num_feat = x.shape[1]
          batch_size = min(self._batch_size, num_sample)
          total_batch = (num_sample + batch_size - 1) // batch_size
      
          # always fit intercept
          weights = jnp.zeros((num_feat + 1, 1))
          dk_arr = None
      
          # do train
          for _ in range(self._epochs):
      				weights, dk_arr = self._update_weights(
              x, y, weights, total_batch, batch_size, dk_arr
            )
      
            	self._weights = weights
            	self.dk_arr = dk_arr
      	
        return
      
    3. 预测

      根据 w \mathbf{w} w b i a s bias bias计算预测值:
      y ^ = s i g m o d ( w x + b i a s ) \hat{y} = sigmod(\mathbf{w}\mathbf{x}+bias) y^=sigmod(wx+bias)

      def predict_proba(self, x):
      		"""Probability estimates"""
          num_feat = x.shape[1]
          w = self._weights
          assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
          assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
          w.reshape((w.shape[0], 1))
      
          bias = w[-1, 0]
          w = jnp.resize(w, (num_feat, 1))
      
          pred = jnp.matmul(x, w) + bias
          pred = sigmoid(pred, method=self._sig_type)
      
          return pred
      
    4. 使用

      plain_model = SSLRSGDClassifier(
          epochs=3, learning_rate=0.1, batch_size=8, sig_type='t1', eps=1e-6, dk_method='norm'
      )
      
      # train
      plain_model.fit(X.values, y.values.reshape(-1, 1))  # X, y should be two-dimension array
      
      # predict
      predict_prob = plain_model.predict_proba(X.values)
      
      # score
      from sklearn.metrics import roc_auc_score
      print(f"auc = {roc_auc_score(y.values, predict_prob)}")
      
      auc = 0.9903083
      

2. 在SPU上的适配

SPU下验证精度(simulation)

​ 通常情况下,您只需使用 spu 执行类似 LR 的操作即可在安全上下文中运行程序:将数据集移动到 PYU 或 SPU,使用您声明的 SPU 运行程序并reveal您需要的一些信息(但是reveal是一个非常危险的操作,您应该在实际应用中非常小心地使用它)

​ 但是,我们稍后会看到,您可能会遇到明文和秘密之间的巨大度量差距(如 LR 中的 auc)。开发人员可以更简单地运行 MPC 程序,并具有高度的灵活性来调整超参数,例如环的大小、fxp 或底层 MPC 协议等,这将是一个更好的选择。

​ 因此,在这一部分中,我们将展示如何使用simulator运行我们的算法,就像运行普通的 MPC 程序一样,并进行最少的实验来关注和验证程序的缺陷。使用simulator而不是直接使用 SPU 设备运行程序有两个优点:

  • 更少的代码:无需处理大量的 DeviceObject 并在 SPU 之间从 PYU 移动数据。
  • 更快的实验:没有连接ray cluster,端到端运行实验。
(1) 两方模拟

​ 我们首先定义一个简单的simulator,使用 CHEETAH 协议和 64 位环,采用 2pc 设置。

  • 支持多种MPC协议
  • 支持改变ring大小
  • 支持改变fxp
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim = spsim.Simulator.simple(2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64)

也可以使用spu.RuntimeConfig进行自定义。设置enable_hal_profileenable_pphlo_profile可以在程序运行后看到对应的详细调用情况,比如通信量和总的通信量

[2024-06-25 05:03:03.393] [info] [api.cc:222] Link details: total send bytes 240, recv bytes 400, send actions 3
config_aby = spu.RuntimeConfig(
 protocol = spu_pb2.ProtocolKind.ABY3,
 field = spu.FieldType.FM64,
 fxp_fraction_bits = 18,
 enable_hal_profile=True,
 enable_pphlo_profile=True
)
sim_aby = spsim.Simulator(3, config_aby)

​ 定义训练和预测函数:就是根据之前的plaintext进行了一次封装

def fit_and_predict(
    x,
    y,
    epochs=3,
    learning_rate=0.1,
    batch_size=8,
    sig_type='t1',
    eps=1e-6,
    dk_method='norm',
):
    model = SSLRSGDClassifier(
        epochs=epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        sig_type=sig_type,
        eps=eps,
        dk_method=dk_method,
    )
    model.fit(x, y)
    return model.predict_proba(x)	
result = spsim.sim_jax(sim, fit_and_predict)(
    X.values, y.values.reshape(-1, 1)
)  # X, y should be two-dimension array

​ 评估拟合后的模型:

roc_auc_score(y, result)  

​ emm,最后的auc只有0.4…(不用慌,可以选择其他协议或者调参来优化)。

(2) 三方模拟

​ 更换了三方协议ABY3,使用64bit的环。

sim_aby = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)

result = spsim.sim_jax(sim_aby, fit_and_predict)(X.values, y.values.reshape(-1, 1))

auc = roc_auc_score(y, result)

​ 这样,得到的auc约为0.97。

当程序在不进行任何修改的情况下以分片形式运行计算,经过 3 个 epoch 的训练后,auc 可能会急剧下降(对于 cheetah 来说从 0.990 下降到 0.490)!

​ 我们将进行一些分析,并尝试首先从应用程序的角度对其进行修复,然后从 MPC 的角度进行更深入的思考。

(3) 算法角度继续思考
  • 一阶泰勒近似Sigmoid计算
    y ^ = 1 2 + 1 4 ( w x ) \hat{y} = \frac{1}{2}+\frac{1}{4}(\mathbf{wx}) y^=21+41(wx)

    如果内积 w x \mathbf{wx} wx趋近于-2,则 y y y趋近于0

    如果内积 w x \mathbf{wx} wx大于等于2,则 y y y将限制在1;(带截断)

  • grad的计算:
    g r a d = 1 n ∑ i ( s i g m o i d ( w T x i ) − y i ) x i grad = \frac{1}{n}\sum_i(sigmoid(\mathbf{w}^Tx_i)-y_i)x_i grad=n1i(sigmoid(wTxi)yi)xi
    如果使用t1的公式近似sigmoid,很有可能grad=0。

  • dk的计算
    d k = 1 ∑ j p g r a d j 2 + ϵ d_k=\frac{1}{\sqrt{\sum_j^pgrad_j^2}+\epsilon} dk=jpgradj2 +ϵ1

​ 如果碰巧 grad 的所有元素非常接近 0(MPC 中可能几乎没有误差),那么第一个 epoch 计算出的 d k dk dk就会变得非常大(分母非常接近于0)。

  • 参数更新
    w i , k = w i , k − 1 − d k × α × g r a d w_{i,k} = w_{i,k-1} - d_k \times \alpha \times grad wi,k=wi,k1dk×α×grad

​ 参数更新依赖于 d k d_k dk,如果 d k d_k dk过大,那么参数更新会变得异常剧烈,可能会导致训练失败。

  • 可以考虑的解决方案:

    • 增大batch_size:减少梯度为0的概率
    • 增大 d k d_k dk计算使用的 ϵ \epsilon ϵ:减少极端情况下学习率的抖动
    • 使用不带截断的Sigmoid近似方法(SR):减少梯度为0的概率
  • 例如,对于上面使用CHEETAH的simulator,这里我们修改了batch_size重新执行

    result = spsim.sim_jax(sim, partial(fit_and_predict, batch_size=64))(
        X.values, y.values.reshape(-1, 1)
    )
    roc_auc_score(y, result)  # rather pool under cheetah protocol!
    
    0.9892579673378785
    

    Congratulations~结果正常啦

(4) MPC角度继续思考

将根据底层协议进行讨论

https://www.secretflow.org.cn/zh-CN/docs/spu/main/reference/mpc_status

  • 协议选择

    https://www.secretflow.org.cn/zh-CN/docs/spu/main/reference/mpc_status

    一般来说,对于 2pc,使用 cheetah 是安全的,而对于 3pc,ABY3 是唯一的选择。

    • ABY3:一个诚实多数3PC协议,SPU提供半诚实实现,注意,如果两个以上的计算节点部署在一起,则很难抵御合谋攻击。
    • Semi2k-SPDZ*:一种类似于 SPDZ 的半诚实MPC 协议,但需要可信的第三方来生成离线随机数。默认情况下,此协议现在使用可信的第一方。因此,它应该仅用于调试目的。不安全
    • Cheetah*:一种快速的2pc半诚实协议,它使用精心设计的基于同态加密的协议进行算术运算,使用Ferret进行布尔运算。
  • SPU所能表示的最小正数为 2 − f 2^{-f} 2f,因为 f = 18 f = 18 f=18,因此为 2 − 18 2^{-18} 218

  • Cheetah是一个快速的两方半诚实协议,使用PHE来激素计算,但是在进行乘法运算时会产生0-2位错误。

    如果是64位ring,使用大约18位的定点数,因此spu可以表示的最小正浮点数是 1 2 18 ≈ 3.8147 × 1 0 − 6 \frac{1}{2^{18}} \approx 3.8147\times 10^{-6} 21813.8147×106

  • 因此最开始进行两方模拟时导致auc小于0.5,很可能出现了计算溢出。若梯度很小(接近 2 − 18 2^{-18} 218),则平方运算造成的0-2bit误差可能会对结果造成显著的影响,甚至可能从最小正数溢出到最小负数

    通过对一组很小的数求平方和来验证乘法是否出现溢出:

    # Let's test this
    def test_square_and_sum_when_x_small(x):
        return jnp.sum(jnp.square(x))
        
    spsim.sim_jax(sim, test_square_and_sum_when_x_small)(np.array([1e-5] * 10))
    

    结果只有 − 3.0517578 × 1 0 − 5 -3.0517578 \times 10 ^{-5} 3.0517578×105,近似于 − 2 − 18 × 10 -2^{-18} \times 10 218×10咦,结果符号都变了。

    改进思路仍然会以减少梯度为0入手:增大fxp和环大小:提高表示精度,减少下溢到0点概率

  • 我们保持batch_size不变,增加环的大小

    sim128 = spsim.Simulator.simple(
        2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM128
    )
    result = spsim.sim_jax(sim128, fit_and_predict)(X.values, y.values.reshape(-1, 1))
    auc = roc_auc_score(y, result)
    

    结果为0.990308387。

3. 在SPU上验证性能

​ 我们主要想论证在 MPC 设置中 policy-sgd 优于 naive-sgd,因此我们可以设计以下实验:

  • dk_method为policy-sgd找到最佳eps:对于所有数据集,比较准确性和效率。
  • sig_type比较 policy-sgd 和 naive-sgd切换时的准确率和效率。
  • 为了比较 policy-sgd 和 naive-sgd,我们固定epochs并测试learning_ratebatch_size的影响

当前只支持多进程模式,且需要从spu源码编译运行!

五、常见问题

1. 怎么知道SPU支持哪些算子?

​ 实践出真知,自己动手: simulation, emulation

  • https://www.secretflow.org.cn/zh-CN/docs/spu/0.9.1b0/reference/np_op_status
  • https://www.secretflow.org.cn/zh-CN/docs/spu/0.9.1b0/reference/xla_status

2. 怎么知道非线形算子大致的误差范围

​ 误差来自两个方面:

  • 系统设定误差(如环大小,fxp大小,truncation协议等)
  • 非线性算子拟合误差

​ 因此很难给出如浮点数的误差估计,以下文档给出了一些数学算子的大致误差:
https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp

3. 为什么明文下运行正常,密态下运行错误

  • 若运行报错:

    • 实现的算法是否jitable(即使用@jax.jit是否能运行)
    • 是否使用了SPU不支持的算子
  • 若能运行,但误差极大,可以自查是否有以下情况:

    • 是否可能发生溢出:输入数据或参数是否太大或太小
    • SPU内部是否使用了浮点随机数生成器
    • 是否调用了线形代数算子(如矩阵分解,奇异值分解等)
  • 若误差适中:

    可以考虑增大环的大小,提高fxp精度

4. 为什么Emulation的速度比simulation快很多?

​ 数据没有seal(即load到PYU),SPU将其视为Public数据,所有计算在明文下进行

5. 如何对密态算法进行优化

​ 有以下几个思路可以参考:

  • 减少耗时算子的调用(计算公式重写,多项式近似等)

  • 避免重复计算(空间换效率)

  • 并行化

    ​ 实际上,SPU内部已经做了大量的并行操作,若希望进一步优化,可以尝试:

  1. 算法层:for循环很多时候可以通过高阶tensor运算来代替,也可以考虑使用jax.vmap进行自动向量化
  2. Runtime:尝试开启更多并行(experimental feature),如experimental_enable_inter_op_par(即DAG并行)

六、如何使用SPU开发

​ 例如,计算policy_sgd的 d k = 1 t 0 2 + t 1 2 + ϵ dk = \frac{1}{\sqrt{t_0^2+t_1^2}+\epsilon} dk=t02+t12 +ϵ1

  • 首先生成明文数据
import numpy as np
import jax.lax

# Emulation should be run from source in SPU, so we use Secretflow here to do the efficiency experiments.
# You can use the similar trick for emulation directly in SPU.

def compute_dk_func(x, eps=1e-6):
    return jax.lax.rsqrt(jnp.sum(jnp.square(x)) + eps)

x = np.random.rand(1_000_000)

alice_device = sf.PYU("alice")
bob_device = sf.PYU("bob")
spu_obj = sf.SPU(
    cluster_def=cluster_conf,
    link_desc={
        "connect_retry_times": 60,
        "connect_retry_interval_ms": 1000
    },
)
# cluster_config参考之前笔记的设置,这里不再描述
  • 将明文数据放到alice_device下

    # first, load data to PYU
    alice_data = alice_device(lambda x: x)(x)
    
  • 在spu_obj上执行compute_dk_func计算 d k dk dk

    ret = spu_obj(compute_dk_func)(alice_data)
    sf.reveal(ret) # array(0.00173242, dtype=float32)
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值