This is the code for the Conditional Mutual Information-Debiasing (CMID) method proposed in the paper Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness by Bhavya Vasudeva, Kameron Shahabi and Vatsal Sharan. (The base code comes from the group_DRO implementation.)
The code uses python 3.6.8. Dependencies can be installed by using:
pip install -r requirements.txt
Change the root_dir variable in data/data.py. Datasets will be stored in the location specified by root_dir. (Check this link for more details.)
Experiments on Waterbirds, CelebA, MultiNLI, and CivilComments datasets.
-
Waterbirds: The code expects the following files/folders in the
[root_dir]/cubdirectory:data/waterbird_complete95_forest2water2/
A tarball of this dataset can be downloaded from this link.
-
CelebA: The code expects the following files/folders in the
[root_dir]/celebAdirectory:data/list_eval_partition.csvdata/list_attr_celeba.csvdata/img_align_celeba/
These dataset files can be downloaded from this Kaggle link.
-
MultiNLI: The code expects the following files/folders in the
[root_dir]/multinlidirectory:data/metadata_random.csvglue_data/MNLI/cached_dev_bert-base-uncased_128_mnliglue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mmglue_data/MNLI/cached_train_bert-base-uncased_128_mnli
The metadata file is included in
dataset_metadata/multinliin the folder. Theglue_data/MNLIfiles are generated by the huggingface Transformers library and can be downloaded here. -
CivilComments: The code expects the following files/folders in the
[root_dir]/civcomdirectoryall_data_with_grouped_identities.csvall_data_with_identities.csv
A tarball of this dataset can be downloaded from this link.
The main files to run the experiment and infer results are run_expt.py and parse_log_file.py, respectively. The specific commands are listed below:
-
Waterbirds:
python run_expt.py --log_dir /CMID/log-wb -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.0005 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 100 --cmi_reg --log_every 20 --reg_st 20.0 --cmistinc --scale 4python parse_log_file.py --log_dir /CMID/log-wb --num_groups 4 -
CelebA:
python run_expt.py --log_dir /CMID/log-cel -s confounder -d CelebA -t Blond_Hair -c Male --lr 0.0003 --batch_size 128 --weight_decay 0.001 --model resnet50 --n_epochs 50 --cmi_reg --log_every 20 --reg_st 10.0 --cmistinc --scale 5python parse_log_file.py --log_dir /CMID/log-cel --num_groups 4 -
MultiNLI:
python /run_expt.py --log_dir /CMID/log-mnli -s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --lr 5e-05 --batch_size 32 --weight_decay 0 --model bert --n_epochs 5 --cmi_reg --reg_st 75.0 --cmistinc --lr1 0.005python parse_log_file.py --log_dir /CMID/log-mnli --num_groups 6 -
CivilComments:
python run_expt.py --log_dir /CMID/log-ccom -s confounder -d CivComMod -t toxicity -c identity_any --lr 0.00001 --batch_size 32 --weight_decay 0.001 --model bert-base-uncased --n_epochs 10 --cmi_reg --reg_st 25.0 --cmistinc --lr1 0.0001python parse_log_file.py --log_dir /CMID/log-ccom --num_groups 16
The code expects the following files/folders in the ./camelyon directory.
data/camelyon17_v1.0/metadata.csvdata/camelyon17_v1.0/patches/
Including all the patch data. If these files do not exist, the code will download them here during run time.
We use a different file for Camelyon to use Wilds dataloading. To run it, go into the ./camelyon directory and run the following sample command, which will output camelyon.txt in the same directory containing results.
python camelyon.py --cmi_reg --epochs 5 --epochs2 10 --lr 0.0001 --lr1 0.0001 --weight_decay 0.01 --reg_st 0.5 --batch_size 32 &> camelyon.txt
If you find our research useful, please cite our work.
@article{
vasudeva2024mitigating,
title={Mitigating Simplicity Bias in Deep Learning for Improved {OOD} Generalization and Robustness},
author={Bhavya Vasudeva and Kameron Shahabi and Vatsal Sharan},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=XccFHGakyU},
note={}
}