Set up Project in BF

本文详细介绍了一个项目的构建过程,包括源代码的获取、构建、输出及清理等步骤,并介绍了如何在BuildForge上配置自动化构建流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

第一步:弄清楚Build Process;

 

在我们的项目中,基本的Build Brocess 如下:

(1)将source code 从版本控制所在的服务器拷贝到专门存放prodution build的文件服务器。每次build需要建立临时目录存放source code,build完成后需要将临时目录删除。

 

(2)将build的source code拷贝到build所在的机器,调用build脚本进行build;

 

(3)将build的output和log打包输出到文件服务器。

 

(4)删除服务器所在的临时目录,发送build completed note;

 

第二步: 准备Build的机器;

 

有些项目需要跨平台编译,需要在多个平台上编译,如redhat4~6,sles9,windows等,因此需要安装保证对应的操作系统版本。同时需要安装好build所需依赖的库。由于文件服务器需要登录才能访问,因此需要安装文件服务器访问的客服端,而这些build machine 需要和Build Forge 沟通,因此还需要安装Build Forge 的client.

 

再次还需要创建build的目录,用于存放build的source code 和output

准备好这些机器后,最好首先现在这些机器上手动build,已确定build所需的dependency 是否都已经安装完毕。

 

第三步:在Build Forge上创建Project.

 

在创建前需要弄清楚以下三个概念:

(1)Project:

(2)Lib

(3)Step

基本上说一个project 可以包含若干个lib 和多个step,而step中可以调用lib,lib又由不同的step组成。

Project 和lib的区别在于project 有selector 而lib没有。

 

搞清楚了以下几个概念,需要确定哪些步骤适合step,哪些步骤可以创建为lib.lib的好处在于:它就像一个函数,你可以传给他不同的参数和环境变量的值,在不同的平台上执行相同的代码,极大的保证的代码的可重用性。

 

同时,不同的平台之间有一定的差异,在确定process的时候首先需要确定平台在编译时候的共性和不同之处。项目中将平台分为两大类型:Linux 和windows.linux 不同操作系统之间有一定的差异,但差异不是很大。而windows和linux平台之间差异较大。因此项目中存在两个主要的lib linux_build_lib,win_build_lib;

 

在创建step和lib的时候,还需要建立相应的环境变量。选择对应的selector.

 

第四步:测试;

在运行build forge 的project的时候与之对应的是Job,job和project 的关系有点类似于instance 和class的关系。每次project运行都会有与之对应的job id.job name一般由 job tag 和job num构成。

 

 


 

由于在Build Forge上建立process涉及的细节较多,因此单独列出。在本小节阐述。

(1)建立project: 选择 Project--> Add project in Build Forge,一般需要设置的变量为 project name, access,enviroment, selector,最好设置个超时。tag是指你希望每次运行时候的job的name。

 

(2)给project 添加step. 点击project,进去new page后选择add step.

step通常需要设置:step name 和command line,有的时候可能需要你设置是绝对路径还是相对路径。 有的时候不同的step需要在不同的机器上执行,则需要选择seletor,如果该 step是个lib,需要在inline中选择所调用的step.

 

写command 的时候注意:最好不要在一个step里面有多个command,特别是有return code 的情况,如果步骤超过了三步以上,用脚本实现最好。

 

(3)创建lib

 

选择liberary--> add library ,大体步骤和设置project 差不多,除了不需要选择selector。不再详细阐述。

 

(4)添加环境变量

一个enviroment是由多个variable组成的,varibale即键值对,同时一个varble的值还可以由多个varible组合。

选择eviroment-->add enviroment 填写name 和acces即可,然后选择刚才创建的evn  add varible.一般varible的那么全部由大写组成。

 

 


 

 

一些Tips:

(1)如果在一个command line 有多个command 连接 ,如build-linux.sh | tee /${OS_TYPE}.log,则每次返回的都是最后一个command的结果,因此会出现build fail掉而step仍然为pass 的情况。

