Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for torch.cuda.amp in VQ-VAE training #65

Open
vvvm23 opened this issue Apr 28, 2021 · 6 comments
Open

Support for torch.cuda.amp in VQ-VAE training #65

vvvm23 opened this issue Apr 28, 2021 · 6 comments

Comments

@vvvm23
Copy link

vvvm23 commented Apr 28, 2021

Feature request for AMP support in VQ-VAE training.
So far, I tried naively modifying the train function in train_vqvae.py like so:

#  ...
for i, (img, label) in enumerate(loader):
    model.zero_grad()

    img = img.to(device)

    with torch.cuda.amp.autocast():
        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
    scaler.scale(loss).backward()

    if scheduler is not None:
        scheduler.step()
    scaler.step(optimizer)
    scaler.update()
# ...

The MSE error appears normal, but the latent error becomes infinite.
I'm going to try a few ideas when I have the time. I suspect that half precision and/or scaling doesn't play well with EMA updates. One "workaround" is to replace EMA with the 2nd term in the loss function in the original paper, so as to only update parameters using gradients, but that is far from ideal.

Thanks!

@rosinality
Copy link
Owner

I think it will be safer to use fp32 for entire quantize operations.

@vvvm23
Copy link
Author

vvvm23 commented Apr 29, 2021

So, wrapping Quantize.forward in @torch.cuda.amp.autocast(enabled=False) and casting the buffers to be type torch.float32? Might also have to cast the input.

@rosinality
Copy link
Owner

Yes. It may work.

@vvvm23
Copy link
Author

vvvm23 commented Apr 29, 2021

Okay! I can make a pull request for this if you want? If not, I can just close this.

@rosinality
Copy link
Owner

If it is suffice to reproduct the result of fp32 training, definitely it would be nice to have.

@vvvm23
Copy link
Author

vvvm23 commented May 5, 2021

For some reason I can't improve forward pass speed under FP16. (maybe it is bottlenecked by FP32 in quantize operations?) Memory usage is improved though. I'll play around with this a little more and then maybe make a pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants