This repository is a pytorch implementation of the Regularized Modernized Dual Averaging (RMDA) algorithm for training structred neural network models. Details of the algorithm can be found in the following paper:
Zih-Syuan Huang, Ching-pei Lee, Training Structured Neural Networks Through Manifold Identification and Variance Reduction[arXiv]
When provided with a regularizer and the corresponding proximal operator, this algorithm trains a neural network model that conforms the structure induced by the regularizer. In this repository, we include the proximal operator of the L1 norm and the group-LASSO norm as illustrating examples, but users can replace them with any other proximal operators.
This repository contains:
- Regularized modernized dual averaging (RMDA) algorithm.
- Scheduler for learning rate, momentum scheduling and restart.
- Proximal operators for the group-LASSO and L1 norms.
- Training file. An exemplary wrapper for using our algorithm to train a structured neural network.
To compile the code, you will need to install torch and torchvision.
To run an experiment of logistic regression on MNIST, run:
python LogisticRegression_on_MNIST.py
in the Experiments directory.