GRPO多模态奖励函数:利用大模型API接入

背景

使用GRPO强化微调模型时,有时需要复杂的评判规则、多模态输入下给出奖励分数(如一段生成文本是否搞笑,一个生成的SVG代码渲染后是否符合标题,一个网页渲染后是否排版合理),此时可以引入大模型作为打分器。Ouyang et al.,2022 的研究表明,LLM 越来越与人类价值观和推理过程保持一致。这种一致性使 LLM 能够从生成性任务过渡到评估角色。LLM-as-a-Judge 的核心是指使用 LLM 根据预定义的规则、标准或偏好来评估对象、行动或决策。

在TRL的GRPO Trainer中,允许接入奖励模型(RM)作为打分器,然而这个Model需要在本地运行,将引起额外的显存消耗。我们寻求一种方案,利用API接入的方式,在其他服务器上并行推理,计算奖励。

问题建模

GRPO训练中提供原始prompt x_p,多组补全文本c_i,每个生成样本打分规则s_i = f(x_p,c_i),需要利用大模型API,并行处理该规则。x_pc_i中均有可能包含图片。

解决方案

GRPO训练脚本可参考我的这篇博客,包括数据集和基本类型奖励函数设定

使用GRPO微调VLM模型(QWen 2.5 VL)_grpo vlm-优快云博客文章浏览阅读1.9k次,点赞39次,收藏41次。使用jupyter notebook载入、推理和LoRA强化微调一个QWen 2.5模型* 使用GRPO强化微调,奖励函数等设定* 私有数据集上需要做微调适配,除了SFT,强化微调提供了其他可行方案。* 很多情况下数据集的图文对包含答案简短,推理信息需要模型自行补全。然而,一般的SFT训练决定了模型输出必须是数据集中简短的答案形式。GRPO训练有助于激发模型的推理潜能。_grpo vlm https://blog.youkuaiyun.com/apd_csdn/article/details/146592312

LLM奖励函数设置(Gu, 2024)

这个pipeline展示了LLM作为奖励函数的一般流程,可参考其中的流程和prompt。

以下函数以评判SVG质量为例,调用LLM进行打分。其中completions是一个step中多组生成的文本\{c_i\},不包含输入x_p。为了获取输入,我们利用`prompts = [prompt for prompt in kwargs['description']]`从数据集中读取对应的'description'字段。然后,将打分prompt、标题和completion拼接。

打分文本`result_aes`由`my_api.process_batch()`给出,这是一个自定义的对象,用于并行的获取LLM API返回值,每个打分过程彼此独立。注意⚠️返回值可能并未严格遵循格式,所以在获取数值失败情况下打分0.5。

from get_response_async import async_api
my_api = async_api()两行将在下一节介绍

from get_response_async import async_api
my_api = async_api()

import os
from PIL import Image
from io import BytesIO
import cairosvg

def svg_to_png(svg_code: str):
    if '```svg' in svg_code:
        svg_code = svg_code.split('```svg')[-1]
        svg_code = svg_code.split('```')[0]
    # Ensure SVG has proper size attributes
    if 'viewBox' not in svg_code:
        svg_code = svg_code.replace(
            '<svg', f'<svg viewBox="0 0 384 384"'
        )
    # Convert SVG to PNG,no output
    png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'))
    return Image.open(BytesIO(png_data)).convert('RGB').resize((384,384))

