DistillKit 使用教程
1. 项目的目录结构及介绍
DistillKit 项目的目录结构如下:
DistillKit/
├── accelerate_config.sh
├── deepspeed_configs/
│ ├── deepspeed_config.json
│ └── deepspeed_config ordinance.json
├── distil_hidden.py
├── distil_logits.py
├── LICENSE
├── README.md
├── requirements.txt
├── setup.sh
└── train_distil.py
accelerate_config.sh
:用于生成加速训练的配置文件。deepspeed_configs/
:包含 DeepSpeed 配置文件,用于优化训练性能。distil_hidden.py
:实现隐藏状态蒸馏方法的 Python 脚本。distil_logits.py
:实现 logit 蒸馏方法的 Python 脚本。LICENSE
:项目的 Apache-2.0 许可文件。README.md
:项目说明文件,包含项目介绍和基本使用说明。requirements.txt
:项目依赖的 Python 包列表。setup.sh
:项目安装脚本,用于快速设置项目环境。train_distil.py
:用于训练蒸馏模型的 Python 脚本。
2. 项目的启动文件介绍
项目的启动文件主要包括 setup.sh
脚本,用于快速安装项目依赖。以下是启动文件的简单介绍:
setup.sh
:运行此脚本会自动安装项目所需的 Python 包。脚本内容如下:
#!/bin/bash
# 安装基本要求
pip install torch wheel ninja packaging
# 安装 Flash Attention
pip install flash-attn
# 安装 DeepSpeed
pip install deepspeed
# 安装其余要求
pip install -r requirements.txt
要启动项目,你需要先运行 setup.sh
脚本,确保所有依赖都已正确安装。
3. 项目的配置文件介绍
项目的配置文件位于项目的根目录中,主要是一个 JSON 格式的配置文件,用于定义项目运行时的各种参数。以下是配置文件的简单介绍:
- 配置文件通常名为
config.json
,在项目根目录中。它包含以下部分:
{
"project_name": "distil-logits",
"dataset": {
"name": "mlabonne/FineTome-100k",
"split": "train",
"seed": 42
},
"models": {
"teacher": "arcee-ai/Arcee-Spark",
"student": "Qwen/Qwen2-1.5B"
},
"tokenizer": {
"max_length": 4096,
"chat_template": "{% for message in messages %}..."
},
"training": {
"output_dir": "./results",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine",
"resume_from_checkpoint": None
},
"distillation": {
"temperature": 2.0,
"alpha": 0.5
},
"model_config": {
"use_flash_attention": True
}
}
这个配置文件包含了项目运行时的基本设置,如项目名称、数据集、模型、分词器设置、训练设置、蒸馏设置和模型配置等。在实际运行项目前,你可能需要根据实际情况调整这些配置以适应不同的需求。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考