裁判函数 + 训练策略(可落地模板)

目录

  1. 总览:从“能判”到“判得准”的三层结构

  2. 裁判模块设计:Math / Code / OpenQA(RAG)/ NLI 一致性 / 风格合规

  3. 多指标聚合与分数校准(胜出判定 → DPO pairs)

  4. 训练策略:数据蒸馏、困难样本挖掘、温度日程与博弈编队

  5. 评测与调参清单(可复制指标表)

  6. 可直接运行的示例代码(≥30 行×3块)

  7. 参考链接


1. 总览:从“能判”到“判得准”的三层结构

目标:把自我博弈产生的多条路径,稳定、可解释地判出优劣,形成高质量 (prompt, chosen, rejected) 偏好对。

三层结构

  • L1 可验证正确性(Verifiable):数学可直接验算;代码可在沙箱运行单测;抽取/分类任务可规则对齐。

  • L2 事实一致性(Factuality):开放问答接入检索(RAG),用证据驱动一致性评分;可辅以轻量 NLI。

  • L3 结构与风格(Form & Style):条理性、覆盖率、引用格式、安全合规等,作为兜底的细化维度。

实战建议:优先 L1(最强客观性),再 L2(事实性),最后 L3(可解释的启发式)。确保每层都能独立工作,组合时分配权重并做分数校准。


2. 裁判模块设计

2.1 数学(Math Judge)

  • 提取最终数值(正则/指示语)

  • 可选:Sympy 复核、单位换算(cm²/㎡ 等)

  • 评分:正确=1,错误按误差衰减(1 − |pred−gold|/scale)

2.2 代码(Code Judge)

  • 受限沙箱(time/memory 限制)运行小单测,捕获异常

  • 评分:通过全部测试=1,部分通过按比例给分;运行错误给负分

