Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[RFC] [WIP] Making sampling methods differentiable. #16196

Open
xidulu opened this issue Sep 18, 2019 · 6 comments
Open

[RFC] [WIP] Making sampling methods differentiable. #16196

xidulu opened this issue Sep 18, 2019 · 6 comments
Labels
Feature request RFC Post requesting for comments

Comments

@xidulu
Copy link
Contributor

xidulu commented Sep 18, 2019

Background

Backpropagation through random variables is no easy task. Two main methods are often adopted for derivative estimation: score function estimator and pathwise derivative estimator (see https://arxiv.org/abs/1506.05254 for more details). The former one is wildly used in reinforcement learning while the pathwise derivative estimator could be seen a lot in variational autoencoder related models, often referred to as the reparameterization trick. One of the key differences between the two method is that, pathwise derivative estimator requires the derivative of density function f(x;θ) with respect to the parameter, which requires the sampling operation to have gradient, while the SF estimator could bypass such calculation by using log derivative trick.

Proposal

I'm planning to prototype the pathwise gradient for some of the sampling methods in Deep Numpy (Gaussian and Gamma for now) by applying the following modification:

  1. Add require_grads parameter in python frontend.
  2. Add backward function in the backend.

If my experiment goes well, these enhanced sampling methods could possibly serve as the foundation for the distribution module mentioned in MXNet 2.0 Roadmap #16167
Also, differentiable sampling has been introduced into both Tensorflow (tf.distributions) and Pytorch (torch.distributions) for many years, I think it is necessary for MXNet to have such feature as well.


Update:

Gradient for Gaussian added, under review.
#16330
Next, I will try to implement a vanilla VAE demo based on it to find out if the interface is easy to use in practice.

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended label(s): Feature

@zachgk zachgk added Feature request RFC Post requesting for comments labels Sep 18, 2019
@sxjscience
Copy link
Member

@sxjscience
Copy link
Member

Ping @szhengac who should be in charge of the distribution module.

@xidulu
Copy link
Contributor Author

xidulu commented Sep 26, 2019

@sxjscience
Thanks for your reply, I've briefly read REBAR before, estimating gradient by combing reparameterization trick with REINFORCE has been more and more popular these days. (e.g. https://arxiv.org/abs/1807.11143 , https://arxiv.org/abs/1711.00123 )
I'll have further discussion with @szhengac regarding the distribution module.

@sxjscience
Copy link
Member

Let me link it to #12932

@sxjscience
Copy link
Member

@xidulu Also, you may refer to https://www.tensorflow.org/probability .

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Feature request RFC Post requesting for comments
Projects
None yet
Development

No branches or pull requests

4 participants