This repository contains a PyTorch implementation of the curve-finding methods and WA-ensembling procedure from the paper
by Ivan Anokhin and Dmitry Yarotsky (ICML 2020).
Please cite our work if you find it useful in your research:
@article{anokhin2020low,
title={Low-loss connection of weight vectors: distribution-based approaches},
author={Anokhin, Ivan and Yarotsky, Dmitry},
journal={arXiv preprint arXiv:2008.00741},
year={2020}
}Before usage go to the project directory: cd distribution_connector, install requirements: pip install -r requirements.txt and export PYTHONPATH: export PYTHONPATH=$(pwd).
The code in this repository implements the curve-finding procedure for the various methods for Dense ReLU nets and VGG16, and the Ensembling procedure with Weight Adjusment as discribed in the paper.
To run the curve-finding procedure or the ensembling procedure, you first need to train two or more networks that will serve as the end-points of the curve or as input to the WA ensembling procedure. You can train the endpoints using the following command
python3 train.py --dir=<DIR> \
--dataset=<DATASET> \
--data_path=<PATH> \
--model=<MODEL> \
--epochs=<EPOCHS> \
--lr_init=<LR_INIT> \
--wd=<WD> \
--seed=<SEED>Parameters:
DIR— path to training directory where checkpoints will be storedDATASET— dataset name [MNIST/CIFAR10]DATA_PATH— path to the data directoryMODEL— DNN model name:- for MNIST dataset:
- LinearOneLayer
- for CIFAR10:
- LinearOneLayer100, LinearOneLayer500, LinearOneLayer1000, LinearOneLayer2000
- Linear3NoBias, Linear5NoBias, Linear7NoBias
- VGG16/
- PreResNet110
- for MNIST dataset:
EPOCHS— number of training epochsLR_INIT— initial learning rateWD— weight decaySEED— use different seeds to get different end-points
For example, use the following commands to train LinearOneLayer on MNIST and LinearOneLayer100, Linear3NoBias, VGG16 on CIFAR10:
#LinearOneLayer
python3 train.py --dir=checkpoints/LinearOneLayer/chp1 --dataset=MNIST --data_path=data --model=LinearOneLayer --epochs=30 --seed=1 --cuda
#LinearOneLayer100
python3 train.py --dir=checkpoints/LinearOneLayer100/chp1 --dataset=CIFAR10 --data_path=data --model=LinearOneLayer100 --epochs=400 --seed=1 --cuda
#Linear3NoBias
python3 train.py --dir=checkpoints/Linear3NoBias/chp1 --dataset=CIFAR10 --data_path=data --model=Linear3NoBias --epochs=400 --seed=1 --cuda
#VGG16
python3 train.py --dir=checkpoints/VGG16/chp1 --dataset=CIFAR10 --data_path=data --model=VGG16 --epochs=200 --seed=1 --cuda
To evaluate the methods to connect the endpoints, you can use the following command
python3 eval_curve.py --dir=<DIR> \
--point_finder=<POINTFINDER> \
--method=<METHOD>\
--end_time=<ENDTIME>\
--dataset=<DATASET> \
--data_path=<PATH> \
--model=<MODEL> \
--start=<START> \
--end=<END> \
--num_points=<NUM_POINTS>Parameters
POINTFINDER— algorithm that proposes samples of distribution to connect and may do some additional routine to preserve output of the network [PointFinderWithBias/PointFinderInverseWithBias/PointFinderTransportation/PointFinderInverseWithBiasOT/PointFinderSimultaneous/PointFinderStepWiseButterfly/PointFinderStepWiseInverse/PointFinderStepWiseTransportation/PointFinderStepWiseInverseOT]METHOD— method that connects proposed by POINTFINDER samples [lin_connect/arc_connect]; lin_connect and arc_connect refer to Eq. 5 and Eq. 6 in the paper respectively.
POINTFINDERandMETHODtogether determine the curve-finding procedures we examine in the paper. For example, in Table 1 in the paper PointFinderWithBias lin_connect refers to theLinear, PointFinderWithBias arc_connect refers toArc, PointFinderInverseWithBias lin_connect refers toLinear + Weight Adjustment, PointFinderInverseWithBias arc_connect refers toArc + Weight Adjustment, PointFinderTransportation lin_connect refers toOT, PointFinderInverseWithBiasOT lin_connect refers toOT + Weight Adjustment. Also, in Table 2 in the paper PointFinderSimultaneous lin_connect refers toLinear, PointFinderSimultaneous arc_connect refers toArc, PointFinderStepWiseButterfly lin_connect refers toLinear + B-fly, PointFinderStepWiseButterfly arc_connect refers toArc + B-fly, PointFinderStepWiseInverse lin_connect refers toLinear + WA, PointFinderStepWiseInverse arc_connect refers toArc + WA, PointFinderStepWiseTransportation lin_connect refers toOT + B-fly, PointFinderStepWiseInverseOT lin_connect toOT + WA,
START— path to the first checkpoint saved bytrain.pyEND— path to the second checkpoint saved bytrain.pyNUM_POINTS— number of points along the curve to use for evaluationENDTIME—POINTFINDERandMODELdependent time (parametrization of the curve) when the curve reaches the endpoint
eval_curve.py outputs the statistics on train and test loss and error along the curve. It also saves a .npz file containing more detailed statistics at <DIR>.
For example, use the following commands to evaluate the paths on CIFAR10:
#PointFinderWithBias lin_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderWithBias --point_finder=PointFinderWithBias --method=lin_connect --model=LinearOneLayer100 --end_time=1 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderInverseWithBias arc_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderInverseWithBias --point_finder=PointFinderInverseWithBias --method=arc_connect --model=LinearOneLayer100 --end_time=2 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderTransportation lin_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderTransportation --point_finder=PointFinderTransportation --method=lin_connect --model=LinearOneLayer100 --end_time=1 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderInverseWithBiasOT lin_connect for LinearOneLayer100 model
python3 eval_curve.py --dir=experiments/eval/LinearOneLayer100/PointFinderInverseWithBiasOT --point_finder=PointFinderInverseWithBiasOT --method=lin_connect --model=LinearOneLayer100 --end_time=2 --data_path=data --num_points=21 --start=checkpoints/LinearOneLayer100/chp1/checkpoint-400.pt --end=checkpoints/LinearOneLayer100/chp2/checkpoint-400.pt --cuda
#PointFinderSimultaneous lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderSimultaneous --point_finder=PointFinderSimultaneous --method=lin_connect --model=Linear3NoBias --end_time=1 --data_path=data --num_points=21 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseButterfly arc_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseButterfly --point_finder=PointFinderStepWiseButterfly --method=arc_connect --model=Linear3NoBias --end_time=2 --data_path=data --num_points=21 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseInverse lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseInverse --point_finder=PointFinderStepWiseInverse --method=lin_connect --model=Linear3NoBias --end_time=3 --data_path=data --num_points=31 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseTransportation lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseTransportation --point_finder=PointFinderStepWiseTransportation --method=lin_connect --model=Linear3NoBias --end_time=2 --data_path=data --num_points=21 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseInverseOT lin_connect for Linear3NoBias model
python3 eval_curve.py --dir=experiments/eval/Linear3NoBias/PointFinderStepWiseInverseOT --point_finder=PointFinderStepWiseInverseOT --method=lin_connect --model=Linear3NoBias --end_time=3 --data_path=data --num_points=31 --start=checkpoints/Linear3NoBias/chp1/checkpoint-400.pt --end=checkpoints/Linear3NoBias/chp2/checkpoint-400.pt --cuda
#PointFinderStepWiseButterflyConvWBiasOT lin_connect for VGG16
python3 eval_curve.py --dir=experiments/eval/VGG16lin/PointFinderStepWiseButterflyConvWBiasOT/12 --point_finder=PointFinderStepWiseButterflyConvWBiasOT --method=lin_connect --model=VGG16 --end_time=15 --data_path=data --num_points=61 --start=checkpoints/VGG16/chp1/checkpoint-400.pt --end=checkpoints/VGG16/chp2/checkpoint-400.pt
To evaluate results of Ensembling with Weight Adjustment you can use the following command
python3 eval_ensemble.py --dir=<DIR> \
--data_path=<PATH> \
--model=<MODEL> \
--name=<NAME> \
--layer=<LAYER>\
--layer_ind=<LAYERIND>\
--model_paths=<MPATHS>Parameters
NAME— substring that is in all checkpoint's names you want to ensemble. For example, specify NAME=400 if you want to ensemble checkpoints trained 400 epochs.LAYER— index of the layer in pytorch network implementation after which Weight Adjusment procedure is performedLAYERIND— index of the layer in parameter space on which Weight Adjusment procedure is performedMPATHS— path to the directory where checkpoints for ensembling are stored
For example, use the following commands to evaluate the WA(n) Ensembling (please see Section 6 in the paper for WA(n)):
#Linear3NoBias WA(1)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear3NoBias/ --data_path=data --model=Linear3NoBias --name=400 --layer=1 --layer_ind=2 --model_paths=checkpoints/Linear3NoBias/
#Linear3NoBias WA(2)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear3NoBias/ --data_path=data --model=Linear3NoBias --name=400 --layer=0 --layer_ind=1 --model_paths=checkpoints/Linear3NoBias/
#Linear5NoBias WA(1)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear5NoBias/ --data_path=data --model=Linear5NoBias --name=400 --layer=3 --layer_ind=4 --model_paths=checkpoints/curves/Linear5NoBias/
#Linear7NoBias WA(1)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear7NoBias/ --data_path=data --model=Linear7NoBias --name=400 --layer=5 --layer_ind=6 --model_paths=checkpoints/Linear7NoBias/
#Linear7NoBias WA(3)
python3 eval_ensemble.py --dir=experiments/eval_ensemble/Linear7NoBias/ --data_path=data --model=Linear7NoBias --name=400 --layer=3 --layer_ind=4 --model_paths=checkpoints/Linear7NoBias/
#VGG16 WA(9)
#python3 eval_ensemble.py --dir=experiments/eval_ensemble/VGG16cifar100w9/ --data_path=data --model=VGG16 --name=200 --layer=9 --layer_ind=-14 --model_paths=checkpoints/cifar100/VGG16 --dataset=CIFAR100
#VGG16 WA(10)
#python3 eval_ensemble.py --dir=experiments/eval_ensemble/VGG16cifar100w10/ --data_path=data --model=VGG16 --name=200 --layer=10 --layer_ind=-12 --model_paths=checkpoints/cifar100/VGG16 --dataset=CIFAR100
#VGG16 WA(3)
#python3 eval_ensemble.py --dir=experiments/eval_ensemble/VGG16cifar100w3/ --data_path=data --model=VGG16 --name=200 --layer=3 --layer_ind=-26 --model_paths=checkpoints/cifar100/VGG16 --dataset=CIFAR100eval_ensemble.py outputs the statistics on ensembling. It also saves a .npz file and a .png plot containing more details at <DIR>.
- Surfaces, Mode Connectivity, and Fast Ensembling of DNNs by Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov and Andrew Gordon Wilson
- Essentially No Barriers in Neural Network Energy Landscape by Felix Draxler, Kambis Veschgini, Manfred Salmhofer, Fred A. Hamprecht
- Topology and Geometry of Half-Rectified Network Optimization by C. Daniel Freeman, Joan Bruna
- Averaging Weights Leads to Wider Optima and Better Generalization by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson