Code to reproduce experimental results of the paper.
Requires Python 3+.
- Create a conda environment:
conda env create -f environment.yml
, - Activate the environment:
conda activate environment
.
The project implements both active learning (AL, --strategy 0
) and data pruning (DP, --strategy 1
).
The command line flag --auto_config
fills in the appropriate hyperparameters based on the model (recommended). The workflow of the main script is as follows:
- Train a query model (possibly across multiple initializations) and retrieves sample scores;
- Acquire (for AL) or remove (for DP) samples based on scores and other factors (e.g., class-wise quotas);
- Potentially repeat steps 1-2 across multiple iterations (
--iterations
, common for AL); - Once the ultimate dataset is determined, train the final model and save its metrics in a json format.
Here are a few simple usage examples. The commands should be executed from a parent directory of the project folder.
- Prune 30% of CIFAR-10 using VGG-16 and EL2N scorer:
python -m drop-data-pruning.main --auto_config --use_gpu --strategy 1 --final_frac 0.7 --model_name VGG16 --scorer_name EL2N
- Randomly prune 30% of CIFAR-10 using VGG-16 and DRoP class-wise ratios with query retrained 5 times:
python -m drop-data-pruning.main --auto_config --use_gpu --strategy 1 --final_frac 0.7 --model_name VGG16 --scorer_name Random --quoter_name DRoP --num_inits 5
- Prune 30% of CIFAR-10 using VGG-16 and Forgetting, and train the final model with a cost-sensitive optimization algorithm CDB-W :
python -m drop-data-pruning.main --auto_config --use_gpu --cdbw_final --strategy 1 --final_frac 0.7 --model_name VGG16 --scorer_name Random
@InProceedings{vysogorets2025drop,
title = {DRoP: Distributionally Robust Data Pruning},
author = {Vysogorets, Artem and Ahuja, Kartik and Kempe, Julia},
booktitle = {Proceedings of the 13th International Conference on Learning Representations},
pages = {1--25},
year = {2025},
series = {Proceedings of Machine Learning Research},
month = {24--28 Apr},
publisher = {PMLR}}