解决此问题的方法有两个:

(1)在command前设置 set -o pipefail;

(2)在step 之后使用body { margin: 0pt; padding: 0pt; }td, div { font-family: Tahoma; font-size: 10pt; vertical-align: top; }body { margin: 0pt; padding: 0pt; overflow: hidden; }.transcript { background-color: rgb(210, 210, 210); }.messageBlock { margin-left: 4px; margin-bottom: 3px; }.message { margin-left: 100px; word-wrap: break-word; white-space: pre-wrap; }.messageCont { margin-left: 100px; word-wrap: break-word; white-space: pre-wrap; }.other { overflow: hidden; color: rgb(57, 87, 122); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.myself { overflow: hidden; color: rgb(218, 129, 3); font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont { font-size: 8px; text-align: right; color: rgb(57, 87, 122); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.myselfCont { font-size: 8px; text-align: right; color: rgb(218, 129, 3); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.system { margin-left: 4px; word-wrap: break-word; color: rgb(218, 129, 3); font-style: normal; font-weight: normal; white-space: pre-wrap; }.showTimestamp { padding-left: 8px; margin-right: 3px; float: right; color: rgb(153, 153, 153); font-style: normal; font-weight: normal; }.other1 { color: rgb(172, 32, 0); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont1 { font-size: 8px; text-align: right; color: rgb(172, 32, 0); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other2 { color: rgb(60, 159, 168); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont2 { font-size: 8px; text-align: right; color: rgb(60, 159, 168); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other3 { color: rgb(226, 86, 20); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont3 { font-size: 8px; text-align: right; color: rgb(226, 86, 20); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other4 { color: rgb(11, 106, 200); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont4 { font-size: 8px; text-align: right; color: rgb(11, 106, 200); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other5 { color: rgb(178, 50, 144); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont5 { font-size: 8px; text-align: right; color: rgb(178, 50, 144); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other6 { color: rgb(2, 231, 199); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont6 { font-size: 8px; text-align: right; color: rgb(2, 231, 199); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other7 { color: rgb(91, 50, 132); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont7 { font-size: 8px; text-align: right; color: rgb(91, 50, 132); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.highlight { background-color: rgb(190, 214, 248); }.datestamp { cursor: default; margin-bottom: 1px; background-color: rgb(192, 192, 192); width: 100%; float: left; text-align: right; color: rgb(255, 255, 255); font-weight: bold; font-style: italic; }#chatAlert { float: left; border-bottom: 1px solid rgb(232, 208, 145); padding: 6px; width: 100%; color: rgb(165, 117, 76); }#chatAlertImage { float: left; }#chatAlertText { float: left; margin-left: 6px; }#chatAlertClose { float: right; margin-right: 10px; padding-right: 6px; margin-top: 0px; }#chatAlertText a { color: rgb(165, 117, 76); }#chatAlertText a:hover { color: rgb(165, 117, 76); text-decoration: none; }.tsDisplay { display: block; }.dsDisplay { display: block; }

exit ${PIPESTATUS[0]} 

 

(2)如果需要同时在build forge里面输出log 和存log 文件,可以用tee command;

 

(3)在linux平台上command里面的注释可以用#,而windows则不行,#会报错,可以使用rem

 

(4)window上远程拷贝可以使用xcopy 

xcopy  //**.com/directory/file name  local_directory_on_build_ machine /v/q/y

 

(5)windows上输出屏幕和log 可以使用type command

 

(6)REG ADD "HKCU/Software/Microsoft/Command Processor" /V DisableUNCCheck /T REG_DWORD /F /D 1

 

这个command是用来解决windows上的UNC pathbody { margin: 0pt; padding: 0pt; }td, div { font-family: Tahoma; font-size: 10pt; vertical-align: top; }body { margin: 0pt; padding: 0pt; overflow: hidden; }.transcript { background-color: rgb(210, 210, 210); }.messageBlock { margin-left: 4px; margin-bottom: 3px; }.message { margin-left: 100px; word-wrap: break-word; white-space: pre-wrap; }.messageCont { margin-left: 100px; word-wrap: break-word; white-space: pre-wrap; }.other { overflow: hidden; color: rgb(57, 87, 122); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.myself { overflow: hidden; color: rgb(218, 129, 3); font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont { font-size: 8px; text-align: right; color: rgb(57, 87, 122); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.myselfCont { font-size: 8px; text-align: right; color: rgb(218, 129, 3); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.system { margin-left: 4px; word-wrap: break-word; color: rgb(218, 129, 3); font-style: normal; font-weight: normal; white-space: pre-wrap; }.showTimestamp { padding-left: 8px; margin-right: 3px; float: right; color: rgb(153, 153, 153); font-style: normal; font-weight: normal; }.other1 { color: rgb(172, 32, 0); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont1 { font-size: 8px; text-align: right; color: rgb(172, 32, 0); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other2 { color: rgb(60, 159, 168); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont2 { font-size: 8px; text-align: right; color: rgb(60, 159, 168); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other3 { color: rgb(226, 86, 20); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont3 { font-size: 8px; text-align: right; color: rgb(226, 86, 20); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other4 { color: rgb(11, 106, 200); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont4 { font-size: 8px; text-align: right; color: rgb(11, 106, 200); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other5 { color: rgb(178, 50, 144); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont5 { font-size: 8px; text-align: right; color: rgb(178, 50, 144); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other6 { color: rgb(2, 231, 199); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont6 { font-size: 8px; text-align: right; color: rgb(2, 231, 199); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.other7 { color: rgb(91, 50, 132); vertical-align: top; font-weight: bold; font-style: normal; float: left; width: 95px; }.otherCont7 { font-size: 8px; text-align: right; color: rgb(91, 50, 132); font-family: Arial,Lucida Grande; font-style: normal; vertical-align: top; font-weight: bold; float: left; width: 95px; }.highlight { background-color: rgb(190, 214, 248); }.datestamp { cursor: default; margin-bottom: 1px; background-color: rgb(192, 192, 192); width: 100%; float: left; text-align: right; color: rgb(255, 255, 255); font-weight: bold; font-style: italic; }#chatAlert { float: left; border-bottom: 1px solid rgb(232, 208, 145); padding: 6px; width: 100%; color: rgb(165, 117, 76); }#chatAlertImage { float: left; }#chatAlertText { float: left; margin-left: 6px; }#chatAlertClose { float: right; margin-right: 10px; padding-right: 6px; margin-top: 0px; }#chatAlertText a { color: rgb(165, 117, 76); }#chatAlertText a:hover { color: rgb(165, 117, 76); text-decoration: none; }.tsDisplay { display: block; }.dsDisplay { display: block; }

  
   

(7)rmdir /S /Q  $BF_ROOT/%BF_PROJECTNAME%/

 

 

 

 

(8)sometimes, if you use perl liberay,then when you run the script from server to local build machine,it may miss some files,then you need to export to local:try command"export BLDPERLLIB=directory name"


(9)run command "perl -c /path/code.pl"  to do a compile check and find any major coding issues with perl script;

 

(10) use chdir to locate a directory with perl,the right way to use it

if(chdir(directory name) ==0)

{

 can not locate to it,you need to exit with a error code;

}

else

{

do some thing

}

 

(10)windows上用7z解压缩:

7z x $targetDirectory//$downloadFile -o$targetDirectory -o means where to put the unzip file

 

(11)windows 上删除某个文件使用

  del /s/f/q $targetDirectory/$downloadFile

 

(12)windows上用7z压缩:

 $cmd = "7z a -t7z  ${compressFileName} -o$target  $target/name";

 

(13)linux建立和解除文件的link:

unlink  linkename: 解除

ln -s file_to_be_linked link_name: 建立

 

(14)Linux 创建多层目录

mkdir -p directoryname

eg. mkdir -p /test/test

 

(15)Buildforge 中的bset 和tset

.tset env "$vat=***"

 

如果有多个线程同时在不同的platfrom上跑,但是都使用了同一个变量,但这些值必须不同的话,那么不能使用.test。.test会对所有的都生效

 

 

 

 

 

 

 

 

 

 

 

 

 

 

import torch import json import os import argparse import numpy as np import re from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from PIL import Image from peft import LoraConfig, get_peft_model from transformers import ( AutoModelForCausalLM, AutoProcessor, TrainingArguments, BitsAndBytesConfig, GenerationConfig, AutoTokenizer, AutoImageProcessor, get_cosine_schedule_with_warmup ) import torch.optim as optim from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, classification_report import warnings warnings.filterwarnings("ignore", message="Could not find a config file") # 标签映射定义 TASK1_LABEL_MAP = {"无害": 0, "有害": 1} TASK2_LABEL_MAP = { 0: "无害", 1: "针对性有害", 2: "一般性有害", 3: "性暗示", 4: "沮丧文化" } class HarmfulMemeDataset(Dataset): def __init__(self, annotation_path, tokenizer, image_processor, task_name, image_root=None): print(f"Loading annotations from: {annotation_path}") self.image_root = image_root self.tokenizer = tokenizer self.image_processor = image_processor self.processor = tokenizer # 关键修复:添加processor属性引用 with open(annotation_path, 'r', encoding='utf-8') as f: self.data = json.load(f) self.task_name = task_name self.label_token_ids = self._get_label_tokens() def _get_label_tokens(self): """获取标签对应的token ID""" if self.task_name == "task_1": labels = ["无害", "有害"] else: labels = [str(i) for i in range(5)] # 0-4 token_ids = {} tokenizer = self.processor.tokenizer for label in labels: tokens = tokenizer.encode(label) token_ids[label] = tokens # 保存完整的token序列 if len(tokens) > 1: print(f"警告: 标签 '{label}' 被分词为多个token: {tokens}") return token_ids def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] if self.image_root: image_path = os.path.join(self.image_root, sample["path"]) else: image_path = sample["path"] if not os.path.exists(image_path): raise FileNotFoundError(f"图片不存在:{image_path}") image = Image.open(image_path).convert("RGB") if self.task_name == "task_1": user_content = f"任务:有害模因检测(是否有害)。\n输入文本:'{sample['text']}'\n请判断此图文组合是否有害。" raw_label = sample['label']# 原始标签是字符串("无害"或"有害") # 根据TASK1_LABEL_MAP进行映射 if isinstance(raw_label, int): # 如果是整数,转换为对应的字符串标签 label_map = {v: k for k, v in TASK1_LABEL_MAP.items()} # 反转映射 label = label_map.get(raw_label, "无害") # 默认值为"无害" else: # 如果已经是字符串,直接使用 label = raw_label label_token = self.label_token_ids[label] assistant_content = f"结论:{label}。\n理由:{sample['explanation']}" else: user_content = f"任务:有害模因类型分类。\n输入文本:'{sample['text']}'\n请判断此图文组合的有害类型(0-4)。" raw_label = str(sample['type'])# 将整数标签转换为字符串 label = str(raw_label) label_token = self.label_token_ids[label] assistant_content = f"结论:{label}。\n理由:{sample['explanation']}" messages = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": user_content}]}, {"role": "assistant", "content": [{"type": "text", "text": assistant_content}]} ] prompt = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, chat_format="chatml" ) # 单独处理图像 image = self.image_processor( images=image, return_tensors="pt" )["pixel_values"].squeeze(0) # 单独处理文本 encoding = self.tokenizer( text=prompt, return_tensors="pt", padding=False, truncation=False ) prompt_tokens = encoding["input_ids"][0].tolist() # 找到结论标签的位置 conclusion_start = self.processor.tokenizer.encode("结论:") # 在prompt中查找"结论:"的位置 start_idx = -1 for i in range(len(prompt_tokens) - len(conclusion_start) + 1): if prompt_tokens[i:i+len(conclusion_start)] == conclusion_start: start_idx = i + len(conclusion_start) break inputs = self.processor( text=prompt, images=image, return_tensors="pt", padding="max_length", truncation=True, max_length=512 ) inputs = {k: v.squeeze(0) for k, v in inputs.items()} # 创建标签张量,只标记结论位置 labels = torch.full_like(inputs["input_ids"], fill_value=-100, dtype=torch.long) if start_idx != -1 and start_idx < len(labels): # 标记整个标签token序列 label_tokens = self.label_token_ids[label] for i, token_id in enumerate(label_tokens): if start_idx + i < len(labels): labels[start_idx + i] = token_id inputs["labels"] = labels return inputs def parse_generated_text(self,text): """解析生成的文本,提取结论标签""" conclusion_match = re.search(r"结论[::]\s*(\S+)", text) if not conclusion_match: return None conclusion = conclusion_match.group(1).strip().rstrip('。.') # 处理多token标签 if conclusion in ["无害", "有害"]: # 任务1标签 return conclusion elif conclusion.isdigit() and 0 <= int(conclusion) <= 4: # 任务2标签 return conclusion # 尝试分词匹配 tokenizer = AutoProcessor.from_pretrained(args.model_id).tokenizer conclusion_tokens = tokenizer.encode(conclusion, add_special_tokens=False) # 与已知标签的token序列匹配 for label, tokens in self.label_token_ids.items(): if conclusion_tokens == tokens: return label return None def compute_metrics(task_name, preds, labels): """计算评估指标""" mask = labels != -100 preds = preds[mask] labels = labels[mask] if task_name == "task_1": # 二分类任务 return { "accuracy": accuracy_score(labels, preds), "f1": f1_score(labels, preds, average="binary"), "precision": precision_score(labels, preds, average="binary"), "recall": recall_score(labels, preds, average="binary") } else: # 多分类任务 report = classification_report(labels, preds, output_dict=True, zero_division=0) return { "accuracy": accuracy_score(labels, preds), "f1_macro": f1_score(labels, preds, average="macro"), "precision_macro": precision_score(labels, preds, average="macro"), "recall_macro": recall_score(labels, preds, average="macro"), "class_report": report } def main(args): os.environ["TOKENIZERS_PARALLELISM"] = "false" # 1. 加载模型和预处理器 print("Loading model and processor...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model = AutoModelForCausalLM.from_pretrained( args.model_id, quantization_config=quantization_config, trust_remote_code=True, device_map="auto", bf16=True ) model.generation_config = GenerationConfig.from_pretrained( args.model_id, trust_remote_code=True, chat_format="chatml", max_new_tokens=100, pad_token_id=model.generation_config.eos_token_id ) # 分别初始化文本和图像处理器 tokenizer = AutoTokenizer.from_pretrained( args.model_id, trust_remote_code=True, pad_token='<|endoftext|>' # 显式设置pad_token ) image_processor = AutoImageProcessor.from_pretrained( args.model_id, trust_remote_code=True ) tokenizer.chat_template = """{% for message in messages %} <|im_start|>{{ message['role'] }} {{ message['content'] }} <|im_end|> {% endfor %} {% if add_generation_prompt %} <|im_start|>assistant {% endif %}""" # 设置pad token # 确保pad_token正确设置 if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # 2. LoRA配置 print("Configuring LoRA...") lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=[ "c_attn", "c_proj", "w1", "w2", "w3", "visual.proj", "visual.image_encoder" ] ) peft_model = get_peft_model(model, lora_config) peft_model.print_trainable_parameters() # 3. 初始化优化器和调度器 optimizer = optim.AdamW( peft_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) # 4. 训练参数配置 training_args = TrainingArguments( output_dir=os.path.join(args.output_dir, args.task), num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=args.grad_accum_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, lr_scheduler_type="cosine", logging_strategy="steps", logging_steps=10, save_strategy="epoch", eval_strategy="epoch", eval_accumulation_steps=1, metric_for_best_model="f1" if args.task == "task_1" else "f1_macro", greater_is_better=True, load_best_model_at_end=True, bf16=True, report_to="none", remove_unused_columns=False, disable_tqdm=False, skip_memory_metrics=True, dataloader_pin_memory=False, ) # 5. 加载数据集 print(f"Loading datasets for {args.task}...") train_dataset = HarmfulMemeDataset( annotation_path=args.train_annotation_path, tokenizer=tokenizer, image_processor=image_processor, task_name=args.task, image_root=args.image_root ) test_dataset = HarmfulMemeDataset( annotation_path=args.test_annotation_path, tokenizer=tokenizer, image_processor=image_processor, task_name=args.task, image_root=args.image_root ) # 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) eval_loader = DataLoader( test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) # 计算总步数,初始化学习率调度器 total_train_steps = len(train_loader) // args.grad_accum_steps * args.epochs scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_train_steps ) # 6. 训练循环 print(f"Starting {args.task} training...") best_metric = -1 for epoch in range(args.epochs): print(f"\n===== Epoch {epoch + 1}/{args.epochs} =====") # 训练阶段 peft_model.train() total_train_loss = 0.0 train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}", unit="batch") for step, batch in enumerate(train_pbar): batch = {k: v.to(peft_model.device) for k, v in batch.items()} # 前向传播 outputs = peft_model(**batch) loss = outputs.loss total_train_loss += loss.item() # 梯度累积 loss = loss / args.grad_accum_steps loss.backward() # 参数更新 if (step + 1) % args.grad_accum_steps == 0: optimizer.step() scheduler.step() optimizer.zero_grad() # 更新进度条 train_pbar.set_postfix({"loss": f"{loss.item() * args.grad_accum_steps:.4f}"}) avg_train_loss = total_train_loss / len(train_loader) print(f"Epoch {epoch + 1} 平均训练损失: {avg_train_loss:.4f}") # 评估阶段 peft_model.eval() all_preds = [] all_labels = [] all_generated_texts = [] eval_pbar = tqdm(eval_loader, desc=f"Evaluating Epoch {epoch + 1}", unit="batch") with torch.no_grad(): for batch in eval_pbar: # 获取真实标签 labels = batch["labels"].cpu().numpy() mask = labels != -100 valid_labels = labels[mask].reshape(-1) # 生成文本 inputs = {k: v.to(peft_model.device) for k, v in batch.items() if k != "labels"} pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id generated_ids = peft_model.generate( **inputs, generation_config=model.generation_config, pad_token_id=pad_token_id # 使用修正后的值 ) # 解码生成的文本 generated_texts = tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) # 解析生成的文本获取预测标签 batch_preds = [] for text in generated_texts: # 提取assistant的响应部分 if "<|im_start|>assistant" in text: response = text.split("<|im_start|>assistant")[-1].strip() else: response = text # 解析结论 conclusion = parse_generated_text(response) if conclusion is None: # 无法解析结论,使用默认值 pred_label = 0 if args.task == "task_1" else "0" else: pred_label = conclusion # 转换为数字标签 if args.task == "task_1": # 二分类任务 if "无害" in pred_label: pred_value = 0 elif "有害" in pred_label: pred_value = 1 else: # 无法解析,使用默认值 pred_value = 0 else: # 多分类任务 if pred_label in ["0", "1", "2", "3", "4"]: pred_value = int(pred_label) else: # 无法解析,使用默认值 pred_value = 0 batch_preds.append(pred_value) all_preds.extend(batch_preds) all_labels.extend(valid_labels.tolist()) all_generated_texts.extend(generated_texts) # 计算评估指标 metrics = compute_metrics(args.task, np.array(all_preds), np.array(all_labels)) # 打印评估结果 print("\n评估指标:") print("=" * 50) if args.task == "task_1": print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"F1 Score: {metrics['f1']:.4f}") print(f"Precision: {metrics['precision']:.4f}") print(f"Recall: {metrics['recall']:.4f}") else: print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"Macro F1: {metrics['f1_macro']:.4f}") print(f"Macro Precision: {metrics['precision_macro']:.4f}") print(f"Macro Recall: {metrics['recall_macro']:.4f}") print("\n分类报告:") print(classification_report(all_labels, all_preds, target_names=list(TASK2_LABEL_MAP.values()), zero_division=0)) print("=" * 50) # 保存最佳模型 current_metric = metrics["f1"] if args.task == "task_1" else metrics["f1_macro"] if current_metric > best_metric: best_metric = current_metric save_path = os.path.join(training_args.output_dir, f"best_model_epoch{epoch+1}") print(f"保存最佳模型(指标 {current_metric:.4f})到 {save_path}") peft_model.save_pretrained(save_path) # 保存生成的文本示例 sample_output_path = os.path.join(save_path, "sample_outputs.txt") with open(sample_output_path, "w", encoding="utf-8") as f: for i, text in enumerate(all_generated_texts[:10]): f.write(f"样本 {i+1}:\n") f.write(text) f.write("\n" + "-"*80 + "\n") print(f"训练完成!最佳指标: {best_metric:.4f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="训练有害模因检测模型") parser.add_argument("--model_id", default="/xzwu/Qwen-VL-Chat", help="预训练模型路径") parser.add_argument("--output_dir", default="/xzwu/explain-m3-adapter", help="输出目录") parser.add_argument("--epochs", type=int, default=5, help="训练轮数") parser.add_argument("--batch_size", type=int, default=4, help="训练批次大小") parser.add_argument("--eval_batch_size", type=int, default=4, help="评估批次大小") parser.add_argument("--grad_accum_steps", type=int, default=2, help="梯度累积步数") parser.add_argument("--learning_rate", type=float, default=1e-5, help="学习率") parser.add_argument("--weight_decay", type=float, default=0.01, help="权重衰减") parser.add_argument("--warmup_steps", type=int, default=100, help="预热步数") parser.add_argument("--lora_rank", type=int, default=8, help="LoRA秩") parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha") parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout") parser.add_argument("--num_workers", type=int, default=4, help="数据加载工作线程数") parser.add_argument("--task", choices=["task_1", "task_2"], default="task_1", help="任务类型") parser.add_argument("--train_annotation_path", default="/xzwu/data/data/train_data_explanation.json", help="训练标注路径") parser.add_argument("--test_annotation_path", default="/xzwu/data/data/test_data_explanation.json", help="测试标注路径") parser.add_argument("--image_root", default="/xzwu/data/meme", help="图片根目录") args = parser.parse_args() # 打印配置 print("=" * 50) print("训练配置:") for arg in vars(args): print(f"{arg}: {getattr(args, arg)}") print("=" * 50) main(args)运行代码后报错:Epoch 1 平均训练损失: 0.4206 Evaluating Epoch 1: 0%| | 0/600 [00:00<?, ?batch/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer. Evaluating Epoch 1: 0%| | 0/600 [00:12<?, ?batch/s] Traceback (most recent call last): File "/xzwu/explain-m3/explain-m3-project/train2.py", line 531, in <module> main(args) File "/xzwu/explain-m3/explain-m3-project/train2.py", line 431, in main conclusion = parse_generated_text(response) TypeError: parse_generated_text() missing 1 required positional argument: 'text'
07-29
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

惹不起的程咬金

来都来了,不赏点银子么

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值