训练入口+保存模型

import argparse
from torch.utils.data import DataLoader
from testing.testing import *
from models.model_coupled_v1 import Unet
from data.data_load import *
import glob
from collections import OrderedDict


device = "cuda:0" if torch.cuda.is_available() else "cpu"

cat = True  # Concatenate sketch on input
image_size = 256
channels = 4
batch_size = 1
timesteps = 1000

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--sketch_dir', type=str, required=False, default='/home/featurize/data/AnimeDiffusion-main/AnimeDiffusion_Dataset/train_data/sketch', help='Path to the directory containing line art images.')
    parser.add_argument('--scrib_dir', type=str, required=False, default='/home/featurize/work/Diffusart-CVPRW/samples/scrib/imgtuya', help='Path to the directory containing color scribbles images.')
    parser.add_argument('--target_dir', type=str, required=False, default='/home/featurize/data/AnimeDiffusion-main/AnimeDiffusion Dataset/train_data/reference', help='Path to the directory containing color scribbles images.')
    parser.add_argument('--out_dir', type=str, required=False, default='/home/featurize/work/Diffusart-CVPRW/results', help='Path to the directory containing color scribbles images.')
    parser.add_argument('--model_path', type=str, required=False, default='./checkpoint/diffusart_v1.pth', help='Path to the .pth model file.')
    args = parser.parse_args()

    # Reading all images from directories
    sketch_path = glob.glob(args.sketch_dir + '/*.jpg')
    target_path = glob.glob(args.target_dir + '/*.jpg')
    scrib_path = glob.glob(args.scrib_dir + '/*.png')
    loader_train = MyData_paper_train(sketch_path, scrib_path, target_path, size=image_size)

    dataloader_train = DataLoader(loader_train, batch_size=batch_size, num_workers=0, shuffle=True)

    # 定义验证数据路径
    val_sketch_path = glob.glob(args.sketch_dir.replace('train_data', 'val_data') + '/*.jpg')
    val_target_path = glob.glob(args.target_dir.replace('train_data', 'val_data') + '/*.jpg')
    val_scrib_path = glob.glob(args.scrib_dir.replace('train_data', 'val_data') + '/*.png')

    # 创建验证数据集
    loader_val = MyData_paper_train(val_sketch_path, val_scrib_path, val_target_path, size=image_size)

    # 创建验证数据加载器
    dataloader_val = DataLoader(loader_val, batch_size=batch_size, num_workers=0, shuffle=False)

    model = Unet(
        dim=image_size,
        channels=channels,
        dim_mults=(1, 2,)
    ).to(device)

    print('Entering to inference')

    # Loading the model
    state_dict = torch.load(args.model_path, map_location=device)

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    training_scribs(
        model=model,
        optimizer=optimizer,
        dataloader=dataloader_train,
        val_dataloader=dataloader_val,
        channels=channels,
        image_size=image_size,
        out_path=args.out_dir,
        device=device,
        cat=cat,
        val_interval=10,
        save_interval=100,
        max_epochs=100
    )
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from diffusers import DPMSolverMultistepScheduler
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import sys
# 将项目根目录添加到 sys.path
project_root = "/home/featurize/work/Diffusart-CVPRW"
if project_root not in sys.path:
    sys.path.append(project_root)

from models.schedulers import *

reverse_transform_torch = transforms.Compose([
    transforms.Lambda(lambda t: (t + 1) / 2),
])

device = "cuda:0" if torch.cuda.is_available() else "cpu"

################################################# Forward process #######################################################
timesteps_inf = 2
timesteps = 1000
# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
# betas = betas.to(device)
# betas = cosine_beta_schedule(timesteps)
# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

################################### with hints ###########################################################"

def p_sample_hints(model, x_in, feat, t, t_index, cat):
    # print(x_in.shape)
    noise_pred = model(x_in, feat, t.to(device))
    if cat:
        x = x_in[:, 1:, :, :]
    else:
        x = x_in
    # print('in', x_in.shape)
    betas_t = extract_inf(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract_inf(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract_inf(sqrt_recip_alphas, t, x.shape)
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    sqrt_recip_alphas_t = sqrt_recip_alphas_t.to(device)
    betas_t = betas_t.to(device)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.to(device)
    # print(x.shape, noise_pred.shape)
    # noise_pred = torch.cat((feat[:, 0:1, :, :], noise_pred), dim=1)
    # print(sqrt_recip_alphas_t.shape, x.shape, betas_t.shape, noise_pred.shape, sqrt_one_minus_alphas_cumprod_t.shape)
    model_mean = sqrt_recip_alphas_t * (
            x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract_inf(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        noise = noise.to(device)
        sqrt_var = torch.sqrt(posterior_variance_t)
        sqrt_var = sqrt_var.to(device)
        # Algorithm 2 line 4:
        return model_mean + sqrt_var * noise

    # Algorithm 2 (including returning all images)


def p_sample_loop_hints(model, noise, feat, hints, shape, cat=None):
    # device = next(model.parameters()).device
    b = shape[0]
    sketch = feat
    feat = torch.cat((feat[:b], hints[:b]), dim=1)
    # start from pure noise (for each example in the batch)
    if cat:
        img = noise[:, 1:, :, :]
    else:
        img = noise
    imgs = []

    # 禁用 tqdm 进度条,设置 disable=True
    for i in tqdm(reversed(range(0, timesteps_inf)), desc='sampling loop time step', total=timesteps_inf, disable=True):
        if cat:
            img = torch.cat((sketch[:b], img[:b]), dim=1)

        img = p_sample_hints(model, img, feat, torch.full((b,), i, dtype=torch.long), i, cat)
        imgs.append(img.cpu())
    return imgs


def sample_hints(model, noise, feat, hints, image_size, batch_size=16, channels=3, cat=None):
    return p_sample_loop_hints(model, noise, feat, hints, shape=(batch_size, channels, image_size, image_size), cat=cat)


def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)


def extract_inf(tensor, t, shape):
    # 这里假设 extract_inf 函数是根据 t 从 tensor 中提取对应的值
    # 并扩展到 shape 形状
    out = tensor.gather(-1, t.cpu())
    return out.reshape(shape[0], *((1,) * (len(shape) - 1))).to(t.device)


def training_scribs(model, optimizer, dataloader, val_dataloader, channels, image_size, out_path, device, cat, val_interval=10, save_interval=100, max_epochs=100):
    # 检查输入参数
    if not isinstance(model, nn.Module):
        raise ValueError("model must be an instance of nn.Module")
    if not isinstance(optimizer, torch.optim.Optimizer):
        raise ValueError("optimizer must be an instance of torch.optim.Optimizer")
    if not isinstance(dataloader, torch.utils.data.DataLoader):
        raise ValueError("dataloader must be an instance of torch.utils.data.DataLoader")
    if val_dataloader is not None and not isinstance(val_dataloader, torch.utils.data.DataLoader):
        raise ValueError("val_dataloader must be an instance of torch.utils.data.DataLoader")
    if not isinstance(channels, int) or channels <= 0:
        raise ValueError("channels must be a positive integer")
    if not isinstance(image_size, int) or image_size <= 0:
        raise ValueError("image_size must be a positive integer")
    if not isinstance(out_path, str) or not out_path:
        raise ValueError("out_path must be a non-empty string")
    if device not in ['cpu', 'cuda', 'cuda:0', 'cuda:1', 'cuda:2', 'cuda:3']:
        raise ValueError("device must be a valid device name")
    if not isinstance(cat, bool):
        raise ValueError("cat must be a boolean")
    if not isinstance(val_interval, int) or val_interval <= 0:
        raise ValueError("val_interval must be a positive integer")
    if not isinstance(save_interval, int) or save_interval <= 0:
        raise ValueError("save_interval must be a positive integer")
    if not isinstance(max_epochs, int) or max_epochs <= 0:
        raise ValueError("max_epochs must be a positive integer")

    create_directory(out_path)
    model.train()
    scheduler_DPM = DPMSolverMultistepScheduler(beta_schedule='linear', beta_start=1e-4, algorithm_type='dpmsolver++', solver_order=2, num_train_timesteps=1000, thresholding=True)
    scheduler_DPM.set_timesteps(num_inference_steps=100)

    # 初始化 TensorBoard 日志
    writer = SummaryWriter(log_dir=os.path.join(out_path, 'logs'))
    best_val_loss = float('inf')  # 初始化最佳验证损失为正无穷

    for epoch in range(max_epochs):
        # 训练阶段
        train_losses = []
        for idx, batch in enumerate(dataloader):
            batch_size = batch[0].shape[0]
            # 检查数据形状
            # if batch[0].shape[1:] != (channels, image_size, image_size):
            #     raise ValueError(f"Unexpected shape for sketch data: {batch[0].shape[1:]}, expected ({channels}, {image_size}, {image_size})")
            # if batch[1].shape[1:] != (channels, image_size, image_size):
            #     raise ValueError(f"Unexpected shape for hints data: {batch[1].shape[1:]}, expected ({channels}, {image_size}, {image_size})")
            # if batch[2].shape[1:] != (channels, image_size, image_size):
            #     raise ValueError(f"Unexpected shape for target data: {batch[2].shape[1:]}, expected ({channels}, {image_size}, {image_size})")

            sketch = batch[0].to(device).to(dtype=torch.float)
            hints = batch[1].to(device).to(dtype=torch.float)
            target = batch[2].to(device).to(dtype=torch.float)

            shape = (batch_size, channels, image_size, image_size)
            torch.manual_seed(2)  # Manual seed
            noise = torch.randn(shape, device=device)
            samples = sample_hints(model, noise, sketch, hints, image_size=image_size, batch_size=batch_size,
                                   channels=channels, cat=cat)

            samples_hints = make_grid(reverse_transform_torch(hints[:, :3, :, :]))
            samples_grid = make_grid(reverse_transform_torch(samples[-1]))

            loss = F.mse_loss(samples[-1].to(device), target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            # 记录训练损失到 TensorBoard
            writer.add_scalar('Loss/train', loss.item(), epoch * len(dataloader) + idx)
            print(f"Epoch: {epoch}, Batch: {idx}, Loss: {loss.item()}")

        # 计算平均训练损失
        avg_train_loss = sum(train_losses) / len(train_losses)
        print(f"Epoch {epoch} - Average Training Loss: {avg_train_loss}")

        # 验证阶段
        if val_dataloader is not None and epoch % val_interval == 0:
            model.eval()
            val_losses = []
            with torch.no_grad():
                for val_idx, val_batch in enumerate(val_dataloader):
                    val_batch_size = val_batch[0].shape[0]
                    # 检查验证数据形状
                    # if val_batch[0].shape[1:] != (channels, image_size, image_size):
                    #     raise ValueError(f"Unexpected shape for val_sketch data: {val_batch[0].shape[1:]}, expected ({channels}, {image_size}, {image_size})")
                    # if val_batch[1].shape[1:] != (channels, image_size, image_size):
                    #     raise ValueError(f"Unexpected shape for val_hints data: {val_batch[1].shape[1:]}, expected ({channels}, {image_size}, {image_size})")
                    # if val_batch[2].shape[1:] != (channels, image_size, image_size):
                    #     raise ValueError(f"Unexpected shape for val_target data: {val_batch[2].shape[1:]}, expected ({channels}, {image_size}, {image_size})")

                    val_sketch = val_batch[0].to(device).to(dtype=torch.float)
                    val_hints = val_batch[1].to(device).to(dtype=torch.float)
                    val_target = val_batch[2].to(device).to(dtype=torch.float)

                    val_shape = (val_batch_size, channels, image_size, image_size)
                    val_noise = torch.randn(val_shape, device=device)
                    val_samples = sample_hints(model, val_noise, val_sketch, val_hints, image_size=image_size, batch_size=val_batch_size,
                                               channels=channels, cat=cat)

                    val_loss = F.mse_loss(val_samples[-1].to(device), val_target).item()
                    val_losses.append(val_loss)

            avg_val_loss = sum(val_losses) / len(val_losses)
            writer.add_scalar('Loss/validation', avg_val_loss, epoch)
            print(f"Epoch: {epoch}, Validation Loss: {avg_val_loss}")

            # 如果当前验证损失小于最佳验证损失,保存模型
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                save_path = os.path.join(out_path, f'model_best.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_val_loss
                }, save_path)
                print(f"Best model saved to {save_path} with validation loss: {avg_val_loss}")

            model.train()

        # 保存模型参数
        if epoch % save_interval == 0:
            save_path = os.path.join(out_path, f'model_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_train_loss
            }, save_path)
            print(f"Model saved to {save_path}")

    # 保存最终模型
    final_save_path = os.path.join(out_path, f'model_final.pth')
    torch.save({
        'epoch': max_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_train_loss
    }, final_save_path)
    print(f"Final model saved to {final_save_path}")

    writer.close()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值