2.3 开放问答(RAG Judge)

  • 轻量检索:BM25/向量(可用 rank_bm25 / sentence-transformers

  • 证据打分:答案句与证据的重叠/相似度 → 归一化

  • 引用检查:是否给出出处/引用标识

2.4 NLI 一致性(可选)

  • 轻量模型(如 microsoft/deberta-v3-small NLI 头)做证据 → 结论的蕴含判定

  • 评分:蕴含 +1,矛盾 −1,中立 0.2

2.5 结构与风格(Form & Style)

  • 条理性:是否包含列表/小标题

  • 覆盖度:问题关键词/术语覆盖(启发式)

  • 合规:敏感/不当词过滤(负分)


3. 多指标聚合与分数校准(胜出判定 → DPO pairs)

把不同任务的分数统一到 [0,1][−1,1] 区间,然后做加权和;为避免某一路径在某些任务上天然分高,采取分位数归一化 + 温度缩放(Platt/温度)做校准。

  • 示例聚合
    score = 0.55 * verifiable + 0.30 * factuality + 0.15 * form_style

  • 平局化解:看“生成长度/多样性/覆盖度”微调 0.02–0.05 作为细小偏置

  • 不可判样本:直接丢弃(不要把噪声送进 DPO)


4. 训练策略(让 win-rate 稳步上升)

4.1 数据蒸馏与困难样本挖掘(Hard Negative Mining)

  • 先用温和参数(较低 temperature)生成 A/B,得到基础偏好集

  • 周期性重新生成 C/D…,并与当前策略对战,仅保留“险胜/险负”的对,最提升边际

4.2 温度日程(Temperature Schedule)

  • 前期高温(0.9~1.0):制造分歧、扩大搜索面

  • 中后期降温(0.6~0.7):稳定风格,减少无意义差异

  • 对不同任务分桶设温度:数学/代码较低,开放问答略高

4.3 博弈编队(League Training)

  • 同一底模训练出 policy-A/B 两个分支,交叉对战

  • 定期“晋级/降级”:把胜率低的替换为新初始化的探索者,维持生态多样

4.4 DPO 细节

  • beta:0.05–0.2 常用,边挖 hard negative 边适度增大(对偏好差异更敏感)

  • 混合训练:DPO 样本混入少量 SFT “黄金链路”样本(5–15%)稳风格

4.5 PPO 细节(若用在线)

  • target_kl 控制在 0.05–0.15,防漂移

  • reward clippingvalue clip 保守设置

  • 短 episode,勤评估:每 N 步在固定 dev-set 看 win-rate


5. 评测与调参清单(可复制)

指标说明目标值/趋势
Win-Rate (dev)新策略 vs 旧策略/基线 的胜率逐轮上升(>55%)
Math Acc有标准答案任务的准确率稳步上涨
Factual@Top-kRAG 证据一致性(Top-k)>0.7
Length-Penalty冗长惩罚后的平均长度适中、不过长
Toxicity Rate不当输出占比下降、<0.5%
Pair Reject Rate裁判放弃率(不可判)<15%(太高说明规则太苛刻)

6. 可直接运行的示例代码

下述代码分三块:(A) 裁判模块(B) 聚合与写 DPO 对(C) 训练策略管线。你可以直接把它们放入你的工程(judges/, pipelines/)并按需裁剪。

(A) 裁判模块:math/code/openqa + 聚合(≈170 行)

# judges/judge_suite.py
# -*- coding: utf-8 -*-
"""
可组合裁判模块:Math / Code / OpenQA(RAG) / NLI(可选) / FormStyle
聚合为一个统一评分接口 judge_pair(prompt, ans_a, ans_b, meta)
"""
import re
import math
import json
import random
from dataclasses import dataclass
from typing import Optional, Dict, Tuple, List

import numpy as np

try:
    import sympy as sp
except Exception:
    sp = None  # 可选

# ---- 工具 ----
def safe_last_number(text: str) -> Optional[float]:
    nums = re.findall(r"-?\d+(?:\.\d+)?", text)
    if nums:
        try:
            return float(nums[-1])
        except Exception:
            return None
    return None

def normalize_score(x: float, lo: float, hi: float) -> float:
    x = (x - lo) / (hi - lo + 1e-9)
    return float(np.clip(x, 0.0, 1.0))

# ---- Math Judge ----
def math_judge(answer_text: str, gold: Optional[float]) -> float:
    if gold is None:
        return 0.0
    pred = safe_last_number(answer_text)
    if pred is None:
        return 0.0
    err = abs(pred - gold)
    if err <= 1e-6:
        return 1.0
    # 误差缩放:与 gold 的量级相关,避免过严
    scale = max(abs(gold), 10.0)
    return max(0.0, 1.0 - err / scale)

def sympy_verify(expr_text: str, gold: Optional[float]) -> float:
    if sp is None or gold is None:
        return 0.0
    # 粗糙解析:提取等式右端的数或表达式
    try:
        exprs = re.findall(r"= *(.+)", expr_text)
        if not exprs:
            return 0.0
        val = float(sp.N(sp.sympify(exprs[-1])))
        return 1.0 if abs(val - gold) <= 1e-6 else 0.0
    except Exception:
        return 0.0

# ---- Code Judge(伪沙箱示例)----
def code_judge(answer_text: str, unit_tests: Optional[List[str]] = None) -> float:
    if not unit_tests:
        return 0.0
    # 简化:从文本中抓取 ```python ... ``` 代码块
    m = re.search(r"```python(.*?)```", answer_text, flags=re.S)
    if not m:
        return 0.0
    code = m.group(1)
    passed, total = 0, len(unit_tests)
    # 仅演示:真实工程需用 subprocess + 限时/限内存沙箱
    local_env = {}
    try:
        exec(code, {}, local_env)
        for t in unit_tests:
            try:
                ok = eval(t, {}, local_env)
                passed += 1 if ok else 0
            except Exception:
                pass
    except Exception:
        return 0.0
    return passed / max(1, total)

# ---- OpenQA / RAG Judge(轻量启发式)----
def overlap_score(text: str, evidence: str) -> float:
    text_tokens = set(re.findall(r"[A-Za-z0-9\u4e00-\u9fa5]+", text))
    ev_tokens = set(re.findall(r"[A-Za-z0-9\u4e00-\u9fa5]+", evidence))
    inter = len(text_tokens & ev_tokens)
    return normalize_score(inter, 0, max(len(ev_tokens), 6))

def rag_judge(answer_text: str, evidences: Optional[List[str]]) -> float:
    if not evidences:
        return 0.0
    # 取与最佳证据的重叠得分
    best = 0.0
    for ev in evidences[:3]:
        best = max(best, overlap_score(answer_text, ev))
    # 是否显式标注引用(极简)
    cite_bonus = 0.05 if re.search(r"\[(参考|引用|source)\]", answer_text) else 0.0
    return float(np.clip(best + cite_bonus, 0.0, 1.0))

# ---- 结构与风格 Judge ----
def form_style_judge(answer_text: str, question: str) -> float:
    has_list = 0.2 if re.search(r"(\n- |\n\d+\.)", answer_text) else 0.0
    has_head = 0.2 if re.search(r"(###|####|【|——)", answer_text) else 0.0
    # 覆盖度(启发式关键词)
    q_kw = set([w for w in re.split(r"[,。、;:,\s/]+", question) if len(w) >= 2])
    cov = sum(1 for w in q_kw if w in answer_text) / (len(q_kw) + 1e-6)
    cov = float(np.clip(cov, 0.0, 1.0))
    length_ok = float(np.clip(len(answer_text) / 600, 0.0, 1.0))  # 长度上限约束
    return 0.4 * cov + 0.2 * has_list + 0.2 * has_head + 0.2 * length_ok

# ---- 聚合打分 ----
@dataclass
class JudgeWeights:
    verifiable: float = 0.55
    factual: float    = 0.30
    formstyle: float  = 0.15

def judge_score(
    question: str,
    answer_text: str,
    meta: Dict,
    w: JudgeWeights = JudgeWeights()
) -> float:
    gold = meta.get("gold")  # 数学答案(可无)
    unit_tests = meta.get("unit_tests")
    evidences = meta.get("evidences")

    s_math = max(math_judge(answer_text, gold), sympy_verify(answer_text, gold))
    s_code = code_judge(answer_text, unit_tests)
    verifiable = max(s_math, s_code)  # 可验证层,谁能判就用谁

    factual = rag_judge(answer_text, evidences)
    form = form_style_judge(answer_text, question)

    base = w.verifiable * verifiable + w.factual * factual + w.formstyle * form
    # 轻微冗长惩罚(>1200字开始降分)
    penalty = 0.0
    if len(answer_text) > 1200:
        penalty = min(0.1, (len(answer_text) - 1200) / 6000.0)
    return float(np.clip(base - penalty, 0.0, 1.0))

def judge_pair(
    question: str,
    ans_a: str,
    ans_b: str,
    meta: Dict
) -> Tuple[str, str, float, float]:
    sa = judge_score(question, ans_a, meta)
    sb = judge_score(question, ans_b, meta)
    if abs(sa - sb) < 1e-4:
        # 平局:以覆盖度/多样性微调
        sa += 0.01 if len(ans_a) > len(ans_b) else 0.0
    if sa >= sb:
        return ans_a, ans_b, sa, sb
    else:
        return ans_b, ans_a, sb, sa

(B) 聚合与写 DPO 对:加上“不可判样本丢弃 + 分位数校准”(≈100 行)

# pipelines/build_dpo_pairs.py
# -*- coding: utf-8 -*-
"""
用裁判套件批量生成 DPO 偏好对
- 支持不可判样本丢弃
- 简单分位数校准(按批归一)
"""
import json
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple
from judges.judge_suite import judge_pair

def quantile_scale(scores: List[float], eps=1e-6) -> List[float]:
    arr = np.array(scores)
    lo, hi = np.quantile(arr, 0.05), np.quantile(arr, 0.95)
    hi = max(hi, lo + eps)
    return list(np.clip((arr - lo) / (hi - lo), 0.0, 1.0))

def build_pairs(examples: List[Dict], out_path: str):
    chosen_scores, rejected_scores = [], []
    keep_rows = []
    for ex in examples:
        q = ex["prompt"]
        a, b = ex["cand_a"], ex["cand_b"]
        meta = ex.get("meta", {})
        chosen, rejected, sc, sr = judge_pair(q, a, b, meta)
        # 丢弃不可判/噪声对
        if max(sc, sr) < 0.15 or abs(sc - sr) < 1e-3:
            continue
        keep_rows.append({
            "prompt": q,
            "chosen": chosen,
            "rejected": rejected,
            "score_chosen": sc,
            "score_rejected": sr,
            "meta": meta
        })
        chosen_scores.append(sc)
        rejected_scores.append(sr)

    # 简易批校准:把分数映射到分位数区间便于后续分析(DPO 不强制需要)
    if keep_rows:
        cs = quantile_scale(chosen_scores)
        rs = quantile_scale(rejected_scores)
        for r, csc, rsc in zip(keep_rows, cs, rs):
            r["score_chosen_q"] = csc
            r["score_rejected_q"] = rsc

    Path(out_path).write_text(
        "\n".join(json.dumps(r, ensure_ascii=False) for r in keep_rows),
        encoding="utf-8"
    )
    print(f"[DPO] kept={len(keep_rows)} write -> {out_path}")

if __name__ == "__main__":
    # 演示数据结构
    toy = [
        {
            "prompt": "计算:27 + 15 = ?",
            "cand_a": "…因此答案为:42",
            "cand_b": "…最终答案:41",
            "meta": {"gold": 42.0}
        },
        {
            "prompt": "简述 HTTP/2 的关键改进点",
            "cand_a": "- 多路复用…\n- 头部压缩(HPACK)…\n[参考]",
            "cand_b": "HTTP/2 更快。",
            "meta": {"evidences": ["HTTP/2 提供多路复用、头部压缩(HPACK)、服务端推送、流优先级…"]}
        }
    ]
    build_pairs(toy, "dpo_pairs_v2.jsonl")

(C) 训练策略管线:温度日程 + 困难样本挖掘 + DPO 循环(≈140 行)

# pipelines/selfplay_cycle.py
# -*- coding: utf-8 -*-
"""
自博弈循环:
1) 以温度日程生成 A/B
2) 裁判筛选 → DPO pairs
3) DPO 训练一小轮 → 更新策略
4) 硬负样本挖掘(对险胜/险负继续扩展 C/D)
"""
import os
import json
import random
from pathlib import Path
from typing import Dict, List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from pipelines.build_dpo_pairs import build_pairs

def gen_answer(model, tok, sys, q, temp, top_p=0.9):
    prompt = f"{sys}\n\n题目:{q}\n请逐步推理并给出最终结论:"
    ids = tok(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **ids, max_new_tokens=256, do_sample=True, temperature=temp,
        top_p=top_p, repetition_penalty=1.05,
        eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id
    )
    return tok.decode(out[0], skip_special_tokens=True).split("结论")[-1]

SYS_A = "你是严谨的助教,条理清晰。"
SYS_B = "你是细致的评审,先列要点。"

def temperature_schedule(step: int, warm_steps=3) -> float:
    # 前几轮高温,随后缓降
    if step < warm_steps:
        return 0.95
    return max(0.65, 0.95 - 0.05 * (step - warm_steps))

def selfplay_round(
    base_model: str,
    questions: List[Dict],
    step: int,
    out_dir: str
):
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(base_model, device_map="auto", torch_dtype=torch.float16)

    temp = temperature_schedule(step)
    print(f"[GEN] step={step} temp={temp:.2f}")
    pairs_raw = []
    for i, ex in enumerate(questions):
        q = ex["prompt"]
        meta = ex.get("meta", {})
        set_seed(42 + step * 100 + i)
        a = gen_answer(model, tok, SYS_A, q, temp)
        set_seed(1234 + step * 100 + i)
        b = gen_answer(model, tok, SYS_B, q, temp * 0.9)
        pairs_raw.append({"prompt": q, "cand_a": a, "cand_b": b, "meta": meta})

    # 裁判 + 构造 DPO 对
    out_pairs = os.path.join(out_dir, f"pairs_step{step}.jsonl")
    build_pairs(pairs_raw, out_pairs)
    return out_pairs

def mine_hard_negatives(pairs_path: str, k=20) -> List[Dict]:
    rows = [json.loads(l) for l in Path(pairs_path).read_text(encoding="utf-8").splitlines()]
    # 选择“分差最小”的前 k 条,进一步扩展 C/D 候选
    rows.sort(key=lambda r: abs(r["score_chosen"] - r["score_rejected"]))
    return rows[:k]

if __name__ == "__main__":
    base = os.environ.get("BASE", "Qwen/Qwen2.5-1.5B-Instruct")
    # 演示题库(可换 GSM8K/MATH 或业务题)
    questions = [
        {"prompt": "计算:27 + 15 = ?", "meta": {"gold": 42.0}},
        {"prompt": "如果 x=5,y=3,求 2*x^2 - y = ?", "meta": {"gold": 47.0}},
        {"prompt": "简述 HTTP/2 相对 HTTP/1.1 的关键改进点", "meta": {
            "evidences": ["HTTP/2 的关键特性包括多路复用、头部压缩(HPACK)、服务端推送、流优先级与更低延迟。"]
        }},
    ]

    work = "./sp_cycle"
    all_pairs = []
    for step in range(1, 5):
        p = selfplay_round(base, questions, step, work)
        all_pairs.append(p)
        # 你可以在此处调用前文的 DPOTrainer 做一小轮训练,然后继续下一轮
        # train_dpo_once(pairs_jsonl=p, base_or_adapter=..., output_dir=...)
        hard = mine_hard_negatives(p)
        print(f"[HARD] mined={len(hard)} top example prompt={hard[0]['prompt'] if hard else 'N/A'}")

7. 参考链接

参考链接(可靠来源)


最后小结(可操作要点)

1)裁判优先级:先做“可验证”(Math/Code),再做“事实一致”(RAG),最后用“结构/风格”兜底;不可判样本宁可丢弃。
2)生成策略:用温度日程制造分歧→收敛稳定;交叉对战维持多样性。
3)训练闭环:自博弈 → 裁判 → 偏好对(DPO)→ 小步快跑迭代;需要在线反馈时接入 PPO。
4)评测看趋势:dev win-rate、MathAcc、Factual@Top-k 持续上升,Toxicity 下降。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值