以当前热点GPT模型的GPT2为例演示SPU的扩展性,支持快速实现密文大模型预测
讲师:吴豪奇
学习链接:https://www.bilibili.com/video/BV1cT421Y71k
一、隐私保护机器学习背景
1. 数据的重要性和安全性
-
数据至关重要
- 训练高质量模型需要大量的数据(依赖于数据的质量和数量)
- 模型服务商需要用户输入数据作为推理输入
-
数据中包含大量敏感信息
-
生物数据:图像、声音、基因信息等
-
金融数据:收入、支出、信贷等
-
法律法规监管:《个人信息保护法》、GDPR
机器学习中的数据隐私问题日益受到关注
-
那么,接下来就要考虑如何在发挥数据价值的同时保护数据安全
2. 解决方案:安全多方计算(MPC)
也可以基于TEE、基于FL
(1) MPC
多个参与方可以在互不泄露任何信息(除结果外)的情况下协作计算一个函数,并得到正确的结果。
但是存在风险,计算结果会泄漏给参与方。但是,一般我们认为这种反推结果的方式存在一定难度。例如两个参与方执行求和运算,Alice输入 a = 10 a = 10 a=10,Bob输入 b = 7 b = 7 b=7,求和得到 a + b = 17 a+b=17 a+b=17,那么Bob通过对结果进行反推,便可得到另一参与方的值( 17 − 7 = 10 17-7=10 17−7=10)。
例如三方求和。Alice拥有10,Bob拥有7,Charlie拥有25。各个参与方将数据分片,然后发送给参与方。然后各自计算本地的分片数据和(party_sum)。最后累加所有参与方的分片数据和,就可得到结果42。
(2) 基于MPC的隐私保护机器学习PPML
Privacy-Preserving Machine Learning (PPML)
-
隐私训练
输入方有多个,期望通过隐私保护机器学习扩充数据维度,训练一个表现更好的模型。
-
隐私推理
一个作为数据提供方的Alice,另一个为模型提供方Bob。期望通过隐私推理得到Alice提供的数据在Bob的模型下的推理结果,同时保护Alice的数据不被泄漏,Bob的模型不会泄露给Alice。
- 期望可以直接用 MPC 的方式高效地运行已有的机器学习程序,对ML工程师十分友好。
二、SPU架构简介
核心系统组件:
- 前端:机器学习程序
- 编译器:生成并优化SPU的IR(PPHLO)
- 运行时:以MPC协议的方式执行PPHLO
1. 前端:机器学习程序
- 基于JAX、TensorFlow和PyTorch开发机器学习程序。
- 通过Accelerated Linear Algebra(XLA)进行表示
2. 编译器:生成并优化 SPU 的 IR(PPHLO)
通过SPU的编译器将明文的计算过程转换为隐私保护的算子
3. 运行时:以MPC协议的方式执行PPHLO
- SPU支持的MPC协议很多,例如这里的ABY3、Cheetah和SPDZ2K
4. SPU的设计目标
易用、可扩展、高性能
三、NN密态训练/推理示例
1. 逻辑回归
https://github.com/secretflow/spu/blob/0.9.1b0/examples/python/ml/jax_lr/jax_lr.py
(1) 数据从哪来?
- 数据提供方Alice提供了特征 X 1 \mathbf{X}_1 X1,数据提供方Bob提供了特征 X 2 \mathbf{X}_2 X2
- Alice对应device P1,然后加载了前50个特征和标签 y y y
- Bob对应device P2,加载了后50个特征
(2) 如何加密保护数据?
- 数据方对数据加密发送到MPC计算方
- 外包模式,计算方拿到的是密文
import spu.utils.distributed as ppd
x1 = ppd.device("P1")(lambda x: x[:, :50])(x)
x2 = ppd.device("P2")(lambda x: x[:, 50:])(x)
y = ppd.device("P1")(lambda x: x)(y)
(3) 如何定义模型计算?
-
用JAX实现明文算法
参数为
n_epochs
迭代总轮次,n_iters
每个epoch迭代次数,step_size
为学习率class LogitRegression: def __init__(self, n_epochs=10, n_iters=10, step_size=0.1): self.n_epochs = n_epochs self.n_iters = n_iters self.step_size = step_size def fit_auto_grad(...): pass def fit_manual_grad(...): pass
-
定义
fit_auto_grad
:使用自动梯度计算(通过jax.grad
)来更新权重。-
权重
w
和偏置b
初始化为0。 -
如果启用了缓存
use_cache
,特征数据会被缓存。 -
使用
jnp.array_split
将特征和标签数据被分成小批次处理,记为xs
和ys
。 -
使用
jax.lax.fori_loop
进行指定轮次n_epochs
的训练,每轮使用所有批次的数据更新权重。 -
这里的梯度计算由
body_fun
函数定义,主要使用梯度下降法进行参数更新。梯度计算使用jax.grad
生成,梯度更新使用指定的学习率step_size
。grad = jax.grad(loss, argnums=(2, 3))(x, y, w_, b_, use_cache)
。这里的argnums
参数指定应该对哪些参数计算梯度,这里我们希望对w_
和b_
计算梯度,因此设置为(2, 3)
-
如果使用了缓存,训练结束后将放弃缓存。
def fit_auto_grad(self, feature, label, use_cache=False): w = jnp.zeros(feature.shape[1]) b = 0.0 if use_cache: feature = spu.experimental.make_cached_var(feature) xs = jnp.array_split(feature, self.n_iters, axis=0) ys = jnp.array_split(label, self.n_iters, axis=0) def body_fun(_, loop_carry): w_, b_ = loop_carry for x, y in zip(xs, ys): grad = jax.grad(loss, argnums=(2, 3))(x, y, w_, b_, use_cache) w_ -= grad[0] * self.step_size b_ -= grad[1] * self.step_size return w_, b_ ret = jax.lax.fori_loop(0, self.n_epochs, body_fun, (w, b)) if use_cache: feature = spu.experimental.drop_cached_var(feature, *ret) return ret
-
-
定义
fit_manual_grad
:手动计算梯度并更新权重使用JAX手动实现前向与反向传播
除了梯度由手动计算外,其余实现和上述基本一致。
def fit_manual_grad(self, feature, label, use_cache=False): w = jnp.zeros(feature.shape[1]) b = 0.0 if use_cache: feature = spu.experimental.make_cached_var(feature) xs = jnp.array_split(feature, self.n_iters, axis=0) ys = jnp.array_split(label, self.n_iters, axis=0) def body_fun(_, loop_carry): w_, b_ = loop_carry for x, y in zip(xs, ys): pred = predict(x, w_, b_) err = pred - y w_ -= jnp.matmul(jnp.transpose(x), err) / y.shape[0] * self.step_size b_ -= jnp.mean(err) * self.step_size return w_, b_ ret = jax.lax.fori_loop(0, self.n_epochs, body_fun, (w, b)) if use_cache: feature = spu.experimental.drop_cached_var(feature, *ret) return ret
(4) 如何执行密态模型计算?
- 计算方以密文数据作为输入
- 将模型的训练/推理计算图通过SPU编译器Compiler转换为相应的密态算子计算图
- 由SPU device按照MPC协议逐个执行
@ppd.device("SPU")
def train(x1, x2, y):
x = jnp.concatenate((x1, x2), axis=1)
lr = LogitRegression()
if auto_grad:
return lr.fit_auto_grad(x, y, use_cache)
else:
return lr.fit_manual_grad(x, y, use_cache)
W, b = train(x1, x2, y)
需要将函数使用
@ppd.device("SPU")
装饰器。
(5) 思考
-
整个密态训练流程和明文ML类似,除了多了
ppd.device
的装饰器。"devices": { "SPU": { "kind": "SPU", "config": { "node_ids": [ "node:0", "node:1" ], "experimental_data_folder": [ "/tmp/spu_data_0/", "/tmp/spu_data_1/" ], "spu_internal_addrs": [ "127.0.0.1:61330", "127.0.0.1:61331" ], "runtime_config": { "protocol": "CHEETAH", "field": "FM64", "enable_pphlo_profile": true, "enable_hal_profile": true } } }
-
P1和P2对应明文输入设备(PYUObject):数据加载完后会加密到SPU device
"P1": { "kind": "PYU", "config": { "node_id": "node:0" } }, "P2": { "kind": "PYU", "config": { "node_id": "node:1" } }
-
SPU对应由两方Cheetah协议实现的密态计算设备
-
通过SPU device抽象来实现PPML中的数据输入以及密态训练/推理
2. stax/flax
应对更为复杂的建模
上面介绍的逻辑回归计算简单,手动实现可行,但是如果是深度神经网络这种复杂的结构,手动实现就变得不太可能(或者很低效),此时可以考虑使用stax或者flax。
-
Flax demo
https://github.com/secretflow/spu/blob/0.9.0b2/examples/python/ml/flax_mlp/flax_mlp.py#L37
Flax 是一个高性能和灵活的神经网络库,构建在 Google 的 JAX 库之上。JAX 提供了基于 NumPy 的 API,支持自动微分和硬件加速(GPU 或 TPU)。Flax 旨在提供简洁而高效的模型定义和训练工具,让研究者和开发者可以轻松实验和部署复杂的机器学习模型。
import flax.linen as nn import jax import jax.numpy as jnp class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, x): for feat in self.features[:-1]: x = nn.relu(nn.Dense(feat)(x)) x = nn.Dense(self.features[-1])(x) return x def predict(params, x): return MLP(FEATURES).apply(params, x) def loss_func(params, x, y): pred = predict(params, x) def mse(y, pred): def squared_error(y, y_pred): # TODO: check this return jnp.multiply(y - y_pred, y - y_pred) / 2.0 # return jnp.inner(y - y_pred, y - y_pred) / 2.0 # fail, (10, 1) inner (10, 1) -> (10, 10), have to be (10,) inner (10,) -> scalar return jnp.mean(squared_error(y, pred)) return mse(y, pred)
-
stax demo
https://github.com/secretflow/spu/blob/0.9.0b2/examples/python/ml/stax_nn/models.py#L65
def lenet(): nn_init, nn_apply = stax.serial( Conv(out_chan=20, filter_shape=(5, 5), strides=(1, 1), padding='valid'), MaxPool(window_shape=(2, 2), strides=(2, 2)), Relu, Conv(out_chan=50, filter_shape=(5, 5), strides=(1, 1), padding='valid'), MaxPool(window_shape=(2, 2), strides=(2, 2)), Relu, Flatten, Dense(500), Relu, Dense(10), ) return nn_init, nn_apply
3. 复用构建好的模型
例如Huggingface提供的模型
- 安装
transformers
(1) 明文实现
-
加载分词器和模型
这里我采用离线保存好的
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config llm_model = "openai-community/gpt2" tokenizer = AutoTokenizer.from_pretrained(llm_model) pretrained_model = FlaxGPT2LMHeadModel.from_pretrained(llm_model)
-
进行分词
inputs_ids = tokenizer.encode( 'I enjoy walking with my cute dog', return_tensors='jax' ) print(inputs_ids)
-
定义基于gpt2预训练模型的文本生成函数
def text_generation(input_ids, params, token_num=10): config = GPT2Config() model = FlaxGPT2LMHeadModel(config=config) for _ in range(token_num): outputs = model(input_ids=input_ids, params=params) next_token_logits = outputs[0][0, -1, :] next_token = jnp.argmax(next_token_logits) input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1) return input_ids
文本生成循环:
- 函数通过
token_num
次迭代来生成指定数量的 tokens。 - 每次迭代中,模型根据当前的
input_ids
和params
生成输出。 - 输出
outputs
中的 logits 用于确定下一个 token。outputs[0][0, -1, :]
获取最后一个生成的 token 的 logits。 - 使用
jnp.argmax
选取 logits 中最可能的 token 作为下一个 token。 jnp.concatenate
将新 token 添加到input_ids
中,为下一次生成迭代做准备。
- 函数通过
-
将分词的结果传入
text_generation
函数:outputs_ids = text_generation(inputs_ids, pretrained_model.params)
(2) 如何迁移到密文计算中
- 模型无需改变
- 我们这里模拟了P1即alice方拥有模型,P2即Bob方提供输入
- 之后,在SPU上运行对应的
text_generation
函数。
def run_on_spu():
# encode context the generation is conditioned on
inputs_ids = tokenizer.encode(
'I enjoy walking with my cute dog', return_tensors='jax'
)
input_ids = ppd.device("P1")(lambda x: x)(inputs_ids)
params = ppd.device("P2")(lambda x: x)(pretrained_model.params)
outputs_ids = ppd.device("SPU")(
text_generation,
)(input_ids, params)
outputs_ids = ppd.get(outputs_ids)
return outputs_ids
我本地用ppd有点问题,后来改用下述方式:
alice = sf.PYU("alice")
bob = sf.PYU("bob")
token = alice(lambda :inputs_ids)()
model = bob(lambda: pretrained_model.params)()
res = spu_obj(text_generation)(token, model)
不过密文下运行很慢,明文下309ms就推理完的结果,这里大约花了14min。
- 代码改造成本很低(low cost)
4. 如何支持不同模型
-
实现所需的密态算子
-
SPU性能有优化空间