Skip to content

student_loss.backward() in LADD  #3

@jzhang38

Description

@jzhang38
student_loss.backward()
torch.nn.utils.clip_grad_norm_(student_unet.parameters(), 1.0)
student_optimizer.step()
student_scheduler.step()
student_optimizer.zero_grad()

Wouldn' t above code generate gradients on the discriminator as well? Then in the next training iter, those gradients on the discriminator will be used to in optimizer.step . I think we need a discriminator_optimizer.zero_grad() after student_optimizer.zero_grad() ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions