大模型微调入门

大模型微调入门

目录


业界应用大模型的风潮如火如荼,已经从科技圈渗透到各行各业。

自昨日与同事交流,感觉大家对大模型是什么,能做什么,怎么用还有很多疑惑。我切以自己的浅薄学识用此文向大家介绍,希望可以给大家一定的启发。如果有错漏之处,欢迎大家批评指正。

文章撰写的思路由浅及深,以帮助可以更清晰的介绍大模型的训练过程。第一章讲解了什么是大模型,大模型相较于普通的算法强在哪里以及大模型能做什么。第二章向大家介绍大模型的微调和知识库的区别,以及推荐用哪种方式来训练自己的大模型。第三章介绍了微调大模型的环境配置。第四章讲解怎么构建微调大模型的数据集。第五章讲大模型怎么微调,参数是什么意思。第六章讲解模型的训练和评估过程。第七章来聊一聊模型的本地化部署。

在这里插入图片描述

一、什么是大模型以及大模型能做什么

1.1 大模型是什么

大模型(Large Language Model, LLM)是值参数规模达到亿级(0.1B)甚至千亿(100B)的深度学习模型,如GPT-4、LLaMa、Qwen等。

说到大模型就不得不说深度学习,说到深度学习就不得不说神经网络。昨天薇姐的例子很好,一个简单的神经网络模型算法可能就是
y = a x 2 + b x + c y=ax^2+bx+c y=ax2+bx+c
x是输入,y是输出,其模型的结构就可以设计为:
在这里插入图片描述

图中的a、b、c是需要训练的参数,众所周知只要有三个确定的点 ( x 1 , y 1 ) , ( x 2 , y 2 ) , ( x 3 , y 3 ) (x_1,y_1),(x_2,y_2),(x_3,y_3) (x1,y1),(x2,y2),(x3,y3)就能确定一个二次函数。这就引申出来三个比较重要的知识点:

  • 参数量与训练数据量是挂钩的。比如两个点能确定一个一次函数、三个点能确定一个二次函数(如上图)、4个数能确定一个三次函数。一个1.5b(15亿参数两)的大模型自然需要很大的数据集来训练。

  • 不管训练数据量有多少,其参数量都是不变的。如上图不管输入多少个点,其训练的参数量还是只有三个。不过输入的正确数据量越多,所预测的二次函数曲线就会越趋近真实曲线

  • 上面的模型是线性的,也就是说仅通过加法、乘法和矩阵运算就能得到预测结果,适用于比较简单的任务,比如根据房屋面积和数量预测房价、邮件分类等。但是还有一些任务是非线性的,比如图像识别、文本生成、股票预测等。非线性任务需要加入激活层来拟合,篇幅有限不在此多做介绍,读者可以使用上面模型来理解,但是需要知道大模型的结构要复杂得多,不仅仅是线性结构的堆叠。

神经网络可以简单理解为上面参数层(a,b,c)和激活层的多层堆叠。

而深度学习则是多个神经网络的组合和堆叠。

大模型则是多个深度学习算法的组合和堆叠。

1.2 大模型能做什么

在讲解大模型能做什么之前先来聊一聊都有哪些大模型。

按照应用方向划分,有:

  • 文本大模型。以自然语言处理为核心,具备语言理解、生成和复杂推理能力,可以处理文本生成、智能客服、代码生成等任务。自GPT-4发布,标志了文本通用大模型的诞生,一统天下。代表模型:GPT-4、DeepSeek、LLama等。
  • 图像大模型。专注于计算机视觉任务,包括图像分类、目标检测、图像处理等。代表模型:Stable Diffusion
  • 视频大模型。结合时序和空间建模,可以处理视频理解、视频生成和视频推理任务。代表模型:Sora
  • 语音大模型。语音输入和回复,支持对话、歌唱等风格。代表模型:心辰Lingo
  • 多模态大模型。模型能力包括输入文本、图像、语音和视频,其输出也可以是文本、图像、语音和视频。代表模型:GPT-4、DeepSeek、LLama等

大家可以很明显的看出来,目前发展最好、技术最先进的模型是文本大模型、其次是图像大模型,其他几种模型都是基于文本大模型的技术来训练的,主要原因可能是因为积累的标注文本数据要远远多于其他几种模型。虽然图像、视频等领域号称有通用大模型,但是根据用户的反馈其实并不好,个人认为通用智能大模型还是文本大模型所衍生的。

大模型的特点和核心优势体现在以下几个方面:

  • 通用特征提取:通过海量无标注数据预训练,掌握语言规律和知识关联能力(如词语接龙和上下文理解)
  • 任务适应性强:支持文本生成、代码编写、多轮对话等复杂场景,可适应能源行业设备运维报告生成、工单分类等需求
  • 微调增效显著:相比传统算法,通过少量领域数据微调即可提升特定任务准确率30%以上。1

典型应用场景:

  • 智能问答:构建电力设备故障知识库问答系统
  • 文档生成:自动生成输变电设备巡检报告
  • 数据分析:从电网线路传感器日志中提取异常模式

二、微调和知识库

2.1 定义

**模型微调(Fine-Tuning)**是一种基于预训练模型的迁移学习技术,通过在特定任务数据集上进一步训练,调整模型参数以优化其在特定领域的性能。这一过程能够利用预训练模型已有的通用知识,同时适应新任务的需求。
**知识库(Knowledge Base)**是结构化存储、组织和管理知识的系统,通常包含事实、规则及策略,支持智能检索与应用。其核心是促进知识的积累、共享与复用。

2.2 通俗讲解

  1. 模型微调:让大模型“学得更专”
    想象一下,你有一个非常聪明的助手(大模型),它已经学会了人类的语言和很多知识,但还不够专业。为了让它在某个特定领域(比如医疗或法律)表现得更好,你可以给它一些专业书籍(领域数据)让它学习。这个过程就是微调。微调后,助手在这个领域的表现会更精准。

  2. 知识库:给大模型一个“外接大脑”
    知识库就像是一个装满专业知识的图书馆。当大模型遇到不懂的问题时,它可以快速从这个图书馆中找到相关信息来回答。比如,当用户问“变压器过热怎么办?”时,大模型可以从知识库中检索出“检查油泵和散热器”的建议。

2.3 微调与知识库的优劣对比

维度微调知识库
通俗理解让大模型变得更专业,从“通才”变成“专家”。给大模型一个“外接大脑”,遇到问题随时查资料。
优势1. 在特定任务中表现更精准(如医疗诊断)。
2. 不需要每次都去查资料,直接给出答案。
3. 适合需要深度理解的场景(如法律分析)。
1. 信息更新方便,随时添加新知识。
2. 适合需要最新信息的场景(如政策查询)。
3. 成本较低,不需要大量算力。
劣势1. 需要大量专业数据,成本较高。
2. 如果数据不够,可能会“学偏”。
3. 更新麻烦,每次新知识都要重新训练。
1. 依赖检索,回答速度可能较慢。
2. 如果知识库没有该问题的答案,可能回复不出来。
3. 需要人工维护,确保信息准确。
适用场景1. 专业领域(如医疗、法律)。
2. 需要深度理解的复杂任务。
1. 需要最新信息的场景(如政策法规)。
2. 知识频繁更新的领域(如科技前沿)。

2.4 举例说明

  1. 微调的例子
    假设你是一家医院的医生,希望大模型能快速诊断患者的病情。你可以用大量病历数据微调大模型,让它学会如何根据症状判断疾病。微调后,它就能直接给出诊断建议,而不需要每次都去查资料。
  2. 知识库的例子
    假设你是一家企业的客服,用户经常问“最新的税务政策是什么?”你可以把相关政策文件存入知识库。当用户提问时,大模型会从知识库中检索出最新政策并回答。如果政策更新了,你只需要更新知识库,而不需要重新训练模型。

2.5 主流算法

模型微调:

  1. 全参数微调(FFT):更新所有参数,适合数据充足场景。
  2. 参数高效微调(PEFT):
    • LoRA:通过低秩矩阵分解调整权重,减少参数量。
    • Adapter:插入小型网络模块,仅训练新增参数。
    • Prefix Tuning:在输入层添加可学习前缀向量,引导模型适应任务。
  3. 知识蒸馏:用大模型指导小模型训练,平衡性能与效率。

知识库:

  1. 多模态融合:整合文本、图像、语音等多源数据(如医疗影像+病历分析)。
  2. 知识图谱:构建语义网络,增强推理能力(如金融风控路径分析)。

2.6 总结

微调适合需要深度理解和专业知识的场景,但成本较高,更新麻烦。
知识库适合需要最新信息和快速检索的场景,但依赖人工维护,回答可能不够精准。

在实际应用中,可以根据需求选择合适的方式,甚至将两者结合使用。比如,用微调让大模型在专业领域表现更好,同时用知识库补充最新信息,实现最佳效果。

本文主要讲解的是微调方法,所以以下几章都是对于模型微调的介绍。

三、微调环境配置

3.1 软硬件配置需求

在配置环境之前,应该首先介绍一下所需要的硬件配置和环境依赖。

本次微调选用的基础模型是Deepseek,使用的微调平台是LLama Factory,以下是官网的硬件推荐。模型精度越高,性能越好,也会占用更大的显存和内存。模型参数量越大,效果越好,也会占用更大的显存和内存。

在这里插入图片描述

还有官网的软件包版本依赖:

在这里插入图片描述

一般而言,微调1.5b的模型电脑显存应不低于6G。

3.2 服务器介绍

如果个人电脑配置不足以微调模型或者电脑还有其他用途,最好的方式就是使用云服务器。这种方式可以按使用时间计费,预安装了所需的基本软件环境,价格亲民(如:32G显存不到2元/时)。

我选用的服务器是AutoDL 2

在这里插入图片描述

租用服务器时可以选择软件依赖版本或者使用他人已经预制好的镜像环境。

在这里插入图片描述

然后打开控制台可以查看自己租用的机器信息。加载完镜像后即可开机进行微调了。微调完记得关机,如果不使用这个容器就可以释放实例,防止在空闲时间扣钱。

在这里插入图片描述

开机后点击JupyterLab打开在线环境,也可以使用SSH在本地开发。

在这里插入图片描述

3.3 环境配置

首先拉取LLaMa-Factory项目并安装,官网地址:https://github.com/hiyouga/LLaMA-Factory/tree/main

最好使用torch = 2.3的镜像,更新的包会少一些

git clone https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory/
pip install -e '.[torch,metrics]' # 如果下载速度过慢加上 -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

然后下载适合的模型,以DeepSeek-R1-1.5B为例:

下载方式一:huggingface

export HF_ENDPOINT=https://hf-mirror.com
export HF_HOME=/mnt/workspace/huggingface
echo $HF_HOME
huggingface-cli download --resume-download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --local-dir Models/deepseek-1.5b

下载方式二:modelscope(国内,网速更好)

export USE_MODELSCOPE_HUB=1
modelscope download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --local_dir .

启动LLamafactory

llamafactory-cli webui

如果提示“To create a public link, set share=True in launch().”,

就打开并更改以下内容LLaMA-Factory/src/llamafactory/webui/interface.py用于开启远程UI端口

def run_web_ui() -> None:
    gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
    # gradio_share = is_env_enabled("GRADIO_SHARE") // 将这个改为True
    gradio_share = True
    server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
    create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)

如果提示“Could not create share link. Missing file: /usr/local/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3.

Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps:

  1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
  2. Rename the downloaded file to: frpc_linux_amd64_v0.3
  3. Move the file to this location: /usr/local/lib/python3.10/site-packages/gradio”

就按照提示获取远程文件并重命名即可

wget https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 # 网络异常就自己下载后拖到服务器上
mv frpc_linux_amd64 frpc_linux_amd64_v0.3
mv frpc_linux_amd64_v0.3 /usr/local/lib/python3.10/site-packages/gradio # 注意这个地址可能会根据镜像不同而改编
chmod +x /usr/local/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.3 # 增加执行权限

配置成功后启动应该如下所示:

在这里插入图片描述

其中标记的网址是公网地址可以直接点击访问可视化界面

四、数据集准备

4.1 数据集来源

数据集是微调模型的核心,只有构建出高质量的数据集才能微调好模型。

微调数据集一般来自:

  • 公开数据集。从公开数据集网站上获取,例如Huggingface、Modelscope以及特定行业的数据集网站。优势:多、质量高。劣势:行业数据集很难获得。
  • 人为标注。从公司内部或者行业内部获得数据进行数据处理、数据清洗和标注得到。优势:专业度高、一般储备原数据多。缺点:需要花费大量时间和精力。
  • AI生成。通过与已经训练好的知识库AI对话得到数据集用来训练微调AI模型。优势:时间快、质量尚可。缺点:API付费、质量受知识库AI效果影响。

数据集格式:

LLama Factory支持的数据格式主要有两种:

Alpaca(默认):

[
  {
    "instruction": "人类指令(必填)",
    "input": "人类输入(选填)",
    "output": "模型回答(必填)",
    "system": "系统提示词(选填)",
    "history": [
      ["第一轮指令(选填)", "第一轮回答(选填)"],
      ["第二轮指令(选填)", "第二轮回答(选填)"]
    ]
  }
]

shareGPT:

[
  {
    "conversations": [
      {
        "from": "human",
        "value": "<audio>人类指令"
      },
      {
        "from": "gpt",
        "value": "模型回答"
      }
    ],
  }
]

4.2 构建数据集

本次微调为了节省时间成本,决定采用AI生成数据集的方式。

4.2.1 获取问题

首先获取本地数据库的文件另存为CSV文件(用户也可以使用自己的文本文件切片获取数据集)

然后选择一个模型,上传数据库.csv,让它来生成问题,并保存

你是一个大模型数据生成专家,请仔细阅读文件里的架空电线杆塔型号资源库信息。
接下来,请你站在用户(设计人员)的角度思考,如果用户向AI杆塔选型系统提问,都会提哪些问题。
列举50个。

在这里插入图片描述

4.2.2 构建智能体

有了问题就需要生成的回复,选用的方式是构建一个智能体,让它来回答我们的问题。至于原因后续会有提及。

选用的智能体平台是扣子3

打开扣子官网https://www.coze.cn/ ,创建智能体

在这里插入图片描述

输入名字后,就可以设置智能体了。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

1)首先设置它的人设

# 角色
你是一位资深的杆塔型号智能推荐专家,在电力工程领域拥有深厚的专业知识与大量的实践经验。你的职责是依据用户的需求,提供专业、精准的杆塔型号推荐,确保推荐的杆塔既能满足各项技术要求,又能兼顾经济性与安全性。

## 技能
### 技能 1: 检索并推荐杆塔型号
1. 当收到用户提问后,首先检索知识库,从知识库中提取与用户需求相关的信息。若知识库中有匹配信息,直接回复用户相应的杆型信息,包含杆塔的型号名称、类别、材质和回路数等等。
2. 如果知识库中没有相关信息,回答“未找到相关信息”。

### 技能 2: 详细说明选型依据
1. 若成功推荐杆塔型号,需向用户详细阐述推荐该型号的依据,包括满足的技术要求、经济性考量以及安全性保障等方面。
2. 用通俗易懂的语言,确保用户能清楚方便地获取杆塔选型的相关信息。

## 限制:
- 推荐的杆塔型号必须基于现有的市场产品和技术规范。
- 不得超出大模型的知识范围,若需获取最新的市场信息或特定地区的特殊要求,应明确告知用户需要调用外部工具或数据库。
- 确保推荐过程透明公正,不偏向任何特定品牌或供应商,除非用户有明确要求。
- 回答内容应条理清晰,尽量使用有序列表或无序列表等形式呈现关键信息。 

2)然后将我们的资源库文件作为它的知识库

在这里插入图片描述

3)在右侧的预览与调试询问它关于杆塔选型的问题了。可以将答案复制出来构建数据集。但是这种一个问题一个问题的问的太过于麻烦,所以我们需要将这个智能体发布出去,然后获取它的API接口,编写脚本来自动获取答案。这也是为什么选择搭建一个智能体的原因,就是为了获取其API。

点击发布后要确保勾选了API选项

在这里插入图片描述

4.2.4 获取API授权

在这里插入图片描述

添加完成后切记要保存这个令牌!

添加完成后切记要保存这个令牌!

添加完成后切记要保存这个令牌!

因为这个令牌只会出现一次

在这里插入图片描述

4.2.5 脚本获取回复

编写脚本来获取回复就不需要手动输入、复制了。下面的API_TOKEN就是上面保存的令牌。

BOT_ID是构建的智能体页面中

在这里插入图片描述

import os
from coze import Coze

API_TOKEN = '' # 自己的Token
BOT_ID = "7477397675760599059"

os.environ['COZE_API_TOKEN'] = API_TOKEN
os.environ['COZE_BOT_ID'] = BOT_ID
chat = Coze(api_token= API_TOKEN,
            bot_id = BOT_ID,
            max_chat_rounds=20,
            stream=True)

import json

