摘要
这篇文章将激活、权重和梯度
量化为 4 位以加速神经网络训练。但是,现有的 4 位训练方法 需要当代不支持的自定义数字格式 硬件。在这项工作中,我们提出了一种变压器的培训方法,包括 使用 INT4 算法实现的矩阵乘法。培训与 超低 INT4 精度具有挑战性。为了实现这一目标,我们仔细分析 变压器中活化和梯度的具体结构提出 专用的量化器。对于前向传播,我们确定 挑战异常值并提出一个哈达玛量化器来抑制 异常。对于反向传播,我们利用梯度的结构稀疏性 通过提出位拆分并利用分数采样技术来量化 精确梯度。我们的算法在广泛的范围内实现了竞争性的准确性 任务范围包括自然语言理解、机器翻译、 和图像分类。与以前的 4 位训练方法不同,我们的算法 可以在当前一代的 GPU 上实现。我们的原型线性 操作员实施速度比 FP2 快 2.16 倍 并将培训速度提高了 35.1%。
正向传播
正向传播能以线性和非线性(GeLU, normalization, softmax等)算子的组合来实现。在我们的训练过程中,我们用INT4算术加速所有线性运算符,并将所有计算量较小的非线性运算符保留在16位浮 点(FP16)格式中。
线性算子
Transformer中的所有线性运算都可以写成矩阵乘法(MM)的形式。为了便于表述,本文考虑以下简单矩阵乘法的加速(全连接层可以表述成以下的公式,其中X是N = STtoken的激活,W是权重矩阵。):
Z = X W ⊤ , w h e r e Z ∈ R N × C , X ∈ R N × D a n d W ∈ R C × D . (1) Z = XW^⊤, where Z ∈ {R^{N×C}} , X ∈ R^{N×D}and \ W ∈ R^{C×D}. \tag{1} Z=XW⊤,whereZ∈RN×C,X∈RN×Dand W∈RC×D.(1)
- 注:对于注意力层,可能需要批量矩阵乘法(BMMS)。我们提出的技术可以应用于BMMS,详见 Appendix. A.1.。
为了加速训练,必须使用整数运算来计算前向传播。为此研究人员使用了学习步长量化器(LSQ, Steven K Esser, Jeffrey L McKinstry, Deepika Bablani, Rathinakumar Appuswamy, and Dharmendra S Modha. Learned step size quantization. In International Conference on Learning Representations, 2019)。LSQ是静态量化,他的量化尺度不依赖于输入的方法,因此比动态方法消耗更小,量化方法,需要在每次迭代时动态计算量化尺度。对于X的量化与反量化可以写为 f l o a t ( i n t s X ( X ) ) = s X i n t s X ( X ) ≈ X \color{red} float(int_{sX} (X)) = s_Xint_{sX} (X) ≈ X float(intsX(X))=sXintsX(X)≈X,关于细节描述可见原文或本文相关部分。
如此, Eq. (1) 可写为
Y
=
X
W
⊤
≈
s
X
s
W
i
n
t
s
X
(
X
)
i
n
t
s
W
(
W
)
⊤
(2)
Y = XW^⊤ ≈ s_{X}s_W int_{sX} (X)int_{sW} (W)^⊤\tag{2}
Y=XW⊤≈sXsWintsX(X)intsW(W)⊤(2)
非线性算子的HQ
激活异常值 简单地将LSQ应用到具有4位激活/权重的FQT会导致精度下降,因为激活异常值现象[57]。激活有一些离群值条目,我们必须截断范围[−QN sX,QP sX]之外的条目。不幸的是,Transformers倾向于将信息存储在这些异常值中,而且这样的截断会严重损害准确性。当训练任务要在一些新的下游任务上微调预先训练的模型时预训练模型比随机初始化包含更多的异常值[57]。
存在一些处理训练后量化(PTQ)的激活异常值的工作。Suppression[55]发现LayerNormals放大了异常值,并提出了Gamma迁移和Token Wise Clipping解决了这个问题,并在没有太多降级的情况下实现了6位BERT PTQ。SmoothQuant[57]将激活异常值的量化难度迁移到权重,并实现用于大型语言模型(如OPT-175B)的8位PTQ。异常值通道拆分[65]重复包含异常值的信道在网络大小上具有较小开销。然而,这些方法主要关注PTQ或QAT,很少成功处理超低4位训练。
哈达玛变换(Hardamand transform)是一个线性变换,它可以将异常值分摊到其他条目中。HQ使用块对角变换矩阵 H ∈ R D × D : H = B l o c k D i a g ( H k , … , H k ) H∈R^{D×D}:H=BlockDiag(H_k,…,H_k) H∈RD×D:H=BlockDiag(Hk,…,Hk),其中D是 2 k 2^k 2k的倍数.为了抑制异常值,我们量化X和W的变换版本:
- X = ( X H ) H T ≈ s X i n t s X ( X H ) H T X=(XH)H^T≈s_Xint_{sX}(XH)H^T X=(XH)HT≈sXintsX(XH)HT
-
W
=
(
W
H
)
H
T
≈
s
W
i
n
t
s
W
(
W
H
)
H
T
W=(WH)H^T≈s_Wint_{sW}(WH)H^T
W=(WH)HT≈sWintsW(WH)HT
结合量化矩阵,我们得到
Y = s X s W i n t s X ( X H ) H T H i n t s W ( H T W T ) = s X s W i n t s X ( X H ) i n t s W ( H T W T ) , (3) Y=s_X s_W int_{sX}(XH)H^T Hint_{sW}(H^TW^T) \\ \ \ \ \ \ \ \ \ =s_X s_W int_{sX}(XH) int_{sW}(H^TW^T),\tag{3} Y=sXsWintsX(XH)HTHintsW(HTWT) =sXsWintsX(XH)intsW(HTWT),(3)
反向传播
梯度计算
我们现在考虑使用INT4操作来加速线性层的反向传播。等式(3)中定义的线性算子HQ-MM具有四个输入:激活X、权重W和步长sX,sW。给定输出梯度
∇
Y
L
\nabla_{Y} L
∇YL,我们需要计算所有四个输入的梯度
。我们在本节中讨论了激活/权重梯度的计算,并将步长梯度的讨论留给附录A.3。
根据直通估计(Straight-through estimator是quantization中常见的求导方式,quantization是一个离散的方程,无法计算它的导数,所以STE就简单粗暴地直接把输出的导数作为了对输入的导数)和链式规则,我们得到:
∇ W = s X ( ∇ Y T X ^ ◦ I W ) H ⊤ , ∇ X = s W I X ◦ ∇ Y W ^ H ⊤ , (4) ∇_W = s_X(∇^T_Y\hat X ◦ I_W)H^⊤, ∇_X = s_W I_X ◦ ∇_Y\hat WH^⊤,\tag{4} ∇W=sX(∇YTX^◦IW)H⊤,∇X=sWIX◦∇YW^H⊤,(4)
- ◦ ◦ ◦ 为逐元素相乘, X ^ = i n t s X ( X H ) , W ^ = i n t s W ( W H ) \hat X = int_{sX} (XH), \hat W = int_{sW} (WH) X^=intsX(XH),W^=intsW(WH), I X = I ( − Q N ≤ X / s X ≤ Q P ) I_X = I(−QN ≤ X/sX ≤ QP ) IX=I(−QN≤X/sX≤QP), I W = I ( − Q N ≤ W / s W ≤ Q P ) I_W = I(−QN ≤ W/sW ≤ QP ) IW=I(−QN≤W/sW≤QP)
- 1.元素乘法◦ 0/1矩阵
I
X
I_X
IX(或
I
W
I_W
IW)与另一INT4(或INT32)矩阵的乘积。此操作的时间复杂度较低。
2.INT32矩阵与FP16块Hadamard矩阵 s W H ⊤ s_WH^⊤ sWH⊤的乘积,该矩阵也具有低时间复杂度,如第3.3节所述。
3.FP16梯度 ∇ Y ∇_Y ∇Y与INT4矩阵 X ^ \hat X X^或 W ^ \hat W W^的乘积,我们将通过将 ∇ Y ∇_Y ∇Y量化为INT4来加速。
4.1 Structural Sparsity of Gradients 梯度的结构稀疏性
4.2 Bit Splitting and Leverage Score Sampling
-
首先,我们提出了比特分割(BS),它将全精度矩阵分割为较高和较低的4个比特:
∇ Y ≈ s ↑ ∇ Y ↑ + s ↓ ∇ Y ↓ , (5) ∇_Y ≈ s_↑∇^↑_Y + s_↓∇^↓_Y,\tag{5} ∇Y≈s↑∇Y↑+s↓∇Y↓,(5)
BS可以通过首先将 ∇ Y ∇_Y ∇Y量化到INT4来实现,然后将残差量化为INT4, ∇ Y ↑ ∇^↑_ Y ∇Y↑ and ∇ Y ↓ ∇^↓_Y ∇Y↓是INT8的高位和低位4位代表。 -
如前所述,表示一个INT8和一个INT4的乘积,并且可以由两个INT4 MM实现,权重梯度涉及矩阵乘法可以写为
∇ Y T X ^ ≈ ( s ↑ ∇ Y ↑ T + s ↓ ∇ Y ↓ T ) X ^ = ∇ Y ↕ T X ↕ (6) ∇^T_Y\hat X ≈ ({s_↑∇^↑_Y}^T + {s_↓∇^↓_Y}^T)\hat X={∇^{\updownarrow}_Y}^T X^{\updownarrow} \tag{6} ∇YTX^≈(s↑∇Y↑T+s↓∇Y↓T)X^=∇Y↕TX↕(6)
-
∇ Y ↕ = [ s ↑ ∇ Y ↑ s ↓ ∇ Y ↓ ] ∈ R 2 N × C ∇^{\updownarrow}_Y=\begin{bmatrix}s_↑∇^↑_Y\\s_↓∇^↓_Y\end{bmatrix} \in R^{2N\times C} ∇Y↕=[s↑∇Y↑s↓∇Y↓]∈R2N×C
-
X ↕ = [ X ^ X ^ ] ∈ R 2 N × d X^{\updownarrow}=\begin{bmatrix}\hat X\\\hat X\end{bmatrix} \in R^{2N\times d} X↕=[X^X^]∈R2N×d
-
but BS doubles the amount of INT4 operations for MM.
-
然后文章提出了 leverage score sampling (LSS),Noticing that the MM Eq. (6) can be written as the sum of 2N rank-1 matrices:
相关
LSQ
LSQ(Learned Step Quantization)Quantization是一种用于神经网络模型压缩和加速的量化方法。它采用了一种自适应的量化策略,即在训练过程中学习参数的量化步长,并使用基于学习的量化步长对参数进行量化。通过这种方法,LSQ Quantization可以同时实现精度和压缩率的平衡。
在LSQ Quantization中,每个参数都被量化为一个具有固定位宽的整数,而量化步长则是在训练过程中学习的。具体地,对于每个参数,LSQ Quantization首先使用一个类似于二进制搜索的方法来确定最优的量化步长。然后,它使用基于学习的量化步长将参数量化为一个整数,并在反向传播过程中使用梯度下降法来更新量化步长。
通过这种方法,LSQ Quantization可以在保持精度的同时,将神经网络模型的大小和计算开销显著减少。
i
n
t
s
X
(
X
)
:
=
c
l
a
m
p
(
X
/
s
X
,
N
,
P
)
int_{sX}(X):=clamp(X/sX,N,P)
intsX(X):=clamp(X/sX,N,P)
hadamard matrix
CODE
quantize_forward_LSQ+HQ
核心及打包部分
- 主要通过函数quantize_cuda实现,函数结构如下:
量化
// https://github1s.com/xijiu9/Train_Transformers_with_INT4/blob/main/Cuda_LSS_HQ/quantize_forward_LSQ+HQ/quantize_forward_easy_kernel.cu#L21-L35
template<typename scalar_t>
__global__ void quantize_cuda_kernel(const scalar_t * __restrict__ MatI, int8_t * MatO, scalar_t * __restrict__ MatLSQ, const float scale, long long int size){
long long int x = threadIdx.x + blockIdx.x * blockDim.x;
if (x<size){
float temp = MatI[x] / scale;
MatLSQ[x] = temp;
MatO[x] = std::clamp((int)(round(temp)), -8, 7);
}
}
反量化
// https://github1s.com/xijiu9/Train_Transformers_with_INT4/blob/main/Cuda_LSS_HQ/quantize_forward_LSQ+HQ/quantize_forward_easy_kernel.cu#L146-L153
template<typename scalar_t>
__global__ void dequantize_cuda_kernel(const int32_t * gemm, scalar_t * __restrict__ output, const float scale, long long int size){
long long int x = threadIdx.x + blockIdx.x * blockDim.x;
if (x<size){
output[x] = scale * gemm[x];
// output[x] = 0;
}
}
int4乘法在cutluss的支持下进行
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<double>, torch::Tensor, long long int> quantize_cuda(torch::Tensor hx, torch::Tensor hy, float scale_x, float scale_y){
……
return std::make_tuple(output, q_x, q_y, lsq_x, lsq_y, time_vector, gemm, size);
}
函数接受输入张量hx和hy,以及两个缩放因子scale_x和scale_y作为参数。函数返回一个包含多个结果的元组。
首先,代码获取输入张量的维度信息,并为输出张量和中间结果张量创建了合适的选项。
然后,代码使用CUDA进行量化操作。量化操作分为两步:将浮点数数据转换为8位整数数据,并将其缩放到指定的范围内。代码使用CUDA的并行计算能力,分别对输入张量hx和hy进行量化操作,并将结果保存在q_x和q_y中。同时,还将缩放后的浮点数数据保存在lsq_x和lsq_y中。
接下来,代码将8位整数数据打包成4位整数数据。通过调用pack_cuda_kernel函数,将q_x和q_y中的数据打包到pack_qx和pack_qy中。
然后,代码使用Cutlass库进行整数矩阵乘法操作。通过调用CutlassSgemmNN函数,将pack_qx和pack_qy中的数据进行矩阵乘法运算,并将结果保存在gemm中。
最后,代码进行反量化操作,将gemm中的整数数据转换回浮点数数据,并将结果保存在output中。
代码还计算了量化、打包、矩阵乘法和反量化操作的时间,并将结果保存在time_vector中。
最后,代码返回一个包含output、q_x、q_y、lsq_x、lsq_y、time_vector、gemm和size的元组。
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<double>, torch::Tensor, long long int> quantize_cuda(torch::Tensor hx, torch::Tensor hy, float scale_x, float scale_y){
std::vector<double> time_vector; // 用于保存量化、打包、矩阵乘法和反量化操作的时间
cudaError_t result; // 用于保存CUDA函数的返回结果
long long int nx = hx.size(0); // 获取输入张量hx的第一个维度大小
long long int nz = hx.size(1); // 获取输入张量hx的第二个维度大小
long long int ny = hy.size(0); // 获取输入张量hy的第一个维度大小
// 设置输出张量和中间结果张量的选项
auto option_quantize = torch::TensorOptions().dtype(torch::kInt8).device(hx.device()); // 量化操作的选项
auto option_gemm = torch::TensorOptions().dtype(torch::kInt32).device(hx.device()); // 矩阵乘法操作的选项
auto option_dequantize = torch::TensorOptions().dtype(hx.dtype()).device(hx.device()); // 反量化操作的选项
dim3 block(N_THREADS); // 定义CUDA的block大小
cudaDeviceSynchronize(); // 同步CUDA设备
// 记录量化操作的开始时间
clock_t time_quantize_start = clock();
dim3 grid1((nx*nz-1)/(block.x)+1); // 计算量化操作的grid大小
torch::Tensor q_x = torch::empty({nx,nz}, option_quantize); // 创建量化结果张量q_x
torch::Tensor lsq_x = torch::empty({nx,nz}, option_dequantize); // 创建缩放后的浮点数结果张量lsq_x
long long int hx_size = (nx*nz); // 计算输入张量hx的总大小
// 调用CUDA的量化操作核函数
AT_DISPATCH_FLOATING_TYPES_AND_HALF(hx.scalar_type(), "quantize_cuda", ([&] {
quantize_cuda_kernel<scalar_t><<<grid1, block>>>(
hx.data_ptr<scalar_t>(),
q_x.data_ptr<int8_t>(),
lsq_x.data_ptr<scalar_t>(),
scale_x,hx_size);
}));
// 记录量化操作的结束时间
cudaDeviceSynchronize();
clock_t time_quantize_end = clock();
dim3 grid2((ny*nz-1)/(block.x)+1); // 计算量化操作的grid大小
torch::Tensor q_y = torch::empty({ny,nz}, option_quantize); // 创建量化结果张量q_y
torch::Tensor lsq_y = torch::empty({ny,nz}, option_dequantize); // 创建缩放后的浮点数结果张量lsq_y
long long int hy_size = (ny*nz); // 计算输入张量hy的总大小
// 调用CUDA的量化操作核函数
AT_DISPATCH_FLOATING_TYPES_AND_HALF(hy.scalar_type(), "quantize_cuda", ([&] {
quantize_cuda_kernel<scalar_t><<<grid2, block>>>(
hy.data_ptr<scalar_t>(),
q_y.data_ptr<int8_t>(),
lsq_y.data_ptr<scalar_t>(),
scale_y,hy_size);
}));
cudaDeviceSynchronize(); // 同步CUDA设备
clock_t time_pack_end = clock();
// 将8位整数数据打包成4位整数数据
dim3 grid_pack_x((nx*nz/2-1)/block.x+1); // 计算打包操作的grid大小
dim3 grid_pack_y((ny*nz/2-1)/block.x+1); // 计算打包操作的grid大小
torch::Tensor pack_qx = torch::empty({nx,nz>>1}, option_quantize); // 创建打包结果张量pack_qx
torch::Tensor pack_qy = torch::empty({ny,nz>>1}, option_quantize); // 创建打包结果张量pack_qy
int qx_size = hx_size >> 1; // 计算打包后的大小
int qy_size = hy_size >> 1; // 计算打包后的大小
// 调用CUDA的打包操作核函数
pack_cuda_kernel<<<grid_pack_x,block>>>(q_x.data_ptr<int8_t>(), pack_qx.data_ptr<int8_t>(), qx_size);
pack_cuda_kernel<<<grid_pack_y,block>>>(q_y.data_ptr<int8_t>(), pack_qy.data_ptr<int8_t>(), qy_size);
cudaDeviceSynchronize(); // 同步CUDA设备
clock_t time_pack_end = clock();
// 使用Cutlass库进行整数矩阵乘法操作
int lda = nz;
int ldb = nz;
int ldc = ny;
torch::Tensor gemm = torch::empty({nx, ny}, option_gemm); // 创建矩阵乘法结果张量gemm
result = CutlassSgemmNN(nx, ny, nz, reinterpret_cast<cutlass::int4b_t *>(pack_qx.data_ptr<int8_t>()), lda,
reinterpret_cast<cutlass::int4b_t *>(pack_qy.data_ptr<int8_t>()), ldb, gemm.data_ptr<int32_t>(), ldc);
cudaDeviceSynchronize(); // 同步CUDA设备
clock_t time_gemm_end = clock();
// 反量化操作
dim3 grid3((nx*ny-1)/block.x+1); // 计算反量化操作的grid大小
torch::Tensor output = torch::empty({nx, ny}, option_dequantize); // 创建反量化结果张量output
float scale = scale_x * scale_y; // 计算缩放因子
long long int size = nx * ny; // 计算输出张量的总大小
// 调用CUDA的反量化操作核函数
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "dequantize_cuda", ([&] {
dequantize_cuda_kernel<scalar_t><<<grid3, block>>>(
gemm.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(),
scale,size);
}));
cudaDeviceSynchronize(); // 同步CUDA设备
clock_t time_dequantize_end = clock();
// 计算各个操作的时间
double quantize_time = (double)(time_quantize_end - time_quantize_start) / CLOCKS_PER_SEC;
double pack_time = (double)(time_pack_end - time_quantize_end) / CLOCKS_PER_SEC;
double gemm_time = (double)(time_gemm_end - time_pack_end) / CLOCKS_PER_SEC;
double dequantize_time = (double)(time_dequantize_end - time_gemm_end) / CLOCKS_PER_SEC;
// 将各个操作的时间保存在time_vector中
time_vector.push_back(quantize_time);
time_vector.push_back(pack_time);
time_vector.push_back(gemm_time);
time_vector.push_back(dequantize_time);
// 返回包含多个结果的元组
return std::make_tuple(output, q_x, q_y, lsq_x, lsq_y, time_vector, gemm, size);
}
- 用pybind打包设置
// https://github1s.com/xijiu9/Train_Transformers_with_INT4/blob/main/Cuda_LSS_HQ/quantize_forward_LSQ+HQ/quantize_forward_easy.cpp#L1-L20
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("quantize", &quantize);// quantize调用quantize_cuda
}
- setup 打包,用的时候python setup_easy.py install生成动态链接库
// https://github1s.com/xijiu9/Train_Transformers_with_INT4/blob/main/Cuda_LSS_HQ/quantize_forward_LSQ+HQ/setup_easy.py#L1-L20
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
file_path = os.path.dirname(os.getcwd())
cutlass_path = "cutlass/cutlass/include"
file_path = os.path.join(file_path, cutlass_path)
setup(
name='quantize_forward_easy',
ext_modules=[
CUDAExtension(name='quantize_forward_easy', sources=[
'quantize_forward_easy.cpp',
'quantize_forward_easy_kernel.cu',
], include_dirs=[file_path],
extra_compile_args=["-std=c++17"])
],
cmdclass={
'build_ext': BuildExtension
})
在bert的使用见linear_act_python
import quantize_forward_easy as qfe
import quantize_grad_weight_speed as qgw // quantize_grad_weight_LSS+LSQ
import quantize_grad_input_speed as qgi
import shutil
class linear_act_python(Function):
@staticmethod
def forward(ctx, h_input, h_weight, scale_input, scale_weight, bias, layer_name):
* * *
h_input_flatten = h_input.reshape(-1,C_in)
out = qfe.quantize(h_input_flatten, h_weight, scale_input, scale_weight)
output = out[0].reshape(out_shape)
ctx.saved = out[1], out[2], out[3], out[4], scale_input, scale_weight, bias, input_shape, layer_name
return output
@staticmethod
def backward(ctx, grad_output):
# input, weight, bias = ctx.saved
q_input_flatten, q_weight, lsq_input_flatten, lsq_weight, scale_input, scale_weight, bias, input_shape, layer_name = ctx.saved
C_out = grad_output.shape[-1]
grad_output_flatten = grad_output.reshape(-1, C_out)
flag_weight = (q_input_flatten.shape[1] % 4 != 0)
if flag_weight:
# if True:
if qconfig.fp16:
grad_weight, grad_scale_weight, first_transform, second_transform, x1_len, x2_len, scale1 = special_quantize_grad_weight(grad_output_flatten, q_input_flatten, scale_input, 4, lsq_weight)
else:
grad_weight, grad_scale_weight, first_transform, second_transform, x1_len, x2_len, scale1 = special_quantize_grad_weight_float(grad_output_flatten, q_input_flatten, scale_input, 4, lsq_weight)
else:
weight_out = qgw.quantize(grad_output_flatten, 4, q_input_flatten, scale_input, lsq_weight)
grad_weight, grad_scale_weight = weight_out[0], weight_out[1]
first_transform, second_transform, x1_len, x2_len, scale1 = weight_out[-5], weight_out[-4], weight_out[-3], weight_out[-2], weight_out[-1]
flag_input = (grad_output_flatten.shape[1] % 32 != 0) or (q_weight.shape[1] % 4 != 0)
if flag_input or flag_weight:
# if True:
if qconfig.fp16:
grad_input_flatten, grad_scale_input = special_quantize_grad_input(grad_output_flatten, q_weight, scale_weight, 4, lsq_input_flatten)
else:
grad_input_flatten, grad_scale_input = special_quantize_grad_input_float(grad_output_flatten, q_weight, scale_weight, 4, lsq_input_flatten)
else:
activation_out = qgi.quantize(grad_output_flatten, 4, q_weight, scale_weight, lsq_input_flatten, first_transform, second_transform, x1_len, x2_len, scale1)
grad_input_flatten, grad_scale_input = activation_out[0], activation_out[1]
# grad_input_transform = grad_input.reshape(input.size())
grad_input = grad_input_flatten.reshape(input_shape)
return grad_input, grad_weight, grad_scale_input, grad_scale_weight, grad_bias, None
quantize_grad_input_LSS
……
quantize_grad_weight_LSS+LSQ
……
参考
-
[论文研读] Training Transformers with 4-bit Integers 理论部分:https://www.bilibili.com/video/BV1o94y1v7VF/?
-
https://github.com/xijiu9/Train_Transformers_with_INT4
-
https://arxiv.org/pdf/2306.11987.pdf
-
哈哈