import argparse
from torch.utils.data import DataLoader
from testing.testing import *
from models.model_coupled_v1 import Unet
from data.data_load import *
import glob
from collections import OrderedDict
device = "cuda:0" if torch.cuda.is_available() else "cpu"
cat = True
image_size = 256
channels = 4
batch_size = 1
timesteps = 1000
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--sketch_dir', type=str, required=False, default='/home/featurize/data/AnimeDiffusion-main/AnimeDiffusion_Dataset/train_data/sketch', help='Path to the directory containing line art images.')
parser.add_argument('--scrib_dir', type=str, required=False, default='/home/featurize/work/Diffusart-CVPRW/samples/scrib/imgtuya', help='Path to the directory containing color scribbles images.')
parser.add_argument('--target_dir', type=str, required=False, default='/home/featurize/data/AnimeDiffusion-main/AnimeDiffusion Dataset/train_data/reference', help='Path to the directory containing color scribbles images.')
parser.add_argument('--out_dir', type=str, required=False, default='/home/featurize/work/Diffusart-CVPRW/results', help='Path to the directory containing color scribbles images.')
parser.add_argument('--model_path', type=str, required=False, default='./checkpoint/diffusart_v1.pth', help='Path to the .pth model file.')
args = parser.parse_args()
sketch_path = glob.glob(args.sketch_dir + '/*.jpg')
target_path = glob.glob(args.target_dir + '/*.jpg')
scrib_path = glob.glob(args.scrib_dir + '/*.png')
loader_train = MyData_paper_train(sketch_path, scrib_path, target_path, size=image_size)
dataloader_train = DataLoader(loader_train, batch_size=batch_size, num_workers=0, shuffle=True)
val_sketch_path = glob.glob(args.sketch_dir.replace('train_data', 'val_data') + '/*.jpg')
val_target_path = glob.glob(args.target_dir.replace('train_data', 'val_data') + '/*.jpg')
val_scrib_path = glob.glob(args.scrib_dir.replace('train_data', 'val_data') + '/*.png')
loader_val = MyData_paper_train(val_sketch_path, val_scrib_path, val_target_path, size=image_size)
dataloader_val = DataLoader(loader_val, batch_size=batch_size, num_workers=0, shuffle=False)
model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2,)
).to(device)
print('Entering to inference')
state_dict = torch.load(args.model_path, map_location=device)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
training_scribs(
model=model,
optimizer=optimizer,
dataloader=dataloader_train,
val_dataloader=dataloader_val,
channels=channels,
image_size=image_size,
out_path=args.out_dir,
device=device,
cat=cat,
val_interval=10,
save_interval=100,
max_epochs=100
)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from diffusers import DPMSolverMultistepScheduler
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import sys
project_root = "/home/featurize/work/Diffusart-CVPRW"
if project_root not in sys.path:
sys.path.append(project_root)
from models.schedulers import *
reverse_transform_torch = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2),
])
device = "cuda:0" if torch.cuda.is_available() else "cpu"
timesteps_inf = 2
timesteps = 1000
betas = linear_beta_schedule(timesteps=timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
def p_sample_hints(model, x_in, feat, t, t_index, cat):
noise_pred = model(x_in, feat, t.to(device))
if cat:
x = x_in[:, 1:, :, :]
else:
x = x_in
betas_t = extract_inf(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract_inf(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract_inf(sqrt_recip_alphas, t, x.shape)
sqrt_recip_alphas_t = sqrt_recip_alphas_t.to(device)
betas_t = betas_t.to(device)
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.to(device)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract_inf(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
noise = noise.to(device)
sqrt_var = torch.sqrt(posterior_variance_t)
sqrt_var = sqrt_var.to(device)
return model_mean + sqrt_var * noise
def p_sample_loop_hints(model, noise, feat, hints, shape, cat=None):
b = shape[0]
sketch = feat
feat = torch.cat((feat[:b], hints[:b]), dim=1)
if cat:
img = noise[:, 1:, :, :]
else:
img = noise
imgs = []
for i in tqdm(reversed(range(0, timesteps_inf)), desc='sampling loop time step', total=timesteps_inf, disable=True):
if cat:
img = torch.cat((sketch[:b], img[:b]), dim=1)
img = p_sample_hints(model, img, feat, torch.full((b,), i, dtype=torch.long), i, cat)
imgs.append(img.cpu())
return imgs
def sample_hints(model, noise, feat, hints, image_size, batch_size=16, channels=3, cat=None):
return p_sample_loop_hints(model, noise, feat, hints, shape=(batch_size, channels, image_size, image_size), cat=cat)
def create_directory(path):
if not os.path.exists(path):
os.makedirs(path)
def extract_inf(tensor, t, shape):
out = tensor.gather(-1, t.cpu())
return out.reshape(shape[0], *((1,) * (len(shape) - 1))).to(t.device)
def training_scribs(model, optimizer, dataloader, val_dataloader, channels, image_size, out_path, device, cat, val_interval=10, save_interval=100, max_epochs=100):
if not isinstance(model, nn.Module):
raise ValueError("model must be an instance of nn.Module")
if not isinstance(optimizer, torch.optim.Optimizer):
raise ValueError("optimizer must be an instance of torch.optim.Optimizer")
if not isinstance(dataloader, torch.utils.data.DataLoader):
raise ValueError("dataloader must be an instance of torch.utils.data.DataLoader")
if val_dataloader is not None and not isinstance(val_dataloader, torch.utils.data.DataLoader):
raise ValueError("val_dataloader must be an instance of torch.utils.data.DataLoader")
if not isinstance(channels, int) or channels <= 0:
raise ValueError("channels must be a positive integer")
if not isinstance(image_size, int) or image_size <= 0:
raise ValueError("image_size must be a positive integer")
if not isinstance(out_path, str) or not out_path:
raise ValueError("out_path must be a non-empty string")
if device not in ['cpu', 'cuda', 'cuda:0', 'cuda:1', 'cuda:2', 'cuda:3']:
raise ValueError("device must be a valid device name")
if not isinstance(cat, bool):
raise ValueError("cat must be a boolean")
if not isinstance(val_interval, int) or val_interval <= 0:
raise ValueError("val_interval must be a positive integer")
if not isinstance(save_interval, int) or save_interval <= 0:
raise ValueError("save_interval must be a positive integer")
if not isinstance(max_epochs, int) or max_epochs <= 0:
raise ValueError("max_epochs must be a positive integer")
create_directory(out_path)
model.train()
scheduler_DPM = DPMSolverMultistepScheduler(beta_schedule='linear', beta_start=1e-4, algorithm_type='dpmsolver++', solver_order=2, num_train_timesteps=1000, thresholding=True)
scheduler_DPM.set_timesteps(num_inference_steps=100)
writer = SummaryWriter(log_dir=os.path.join(out_path, 'logs'))
best_val_loss = float('inf')
for epoch in range(max_epochs):
train_losses = []
for idx, batch in enumerate(dataloader):
batch_size = batch[0].shape[0]
sketch = batch[0].to(device).to(dtype=torch.float)
hints = batch[1].to(device).to(dtype=torch.float)
target = batch[2].to(device).to(dtype=torch.float)
shape = (batch_size, channels, image_size, image_size)
torch.manual_seed(2)
noise = torch.randn(shape, device=device)
samples = sample_hints(model, noise, sketch, hints, image_size=image_size, batch_size=batch_size,
channels=channels, cat=cat)
samples_hints = make_grid(reverse_transform_torch(hints[:, :3, :, :]))
samples_grid = make_grid(reverse_transform_torch(samples[-1]))
loss = F.mse_loss(samples[-1].to(device), target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
writer.add_scalar('Loss/train', loss.item(), epoch * len(dataloader) + idx)
print(f"Epoch: {epoch}, Batch: {idx}, Loss: {loss.item()}")
avg_train_loss = sum(train_losses) / len(train_losses)
print(f"Epoch {epoch} - Average Training Loss: {avg_train_loss}")
if val_dataloader is not None and epoch % val_interval == 0:
model.eval()
val_losses = []
with torch.no_grad():
for val_idx, val_batch in enumerate(val_dataloader):
val_batch_size = val_batch[0].shape[0]
val_sketch = val_batch[0].to(device).to(dtype=torch.float)
val_hints = val_batch[1].to(device).to(dtype=torch.float)
val_target = val_batch[2].to(device).to(dtype=torch.float)
val_shape = (val_batch_size, channels, image_size, image_size)
val_noise = torch.randn(val_shape, device=device)
val_samples = sample_hints(model, val_noise, val_sketch, val_hints, image_size=image_size, batch_size=val_batch_size,
channels=channels, cat=cat)
val_loss = F.mse_loss(val_samples[-1].to(device), val_target).item()
val_losses.append(val_loss)
avg_val_loss = sum(val_losses) / len(val_losses)
writer.add_scalar('Loss/validation', avg_val_loss, epoch)
print(f"Epoch: {epoch}, Validation Loss: {avg_val_loss}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_path = os.path.join(out_path, f'model_best.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_val_loss
}, save_path)
print(f"Best model saved to {save_path} with validation loss: {avg_val_loss}")
model.train()
if epoch % save_interval == 0:
save_path = os.path.join(out_path, f'model_epoch_{epoch}.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_train_loss
}, save_path)
print(f"Model saved to {save_path}")
final_save_path = os.path.join(out_path, f'model_final.pth')
torch.save({
'epoch': max_epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_train_loss
}, final_save_path)
print(f"Final model saved to {final_save_path}")
writer.close()