import argparse
import torch
from diffusers import StableDiffusionPipeline, DDPMScheduler, AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
# ---------- Dataset ----------
class ImageDataset(Dataset):
def __init__(self, image_dir):
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith((".jpg", ".png"))]
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
return self.transform(image)
# ---------- Inject LoRA ----------
def inject_lora(unet, rank=4):
lora_attn_procs = {}
for name, module in unet.named_modules():
if hasattr(module, 'set_processor'):
lora_attn_procs[name] = LoRAAttnProcessor(rank=rank)
unet.set_attn_processor(lora_attn_procs)
print(f"[INFO] LoRA attention processor set with rank={rank}")
# ---------- Save weights ----------
def save_lora_weights(unet, save_path):
weights = {}
for name, module in unet.named_modules():
if isinstance(module, LoRAAttnProcessor):
weights[name] = module.state_dict()
torch.save(weights, save_path)
# ---------- Training ----------
def train_lora(concept_token, data_dir, output_dir, epochs=10, lr=1e-4, batch_size=2):
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device)
unet = pipeline.unet
vae = pipeline.vae
text_encoder = pipeline.text_encoder
tokenizer = pipeline.tokenizer
inject_lora(unet, rank=4)
dataset = ImageDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
noise_scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
optimizer = AdamW(unet.parameters(), lr=lr)
unet.train()
for epoch in range(epochs):
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
batch = batch.to(device, dtype=torch.float16)
with torch.no_grad():
latents = vae.encode(batch).latent_dist.sample() * 0.18215
text_input = tokenizer([concept_token] * batch.shape[0], padding="max_length", truncation=True, return_tensors="pt")
encoder_hidden_states = text_encoder(text_input.input_ids.to(device))[0]
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch.shape[0],), device=device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item())
save_path = os.path.join(output_dir, "lora_weights.pt")
save_lora_weights(unet, save_path)
print(f"LoRA weights saved to {save_path}")
# ---------- Main ----------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--concept_token", type=str, required=True)
parser.add_argument("--instance_data_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--batch_size", type=int, default=2)
args = parser.parse_args()
train_lora(args.concept_token, args.instance_data_dir, args.output_dir, args.epochs, args.lr, args.batch_size)
File "/home/whp/下载/KnowledgeBase/train_lora.py", line 102, in <module>
train_lora(args.concept_token, args.instance_data_dir, args.output_dir, args.epochs, args.lr, args.batch_size)
File "/home/whp/下载/KnowledgeBase/train_lora.py", line 57, in train_lora
inject_lora(unet, rank=4)
File "/home/whp/下载/KnowledgeBase/train_lora.py", line 35, in inject_lora
lora_attn_procs[name] = LoRAAttnProcessor(rank=rank)
TypeError: __init__() got an unexpected keyword argument 'rank'
最新发布