背景
使用GRPO强化微调模型时,有时需要复杂的评判规则、多模态输入下给出奖励分数(如一段生成文本是否搞笑,一个生成的SVG代码渲染后是否符合标题,一个网页渲染后是否排版合理),此时可以引入大模型作为打分器。Ouyang et al.,2022 的研究表明,LLM 越来越与人类价值观和推理过程保持一致。这种一致性使 LLM 能够从生成性任务过渡到评估角色。LLM-as-a-Judge 的核心是指使用 LLM 根据预定义的规则、标准或偏好来评估对象、行动或决策。
在TRL的GRPO Trainer中,允许接入奖励模型(RM)作为打分器,然而这个Model需要在本地运行,将引起额外的显存消耗。我们寻求一种方案,利用API接入的方式,在其他服务器上并行推理,计算奖励。
问题建模
GRPO训练中提供原始prompt ,多组补全文本
,每个生成样本打分规则
,需要利用大模型API,并行处理该规则。
,
中均有可能包含图片。
解决方案
GRPO训练脚本可参考我的这篇博客,包括数据集和基本类型奖励函数设定
LLM奖励函数设置(Gu, 2024):
这个pipeline展示了LLM作为奖励函数的一般流程,可参考其中的流程和prompt。

以下函数以评判SVG质量为例,调用LLM进行打分。其中completions是一个step中多组生成的文本,不包含输入
。为了获取输入,我们利用`prompts = [prompt for prompt in kwargs['description']]`从数据集中读取对应的'description'字段。然后,将打分prompt、标题和completion拼接。
打分文本`result_aes`由`my_api.process_batch()`给出,这是一个自定义的对象,用于并行的获取LLM API返回值,每个打分过程彼此独立。注意⚠️返回值可能并未严格遵循格式,所以在获取数值失败情况下打分0.5。
from get_response_async import async_api
my_api = async_api()两行将在下一节介绍
from get_response_async import async_api
my_api = async_api()
import os
from PIL import Image
from io import BytesIO
import cairosvg
def svg_to_png(svg_code: str):
if '```svg' in svg_code:
svg_code = svg_code.split('```svg')[-1]
svg_code = svg_code.split('```')[0]
# Ensure SVG has proper size attributes
if 'viewBox' not in svg_code:
svg_code = svg_code.replace(
'<svg', f'<svg viewBox="0 0 384 384"'
)
# Convert SVG to PNG,no output
png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'))
return Image.open(BytesIO(png_data)).convert('RGB').resize((384,384))
def image_reward_func(completions, **kwargs):
SCORE_PROMPT = '''
You are a professional critic commenter. Please give a score to the following painting based on given caption and description.
<constrains>
If the image content matches the captain, give a score near to 1.
Otherwise, give a score near to 0. If no image input, you should give score 0.
The consideration should includes:
1. Completeness. Every object should appear in the image. For each missing, -0.1 score.
2. Unrecognizable object. If an unrecognizable object appears, you should -0.5 score.
3. Color. Each color should be consistent to prompt. For each mismatch, -0.1 score.
4. Absolute and Relative position. For each element, the absolute position should be appropriate.
For every two elements, the relative position should match the description. For each mismatch, -0.1 score.
</constrains>
<format>
You should output with format, output it first before any analysis. The score number ranges 0.0-1.0.
<Score>...</Score> ...
</format>
'''
contents = [completion[0]["content"] for completion in completions]
prompts = [prompt for prompt in kwargs['description']]
results_score = []
all_eval_prompts = []
all_eval_image = []
for i,single_content in enumerate(contents):
try:
img_obj = svg_to_png(single_content)
except:
img_obj = None
messages = SCORE_PROMPT+'###Captain: ' +prompts[i]
all_eval_prompts.append(messages)
all_eval_image.append(img_obj)
results_aes = my_api.process_batch(all_eval_prompts,all_eval_image)
for i,result in enumerate(results_aes):
try:
final_score = float(result.split('<Score>')[-1].split('</Score>')[0])
except:
final_score = 0.5
results_score.append(final_score)
return results_score
如果你的数据集中直接包含了图像,不需要这个例程中svg文本->图像的步骤,则可以替换读取图像的语句为:
#try:
# img_obj = svg_to_png(single_content)
#except:
# img_obj = None
try:
img_obj = Image.open(kwargs['img_path']) # 需要保证数据集中有一个字段img_path用来存储图像地址
except:
img_obj = None
并行LLM API调用(支持多模态🔥):
将以下代码保存成`get_response_async.py`,使得上述代码能够调用。
核心函数是`process_batch(self,batch_title:list,batch_img:list)` ,该函数接收列表中的一系列打分指令和内容文本,同时支持每个样本输入一张图片(如果是多图,请参考openai SDK接入指定模型的文档)。函数调用`AsyncOpenAI`来进行并行推理,支持在jupyter notebook等loop环境中调用。图片在输入过程中需要进行PNG格式的base64编码,这里采用保存一个临时图片后再读入方案。
import asyncio
import nest_asyncio
from openai import AsyncOpenAI
import time
import base64
from PIL import Image
import random
import os
# 这个函数处理单个请求,返回单个结果
class async_api:
def __init__(self) -> None:
nest_asyncio.apply()
self.aclient = AsyncOpenAI(
base_url="https://xxx/api/v3", # 替换为你的 base_url
api_key="" # 替换为你的 API 密钥
)
async def async_query_openai(self,query,image:Image):
# convert image to base64 with png format
if image == None:
return '<Score>0.0</Score>'
tmp_image_name = str(random.randint(100000,999999))+'tmp.png'
image.save(tmp_image_name)
with open(tmp_image_name, "rb") as image_file:
image = image_file.read()
base64_image = base64.b64encode(image).decode('utf-8')
os.remove(tmp_image_name)
completion = await self.aclient.chat.completions.create(
model="xxx", # 替换为你的 model_name
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]
}
],
temperature=0.5,
top_p=0.9,
max_tokens=512
)
return completion.choices[0].message.content # 请确保返回的数据结构正确
# 这个函数接收一个请求列表,返回所有请求的结果列表
async def async_process_queries(self,queries):
results = await asyncio.gather(*(self.async_query_openai(query[0],query[1]) for query in queries))
return results
def process_batch(self,batch_title:list,batch_img:list):
''' Process a batch of text and image prompt. Image element should be PIL.Image
'''
# 修补 asyncio
nest_asyncio.apply()
if len(batch_title)!=len(batch_img):
raise Exception('Feedback titles can not match images one-by-one')
combined_data = list(zip(batch_title, batch_img))
loop = asyncio.get_event_loop()
results = loop.run_until_complete(self.async_process_queries(combined_data))
return results
if __name__=='__main__':
myapi = async_api()
image = Image.open('./images1/0007.png')
image2 = Image.open('./images1/0009.png')
out1 = myapi.process_batch(['描述一下这张图片的内容','这张图有什么特色'],[image,image2])
print(out1)
1184

被折叠的 条评论
为什么被折叠?



