那最后这个量化步骤不可以像rtn一样直接替换原始权重嘛,为什么非要写一个新的linearclass WQLinear(nn.Module):
# def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, dtype=torch.float16):
# super().__init__()
# if w_bit not in [4]:
# raise NotImplementedError("Only 4-bit are supported for now.")
# self.in_features = in_features
# self.out_features = out_features
# self.w_bit = w_bit
# self.group_size = group_size if group_size != -1 else in_features
# self.split_k_iters = 8
# self.interleave = 4
# # quick sanity check (make sure aligment)
# assert self.in_features % self.group_size == 0
# assert out_features % (32 // self.w_bit) == 0
# pack_num = 32 // self.w_bit
# int16_pack_num = 16 // self.w_bit
# assert out_features % (self.interleave) == 0
# self.register_buffer(
# "qweight",
# torch.zeros(
# (
# out_features // self.interleave,
# in_features // int16_pack_num * self.interleave,
# ),
# dtype=torch.int16,
# device=dev,
# ),
# )
# self.register_buffer(
# "scales",
# torch.zeros(
# (
# calculate_zeros_width(in_features, self.group_size) * pack_num,
# out_features,
# ),
# dtype=dtype,
# device=dev,
# ),
# )
# self.register_buffer(
# "scaled_zeros",
# torch.zeros(
# (
# calculate_zeros_width(in_features, self.group_size) * pack_num,
# out_features,
# ),
# dtype=dtype,
# device=dev,
# ),
# )
# if bias:
# self.register_buffer(
# "bias", torch.zeros((out_features), dtype=dtype, device=dev)
# )
# else:
# self.bias = None
# @classmethod
# def from_linear(
# cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
# ):
# awq_linear = cls(
# w_bit,
# group_size,
# linear.in_features,
# linear.out_features,
# linear.bias is not None,
# linear.weight.device,
# dtype=linear.weight.data.dtype
# )
# if init_only: # just prepare for loading sd
# return awq_linear
# # need scales and zeros info for real quantization
# assert scales is not None and zeros is not None
# scale_zeros = zeros * scales
# dtype = scales.dtype
# pack_num = 32 // awq_linear.w_bit
# qscales = torch.zeros(
# (
# scales.shape[0],
# calculate_zeros_width(linear.in_features, group_size) * pack_num,
# ),
# dtype=dtype,
# device=scales.device,
# )
# qscales[:, : scales.shape[1]] = scales
# # awq_linear.scales = scales.clone().half()
# awq_linear.scales = qscales.transpose(1, 0).contiguous()
# if linear.bias is not None:
# awq_linear.bias = linear.bias.clone().to(dtype)
# intweight = [] # 量化后的权重
# for idx in range(awq_linear.in_features):
# intweight.append(
# torch.round(
# (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
# / qscales[:, idx // group_size]
# ).to(torch.int)[:, None]
# )
# intweight = torch.cat(intweight, dim=1)
# # intweight = intweight.t().contiguous()
# intweight = intweight.to(dtype=torch.int32)
# awq_linear.qweight = pack_intweight(
# intweight.contiguous(), interleave=4, kstride=64
# )
# zeros = zeros.to(dtype=torch.int32)
# scaled_zeros = torch.zeros_like(qscales)
# # scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
# scaled_zeros[:, : scales.shape[1]] = -(
# qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
# ).to(dtype)
# awq_linear.scaled_zeros = scaled_zeros.transpose(1, 0).contiguous()
# return awq_linear
# @torch.no_grad()
# def forward(self, x):
# # out_shape = x.shape[:-1] + (self.out_features,)
# # inputs = x.reshape(-1, x.shape[-1])
# inputs = x
# if inputs.numel() / inputs.shape[-1] < 8:
# out = awq_inference_engine.gemv_forward_cuda_new(
# inputs,
# self.qweight,
# self.scales,
# self.scaled_zeros,
# inputs.numel() // inputs.shape[-1],
# self.out_features,
# self.in_features,
# self.group_size,
# )
# else:
# out = awq_inference_engine.gemm_forward_cuda_new(
# inputs, self.qweight, self.scales, self.scaled_zeros
# ) # - 8.0 * self.scales)
# out = out + self.bias if self.bias is not None else out
# # print(out)
# # assert 0
# return out
# def extra_repr(self) -> str:
# return (
# "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
# self.in_features,
# self.out_features,
# self.bias is not None,
# self.w_bit,
# self.group_size,
# )
# )