def process_questions(input_file, output_file):
    # 读取问题文件
    with open(input_file, 'r', encoding='utf-8') as f:
        questions = f.readlines()
    
    i = 1
    alpaca_data = []
    for q in questions:
        q = q.strip()
        if not q:
            continue

        print(i)
        i += 1
        # 调用API获取答案

        try:
            answer = chat(q)
            if answer == "" or "未找到相关信息" in answer:
                continue
            # 构建Alpaca格式条目
            entry = {
                "instruction": q,
                "input": '',
                "output": answer
            }
            alpaca_data.append(entry)
        except Exception as e:
            continue    
    
    # 保存为JSON文件
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(alpaca_data, f, ensure_ascii=False, indent=2)
    print("数据处理完成!")

process_questions(r"问题.txt", "alpaca_format_dataset.json")

最后的数据集大概如下:

在这里插入图片描述

4.3 修改数据配置文件

在模型训练前,需要修改数据集的配置文件,这样在后面微调的时候才会显示我们自制的数据集。首先将dataset_info.json文件下载到本地,增加我们的自制数据集的信息。然后将自制数据集和dataset_info.json文件都上传回原文件夹。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

五、“炼丹”——调参数

现在我们有了数据集就可以准备开始训练模型了。除了数据集之外,微调的另一个核心操作就是调整参数,由于参数很多,设计的算法也很多,排列组合之后有很多可能,经常会遇到调了参数效果并不好,然后反复调整几十次才有比较好效果的情况,所以行内也称调参为“炼丹”。

运行之前提到的命令 llamafactory-cli webui来启动可视化界面,界面如下:

在这里插入图片描述

接下来介绍各个参数代表的含义:

5.1 基础参数

  • 语言:切换界面展示的语言。
  • 模型名称:你微调的基础模型。支持多种开源大模型。
  • 模型路径:可以是云端模型仓库的模型路径(默认)也可以是本地的模型路径。
  • 微调方法:
    • Full: 全参数微调,更新模型的所有参数,适合数据集大、任务复杂的情况。
    • Freeze: 冻结部分参数,只更新特定层,适合资源有限的情况。
    • LoRA: 低秩自适应,只更新少量参数,适合资源有限且需要高效微调的场景,默认算法。
  • 量化等级:选择模型的量化等级,如8 位 (INT8) 或 4 位 (INT4),减少显存占用。
  • 加速方式:选择加速训练的方式。默认选择auto
  • RoPE插值方法:用于扩展模型上下文的能力。

下面分为四个子界面:训练(train)、评估(Evaluate&Predict)、对话(chat)、导出(Export)。

5.2 训练相关参数

在这里插入图片描述

  • 训练阶段:选择训练的方式。如监督微调SFT(默认),奖励建模PM,预训练Pre-Tuning。
  • 数据路径:存放数据集的文件夹。可以使用相对路径和绝对路径。
  • 数据集:选择数据集。有很多自带的数据集,这里我们选择自制的数据集。
  • 学习率:控制模型学习的速度。通常1e-5到5e-5。
  • 训练轮数:对数据集训练的次数,通常2到10之间。
  • 截断长度:输入序列的最大长度,超过就会被阶段。
  • 批处理大小:每次输入模型的样本数量,较大的批处理会占用更多显存。显存较大就提高批处理大小,否则反之。
  • 梯度累计:输入模型多少次样本才计算一次梯度。跟批处理大小一起作用,如果显存小就可以提高该数值,模拟更大的批处理大小。
  • 最大样本数:数据集的样本量限制,超出则截断。
  • 计算类型:模型使用的精度。
  • 验证集比例:划分数据集为验证集的比例。
  • 学习率调节器:调整学习率的算法。

其他高级参数此处不一一列举,有兴趣的读者可以自行学习。一般微调模型最重要的参数就是训练轮次、学习率、批处理大小

设置完参数后可以点击预览命令,该命令可以粘贴到命令行进行训练。

也可以选择保存这些设置好的参数、以及加载之前保存的参数。

最后就是开始训练按钮和中断训练。开始训练后下面会出现损失函数曲线,损失是指期望值与实际值之间的差异,损失越大,模型效果越差。

在这里插入图片描述

5.3 模型的评估

模型训练完,怎么评估模型的效果呢?有两种方式:1)使用验证数据集获取损失得分来评估。2)跟模型对话来评估。

Evaluate&Predict就是第一种方式。

在这里插入图片描述

  • 数据集:数据集选择与训练数据集相同
  • 截断长度、最大样本数、批处理大小:与训练阶段相同
  • 最大生成长度:模型生成的最大长度。
  • Top-p采样值:模型选择答案的置信度。
  • 温度系数:模型创造文本的随机性,温度系数越低,越有可能基于已经训练的知识回答,温度系数越高,越可能自己随机创造。
  • 预览命令、开始训练和中断:与训练阶段相同

5.4 模型对话

在这里插入图片描述

模型对话需要先选择上方的检查点路径,来加载我们训练好的模型,这个检查点路径名字和训练阶段的保存目录相同。

加载模型:加载训练好的模型后才可以对话。

卸载模型:卸载已加载的模型。

5.5 模型导出

在这里插入图片描述

分块大小:模型文件的最大大小

量化等级:模型的精度,量化等级越低,性能越差。

导出目录:导出模型的文件夹路径

六、模型的本地化部署

本地化部署是模型落地的重要环节,这样可以很有效的保证数据隐私和响应速度。我们选择使用Ollama来管理本地模型,Ollama是一个开源的本地大语言模型运行框架,到官网4下载安装即可

将导出的训练好的模型下载到本地,然后运行命令ollama create model_name -f model_dir,其中model_name和model_dir是自己模型的名字和下载模型文件中Modelfile路径

在这里插入图片描述

在这里插入图片描述

至此模型安装成功,就可以本地化与模型问答聊天了!由于本文许多资料基于AI联网搜索获得,所以很多参考文献都没有列出来,在此感谢诸位大佬的分享,如果发现遗漏了您的引用,请和我沟通。

参考文献

1解析大模型常用微调方法:P-Tuning、Prefix Tuning、Adapter、LoRA

2AutoDL算力云 | 弹性、好用、省钱。租GPU就上AutoDL

3https://www.coze.cn/

4Ollama

5教程1:【秒懂教程】30分钟学会DeepSeek R1模型Lora微调训练,适合借鉴学习,保姆级教程,全程干货无废话,草履虫都能学!_哔哩哔哩_bilibili

6教程2:241106_数据如何转为数据集-大模型问答对制作技巧_哔哩哔哩_bilibili

### Transformer 大模型微调入门教程 #### 一、理解微调的概念及其重要性 微调是指在预训练好的大规模语言模型基础上,针对特定任务或领域数据集进行进一步训练的过程。通过这种方式可以使得通用性强的大规模预训练模型更好地适应具体应用场景下的需求[^1]。 #### 二、准备环境与工具包安装 为了能够顺利开展基于Transformer架构的大模型微调工作,需要提前准备好相应的开发环境并安装必要的Python库文件。这里推荐使用`transformers`库来加载预训练模型,并利用`datasets`处理自定义的数据集合;同时借助于`accelerate`加速分布式训练过程,提高效率[^2]。 ```bash pip install transformers datasets accelerate torch ``` #### 三、获取预训练模型及配置参数调整 从Hugging Face Model Hub下载所需的预训练权重文件(例如BERT、RoBERTa等),并通过修改部分超参设置以满足目标任务的要求。比如,在情感分类场景下可适当减少隐藏层维度大小从而降低过拟合风险的同时保持较好的泛化能力[^3]。 ```python from transformers import AutoModelForSequenceClassification, AutoTokenizer model_name = "bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=2, # 设置标签数量为2表示二元分类问题 ignore_mismatched_sizes=True # 当前版本可能与原始预训练不完全匹配时忽略尺寸差异 ) ``` #### 四、构建适配器结构实现高效迁移学习 考虑到直接全量更新整个网络可能会破坏原有良好特性的情况,因此建议采用Parameter-Efficient Fine-Tuning (PEFT)技术中的LoRA(低秩分解)方案来进行局部优化操作。这样既能在一定程度上保留住原生性能优势又不会引入过多额外计算开销。 ```python from peft import LoraConfig, get_peft_model lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["query", "value"]) peft_model = get_peft_model(model, lora_config) for name, param in peft_model.named_parameters(): if 'adapter' not in name: param.requires_grad_(False) ``` #### 五、编写训练脚本完成最终部署上线 最后一步则是按照标准流程组织好输入输出格式后执行迭代式的监督式学习循环直至收敛稳定为止。期间还需注意监控各项指标变化趋势以便及时作出相应调整措施确保整体效果达到预期目标水平之上。 ```python import evaluate metric = evaluate.load("accuracy") def compute_metrics(eval_pred): logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) return metric.compute(predictions=predictions, references=labels) training_args = TrainingArguments(output_dir="./results", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, weight_decay=0.01, save_total_limit=2) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics ) trainer.train() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值