for i in range(len(layers)):
logger.info(f"=== Start quantize layer {i} ===")
layer = layers[i].to(dev)
if "mixtral" in args.net.lower():
# for mixtral, we only leverage lwc, which can be achieve by simply replace Linear with QuantLinear
qlayer = copy.deepcopy(layer)
for name, module in qlayer.named_modules():
if isinstance(module,torch.nn.Linear) and not "gate" in name: # do not quantize gate
quantlinear = QuantLinear(module, args.weight_quant_params, args.act_quant_params)
add_new_module(name, qlayer, quantlinear)
else:
qlayer = DecoderLayer(lm.model.config, layer, args)
qlayer = qlayer.to(dev)
# obtain output of full-precision model
set_quant_state(qlayer, weight_quant=False, act_quant=False)
if args.epochs > 0:
with torch.no_grad():
with torch.cuda.amp.autocast():
for j in range(args.nsamples):
fp_inps[j] = qlayer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]
if args.aug_loss:
fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]
# init smooth parameters
set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward
qlayer.let = args.let
use_shift = True
if is_llama or args.abits == 16:
use_shift = False # deactivate channel-wise shifting for llama model and weight-only quantization
if args.let:
# init channel-wise scaling and shift
qlayer.register_parameter("qkt_smooth_scale",torch.nn.Parameter(torch.ones(layer.self_attn.q_proj.out_features,device=dev, dtype=dtype)))
for name,module in qlayer.named_modules():
if isinstance(module, QuantLinear):
for key in pairs.keys():
if key in name:
act = act_scales[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype).clamp(min=1e-5)
weight = module.weight.abs().max(dim=0)[0].clamp(min=1e-5)
scale = (act.pow(args.alpha)/weight.pow(1-args.alpha)).clamp(min=1e-5)
if use_shift and not is_llama:
shift = act_shifts[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype)
else:
shift = torch.zeros_like(scale)
qlayer.register_parameter(f"{pairs[key]}_smooth_shift",torch.nn.Parameter(shift))
qlayer.register_parameter(f"{pairs[key]}_smooth_scale",torch.nn.Parameter(scale))
if args.resume:
qlayer.load_state_dict(omni_parameters[i], strict=False)
if args.epochs > 0:
with torch.no_grad():
qlayer.float() # required for AMP training
# create optimizer
optimizer = torch.optim.AdamW(
[{"params":let_parameters(qlayer, use_shift),"lr":args.let_lr}, {"params":lwc_parameters(qlayer),"lr":args.lwc_lr}],weight_decay=args.wd)
loss_scaler = utils.NativeScalerWithGradNormCount()
for epochs in range(args.epochs):
loss_list = []
norm_list = []
for j in range(args.nsamples//args.batch_size):
index = j * args.batch_size
# obtain output of quantization model
with traincast():
smooth_and_quant_temporary(qlayer, args, is_llama)
quant_out = qlayer(quant_inps[index:index+args.batch_size,], attention_mask=attention_mask_batch,position_ids=position_ids)[0]
loss = loss_func(fp_inps[index:index+args.batch_size,], quant_out)
if args.aug_loss:
loss += loss_func(fp_inps_2[index:index+args.batch_size,], quant_out)
if not math.isfinite(loss.item()):
logger.info("Loss is NAN, stopping training")
pdb.set_trace()
loss_list.append(loss.detach().cpu())
optimizer.zero_grad()
norm = loss_scaler(loss, optimizer,parameters= get_omni_parameters(qlayer, use_shift)).cpu()
norm_list.append(norm.data)
loss_mean = torch.stack(loss_list).mean()
norm_mean = torch.stack(norm_list).mean()
logger.info(f"layer {i} iter {epochs} loss:{loss_mean} norm:{norm_mean} max memory_allocated {torch.cuda.max_memory_allocated(lm._device) / 1024**2} ")
clear_temp_variable(qlayer)
del optimizer
qlayer.half()
# real smooth and quantization
smooth_and_quant_inplace(qlayer, args, is_llama)
if args.epochs>0:
# update input of quantization model
with torch.no_grad():
# with torch.cuda.amp.autocast():
with traincast():
for j in range(args.nsamples):
quant_inps[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]
register_scales_and_zeros(qlayer)
layers[i] = qlayer.to("cpu")
omni_parameters[i] = omni_state_dict(qlayer)
torch.save(omni_parameters, os.path.join(args.output_dir, f"omni_parameters.pth"))
else:
register_scales_and_zeros(qlayer)
layers[i] = qlayer.to("cpu")
if args.real_quant:
assert args.wbits in [2,3,4] and args.abits >= 16 # only support weight-only quantization
named_linears = get_named_linears(qlayer)
for name, module in named_linears.items():
scales = module.weight_quantizer.scales
zeros = module.weight_quantizer.zeros
group_size = module.weight_quantizer.group_size
dim0 = module.weight.shape[0]
scales = scales.view(dim0,-1)
zeros = zeros.view(dim0,-1)
if args.wbits == 3:
q_linear = qlinear_cuda.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None)
else:
q_linear = qlinear_triton.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None)
q_linear.pack(module.cpu(), scales.float().cpu(), zeros.float().cpu())
add_new_module(name, qlayer, q_linear)
print(f"pack quantized {name} finished")
del module
del layer
torch.cuda.empty_cache()
最新发布