Skip to content

How to diagnose problems in training custom inpaint model #9990

@Marquess98

Description

@Marquess98

Discussed in #9989

Originally posted by Marquess98 November 22, 2024
What I want to do is to perform image inpainting when the input is a set of multimodal images, using sdxl as the pre trained model. But the results are very poor now, and I cannot determine whether it is a problem with the code, dataset, pre trained model, or training parameters.
The infer code snipped is as follows:

noise_scheduler = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
noise_scheduler.set_timesteps(denoise_steps, device=device)

zi = vae.encode(masked_image).latent_dist.sample()
# zi = vae.encode(masked_image).latent_dist.sample()
zi = zi * vae.config.scaling_factor

zd = vae.encode(img2).latent_dist.sample()
zd = zd * vae.config.scaling_factor

zi_m = vae.encode(masked_image).latent_dist.sample()
zi_m = zi_m * vae.config.scaling_factor

noise = torch.randn_like(zi)
denoise_steps = torch.tensor(denoise_steps,dtype=torch.int32,device=device)
timesteps_add, _  = get_timesteps(noise_scheduler, denoise_steps, 1.0, device, denoising_start=None)
start_step = 5

zi_t = noise_scheduler.add_noise(zi, noise, timesteps_add[start_step])  
# mask = mask.unsqueeze(1)
m = F.interpolate(mask.to(zi.dtype), size=(zi.shape[2], zi.shape[3]), 
                    mode='bilinear', align_corners=False)

input_ids = dataset["prompt_ids"].to(device)
input_ids = input_ids.unsqueeze(0)
encoder_hidden_states = text_encoder(input_ids, return_dict=False)[0]

timesteps = noise_scheduler.timesteps
iterable = tqdm(
    enumerate(timesteps),
    total=len(timesteps),
    leave=False,
    desc=" " * 4 + "Diffusion denoising",
)
# iterable = enumerate(timesteps)
start_step = 1
# -----------------------denoise------------------------
for i, t in iterable:
    if i >= start_step:
        unet_input = torch.cat([zi_t, zi_m, zd, m], dim=1)      
        with torch.no_grad():
            noise_pred = unet(unet_input, t, 
                                encoder_hidden_states)[0]
        zi_t = noise_scheduler.step(noise_pred, t, zi_t).prev_sample

# torch.cuda.empty_cache()
decode_rgb = vae.decode(zi_t / vae.config.scaling_factor)
decode_rgb = decode_rgb['sample'].squeeze()

And the results of different start_steps are as follow:[0, 5, 15 respectively]
frame_000940_pred_ddim_st_0
frame_000940_pred_ddim_st_5
frame_000940_pred_ddim_st_15

Another wired thing is the decoder_rgb range is about [-2, 2], Shouldn't its range be [-1, 1] ?
Currently, I think the problem may lie in either the infer code or the scale of dataset(about 5000 sets images so far). Can someone guide me on how to determine which part of the problem it is?
Any suggestions and ideas will be greatly appreciated !!!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions