import os
import json
import time
import psutil
import torch
import gc
import warnings
from dataclasses import dataclass, field
from typing import List, Optional
from PIL import Image, ImageEnhance
from contextlib import ExitStack
from torch.utils.data import Dataset
from transformers import (
AutoProcessor,
TrainingArguments,
Trainer,
TrainerCallback,
Qwen2_5_VLForConditionalGeneration,
set_seed,
logging as transformers_logging
)
from peft import LoraConfig, get_peft_model
import numpy as np
# 忽略非關鍵警告
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
transformers_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ========== 配置 ==========
@dataclass
class ScriptArguments:
train_path: str = field(default="train.jsonl")
valid_path: str = field(default="valid.jsonl")
model_name_or_path: str = field(default="Qwen/Qwen2.5-VL-7B-Instruct")
output_dir: str = field(default="./output_lora_qwen25vl_instruct")
per_device_train_batch_size: int = field(default=1)
gradient_accumulation_steps: int = field(default=4)
num_train_epochs: int = field(default=3)
logging_steps: int = field(default=5)
save_steps: int = field(default=100)
eval_steps: int = field(default=100)
image_size: int = field(default=672)
learning_rate: float = field(default=2e-5)
warmup_steps: int = field(default=50)
weight_decay: float = field(default=0.01)
lora_rank: int = field(default=16)
lora_alpha: int = field(default=32)
lora_dropout: float = field(default=0.05)
fp16: bool = field(default=False)
bf16: bool = field(default=False)
max_steps: int = field(default=-1)
gradient_checkpointing: bool = field(default=True)
seed: int = field(default=42)
report_to: str = field(default="none")
enable_mps_fallback: bool = field(default=True)
# ========== 設備管理 ==========
class DeviceManager:
@staticmethod
def get_device():
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
print("✅ MPS設備可用")
return torch.device("mps")
elif torch.cuda.is_available():
print("✅ CUDA設備可用")
return torch.device("cuda")
else:
print("⚠️ 僅CPU可用")
return torch.device("cpu")
@staticmethod
def clear_memory():
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ========== 資料集 ==========
class SafetyImageDataset(Dataset):
def __init__(self, path, processor, tokenizer, image_token="<|im_start|>", image_size=672):
self.data = []
self.processor = processor
self.tokenizer = tokenizer
self.image_token = image_token
self.image_size = image_size
with open(path, "r", encoding="utf-8") as f:
for line_no, line in enumerate(f, start=1):
try:
example = json.loads(line)
if not isinstance(example, dict):
print(f"⚠️ 第{line_no}行格式錯誤: {example}")
continue
required_keys = ["image", "instruction", "ground_truth"]
if not all(k in example for k in required_keys):
print(f"⚠️ 第{line_no}行缺少欄位: {example}")
continue
self.data.append(example)
except json.JSONDecodeError as e:
print(f"⚠️ 第{line_no}行JSON解析錯誤: {e}")
print(f"✅ 成功載入 {len(self.data)} 筆資料。")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
try:
item = self.data[idx]
# 防禦: 確保每筆資料都是 dict 且有必要欄位
if not isinstance(item, dict):
raise ValueError("非dict資料")
for key in ["image", "instruction", "ground_truth"]:
if key not in item:
raise ValueError(f"缺少欄位: {key}")
return item
except Exception as e:
print(f"❌ __getitem__ idx={idx} 發生錯誤: {e}")
# 回傳一筆空資料避免崩潰
return {
"image": "images/placeholder.jpg", # 你可以放一張不存在的圖
"instruction": "",
"ground_truth": ""
}
class SafetyDataCollator:
def __init__(self, processor, tokenizer, image_token="<|im_start|>", image_size=672):
self.processor = processor
self.tokenizer = tokenizer
self.image_token = image_token
self.image_size = image_size
def _load_image(self, path):
img = Image.open(path).convert("RGB")
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
return img
def __call__(self, examples):
print(f"\n🟢 collator收到 examples type={type(examples)}")
print(f"🔍 每個example型別: {[type(e) for e in examples]}")
print(f"🔍 每個example內容: {examples}")
# 篩掉空dict
examples = [e for e in examples if e and isinstance(e, dict) and "image" in e]
if len(examples) == 0:
raise ValueError("❌ 所有examples都是空dict,無法處理batch!")
images_paths = [e["image"] for e in examples]
images = [self._load_image(p) for p in images_paths]
prompts = [
f"<|user|>\n作為施工安全專家,請分析以下場景中的安全隱患:{self.image_token}\n<|assistant|>"
for _ in examples
]
targets = [e["ground_truth"] for e in examples]
texts = [p + t for p, t in zip(prompts, targets)]
batch = self.processor(
images=images,
text=texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=1024
)
labels = batch["input_ids"].clone()
for i, prompt in enumerate(prompts):
prompt_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0]
labels[i, :len(prompt_ids)] = -100
return {
"input_ids": batch["input_ids"],
"attention_mask": batch["attention_mask"],
"pixel_values": batch["pixel_values"],
"labels": labels
}
# ========== 監控回調 ==========
class SafetyTrainingMonitor(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
mem = psutil.virtual_memory().used / (1024**3)
print(f"📊 Step {state.global_step} | Loss: {logs.get('loss', 'N/A')} | Memory: {mem:.2f}GB")
# ========== 主流程 ==========
def main():
args = ScriptArguments()
set_seed(args.seed)
device = DeviceManager.get_device()
print("🛠️ 設備:", device)
print("🛠️ 參數:", args)
# 加載Processor
processor = AutoProcessor.from_pretrained(
args.model_name_or_path, trust_remote_code=True, use_fast=True
)
tokenizer = processor.tokenizer
# 加載模型
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16 if args.fp16 else torch.float32,
low_cpu_mem_usage=True,
device_map="auto"
)
model.resize_token_embeddings(len(tokenizer))
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# 加載LoRA
peft_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj"],
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# 資料集
train_dataset = SafetyImageDataset(args.train_path, processor, tokenizer, image_size=args.image_size)
valid_dataset = SafetyImageDataset(args.valid_path, processor, tokenizer, image_size=args.image_size)
collator = SafetyDataCollator(processor, tokenizer, image_size=args.image_size)
# 訓練參數
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
save_total_limit=3,
fp16=args.fp16,
bf16=args.bf16,
report_to=args.report_to,
optim="adamw_torch",
lr_scheduler_type="cosine",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=collator,
tokenizer=tokenizer,
callbacks=[SafetyTrainingMonitor()]
)
# 開始訓練
print("🚀 開始訓練...")
trainer.train()
# 保存
trainer.save_model(os.path.join(args.output_dir, "final_model"))
print("✅ 訓練完成!")
if __name__ == "__main__":
main()
(vlm) face8@jamesdeMac-Studio vlm % python train_vlm\ copy.py
✅ MPS設備可用
🛠️ 設備: mps
🛠️ 參數: ScriptArguments(train_path='train.jsonl', valid_path='valid.jsonl', model_name_or_path='Qwen/Qwen2.5-VL-7B-Instruct', output_dir='./output_lora_qwen25vl_instruct', per_device_train_batch_size=1, gradient_accumulation_steps=4, num_train_epochs=3, logging_steps=5, save_steps=100, eval_steps=100, image_size=672, learning_rate=2e-05, warmup_steps=50, weight_decay=0.01, lora_rank=16, lora_alpha=32, lora_dropout=0.05, fp16=False, bf16=False, max_steps=-1, gradient_checkpointing=True, seed=42, report_to='none', enable_mps_fallback=True)
Loading checkpoint shards: 100%|███| 5/5 [00:05<00:00, 1.05s/it]
trainable params: 35,090,432 || all params: 8,324,397,056 || trainable%: 0.4215
✅ 成功載入 934 筆資料。
✅ 成功載入 104 筆資料。
/Users/face8/works/vlm/train_vlm copy.py:258: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
trainer = Trainer(
🚀 開始訓練...
/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/torch/utils/data/dataloader.py:683: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
🟢 collator收到 examples type=<class 'list'>
🔍 每個example型別: [<class 'dict'>]
🔍 每個example內容: [{}]
Traceback (most recent call last):
File "/Users/face8/works/vlm/train_vlm copy.py", line 277, in <module>
main()
File "/Users/face8/works/vlm/train_vlm copy.py", line 270, in main
trainer.train()
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/transformers/trainer.py", line 2207, in train
return inner_training_loop(
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/transformers/trainer.py", line 2503, in _inner_training_loop
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/transformers/trainer.py", line 5301, in get_batch_samples
batch_samples.append(next(epoch_iterator))
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/accelerate/data_loader.py", line 567, in __iter__
current_batch = next(dataloader_iter)
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 733, in __next__
data = self._next_data()
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 789, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
return self.collate_fn(data)
File "/Users/face8/miniconda3/envs/vlm/lib/python3.9/site-packages/transformers/trainer_utils.py", line 872, in __call__
return self.data_collator(features)
File "/Users/face8/works/vlm/train_vlm copy.py", line 152, in __call__
raise ValueError("❌ 所有examples都是空dict,無法處理batch!")
ValueError: ❌ 所有examples都是空dict,無法處理batch!
(vlm) face8@jamesdeMac-Studio vlm %