def image_reward_func(completions, **kwargs):
    SCORE_PROMPT = '''
    You are a professional critic commenter. Please give a score to the following painting based on given caption and description.
    <constrains>
    If the image content matches the captain, give a score near to 1. 
    Otherwise, give a score near to 0. If no image input, you should give score 0. 
    The consideration should includes: 
    1. Completeness. Every object should appear in the image. For each missing, -0.1 score.
    2. Unrecognizable object. If an unrecognizable object appears, you should -0.5 score.
    3. Color. Each color should be consistent to prompt. For each mismatch, -0.1 score.
    4. Absolute and Relative position. For each element, the absolute position should be appropriate. 
    For every two elements, the relative position should match the description. For each mismatch, -0.1 score.
    </constrains>
    <format>
    You should output with format, output it first before any analysis. The score number ranges 0.0-1.0. 
    <Score>...</Score> ...
    </format>
    '''

    
    contents = [completion[0]["content"] for completion in completions]
    prompts = [prompt for prompt in kwargs['description']]
    results_score = []
    all_eval_prompts = []
    all_eval_image = []
    for i,single_content in enumerate(contents):
        try:
            img_obj = svg_to_png(single_content)
        except:
            img_obj = None
        messages = SCORE_PROMPT+'###Captain: ' +prompts[i]
        all_eval_prompts.append(messages)
        all_eval_image.append(img_obj)
    results_aes = my_api.process_batch(all_eval_prompts,all_eval_image)
    for i,result in enumerate(results_aes):
        try:
            final_score = float(result.split('<Score>')[-1].split('</Score>')[0])
        except:
            final_score = 0.5
        results_score.append(final_score)
    return results_score 

 如果你的数据集中直接包含了图像,不需要这个例程中svg文本->图像的步骤,则可以替换读取图像的语句为:

#try:
#    img_obj = svg_to_png(single_content)
#except:
#    img_obj = None

try:
    img_obj = Image.open(kwargs['img_path']) # 需要保证数据集中有一个字段img_path用来存储图像地址
except:
    img_obj = None

并行LLM API调用(支持多模态🔥):

将以下代码保存成`get_response_async.py`,使得上述代码能够调用。

核心函数是`process_batch(self,batch_title:list,batch_img:list)` ,该函数接收列表中的一系列打分指令和内容文本,同时支持每个样本输入一张图片(如果是多图,请参考openai SDK接入指定模型的文档)。函数调用`AsyncOpenAI`来进行并行推理,支持在jupyter notebook等loop环境中调用。图片在输入过程中需要进行PNG格式的base64编码,这里采用保存一个临时图片后再读入方案。

import asyncio
import nest_asyncio
from openai import AsyncOpenAI
import time
import base64
from PIL import Image
import random
import os

# 这个函数处理单个请求,返回单个结果
class async_api:
    def __init__(self) -> None:
        nest_asyncio.apply()    
        self.aclient = AsyncOpenAI(
            base_url="https://xxx/api/v3", # 替换为你的 base_url
            api_key=""  # 替换为你的 API 密钥
        )
    async def async_query_openai(self,query,image:Image):

        # convert image to base64 with png format
        if image == None:
            return '<Score>0.0</Score>'
        tmp_image_name = str(random.randint(100000,999999))+'tmp.png'
        image.save(tmp_image_name)
        with open(tmp_image_name, "rb") as image_file:
            image = image_file.read()
        base64_image = base64.b64encode(image).decode('utf-8')
        os.remove(tmp_image_name)
        completion = await self.aclient.chat.completions.create(
            model="xxx",    # 替换为你的 model_name
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": query
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            }
                        }
                    ]
                }
            ],
            temperature=0.5,
            top_p=0.9,
            max_tokens=512
        )
        return completion.choices[0].message.content  # 请确保返回的数据结构正确
    
    # 这个函数接收一个请求列表,返回所有请求的结果列表
    async def async_process_queries(self,queries):
        results = await asyncio.gather(*(self.async_query_openai(query[0],query[1]) for query in queries))
        return results
    
    
    def process_batch(self,batch_title:list,batch_img:list):
        ''' Process a batch of text and image prompt. Image element should be PIL.Image  
        '''
        # 修补 asyncio
        nest_asyncio.apply()
        if len(batch_title)!=len(batch_img):
            raise Exception('Feedback titles can not match images one-by-one')
        combined_data = list(zip(batch_title, batch_img))

        loop = asyncio.get_event_loop()
        results = loop.run_until_complete(self.async_process_queries(combined_data))
        return results


if __name__=='__main__':
    myapi = async_api()
    image = Image.open('./images1/0007.png')
    image2 = Image.open('./images1/0009.png')
    out1 = myapi.process_batch(['描述一下这张图片的内容','这张图有什么特色'],[image,image2])
    print(out1)

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值