来自 https://fengyao.notion.site/off-policy-rl#246721e3f6c480259e6ff598ac4c317b 中引用的code:
# VLLM Side
import torch
from vllm import LLM, SamplingParams
import math
if __name__ == '__main__':
TEMPERATURE = 0.7
DTYPE = torch.bfloat16
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", dtype=DTYPE, enforce_eager=True)
# model = llm.llm_engine.model_executor.driver_worker.model_runner.model
# saved_outputs = []
# def logits_processor_hook(module, input, output):
# assert isinstance(output, torch.Tensor)
# saved_outputs.append(output.clone())
# model.logits_processor.register_forward_hook(logits_processor_hook)
prompts = [
"One of the most important things in life is to",
"The answer to 1 + 1 is",
]
outputs = llm.generate(
prompts,
sampling_params=SamplingParams(
max_tokens=512,
temperature=TEMPERATURE,
logprobs=2,
),
)
save_stuff = []
for output in outputs:
assert len(output.outputs[0].token_ids) == len(output.outputs[0].logprobs)
#for token, logprob in zip(output.outputs[0].token_ids, output.outputs[0].logprobs):
#print(token, logprob)
save_stuff.append(
{
"input_ids": output.prompt_token_ids,
"output_ids": output.outputs[0].token_ids,
"logprobs": output.outputs[0].logprobs,
}
)
# HF Side
torch.cuda.set_device(1)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", torch_dtype=DTYPE, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
seq_id = 0
vllm_errs = []
# hook_errs = []
vllm_prob_errs = []
# hook_prob_errs = []
for output in save_stuff:
token_ids = torch.tensor([*output["input_ids"], *output["output_ids"]], device="cuda").unsqueeze(0)
print(token_ids.shape)
with torch.inference_mode():
model_outputs = model(token_ids)
print(model_outputs[0].shape)
real_logprobs = F.log_softmax(model_outputs[0] / TEMPERATURE, dim=-1)
print(real_logprobs.shape)
for i in range(len(output["logprobs"])):
print("===", output["output_ids"][i], "===")
# hook_logprobs = F.log_softmax(saved_outputs[i][seq_id] / TEMPERATURE, dim=-1)
for key in output["logprobs"][i]:
_real_logprobs = real_logprobs[0, i -1 + len(output["input_ids"])]
vllm_rel_err = abs((output["logprobs"][i][key].logprob - _real_logprobs[key].item()) / (_real_logprobs[key].item() + 1e-10))
# hook_rel_err = abs((hook_logprobs[key].item() - _real_logprobs[key].item()) / (_real_logprobs[key].item() + 1e-10))
vllm_errs.append(vllm_rel_err)
# hook_errs.append(hook_rel_err)
vllm_prob = math.exp(output["logprobs"][i][key].logprob)
# hook_prob = math.exp(hook_logprobs[key].item())
real_prob = math.exp(_real_logprobs[key].item())
vllm_prob_err = abs(vllm_prob - real_prob)
# hook_prob_err = abs(hook_prob - real_prob)
vllm_prob_errs.append(vllm_prob_err)
# hook_prob_errs.append(hook_prob_err)
if (vllm_rel_err > 0.1) and real_prob < 0.9:
print(
key, output["logprobs"][i][key],
"HF logprobs:", real_logprobs[0, i -1 + len(output["input_ids"])][key].item()
)
print(f"Prob: {real_prob}, VLLM: {vllm_prob}")
# if (vllm_rel_err > 0.1 or hook_rel_err > 0.1) and real_prob < 0.9:
# print(
# key, output["logprobs"][i][key],
# "HF logprobs:", real_logprobs[0, i -1 + len(output["input_ids"])][key].item(),
# "Hook logprobs:", hook_logprobs[key].item(),
# )
# print(f"Prob: {real_prob}, VLLM: {vllm_prob}, Hook: {hook_prob}")
seq_id += 1
from statistics import mean, stdev, median
print("Relative logprob errors")
print(f"VLLM: max={max(vllm_errs)}, mean={mean(vllm_errs)}, stdev={stdev(vllm_errs)}, median={median(vllm_errs)}, min={min(vllm_errs)}")
# print(f"Hook: max={max(hook_errs)}, mean={mean(hook_errs)}, stdev={stdev(hook_errs)}, median={median(hook_errs)}, min={min(hook_errs)}")
print("Absolute prob errors")
print(f"VLLM: max={max(vllm_prob_errs)}, mean={mean(vllm_prob_errs)}, stdev={stdev(vllm_prob_errs)}, median={median(vllm_prob_errs)}, min={min(vllm_prob_errs)}")
# print(f"Hook: max={max(hook_prob_errs)}, mean={mean(hook_prob_errs)}, stdev={stdev(hook_prob_errs)}, median={median(hook_prob_errs)}, min={min(hook_prob_errs)}")
3195

被折叠的 条评论
为什么被折叠?



