Skip to content

wp99cp/continual-learning

Repository files navigation

Prototype-Based Sampling for Class-Incremental Learning

Abstract

In real-world environments, machine learning models frequently face the challenge of adapting to tasks not present in their initial training data while keeping their previously learned capabilities. If one simply continued training the model on the new data, it would exhibit catastrophic forgetting. Several algorithms belonging to the field of continual learning have been proposed to remedy this issue. One prominent example is experience replay, which stores certain samples from the previous data distribution for the model to revisit when training on the new data. We aim to contribute to solving this problem by exploring the correlation between the statistical properties of the prototypes of classes in feature space and catastrophic forgetting. Our analysis reveals that these properties are indeed important for forgetting. Additionally, the methods we created based on these insights achieve state-of-the-art performance, demonstrating the relevance of our model.

Reproduce the Experiments

  1. Clone the repository

  2. Install the requirements

    # create and activate a virtual environment
    python3 -m venv .venv
    source .venv/bin/activate
    
    # install the requirements
    pip install -r requirements.txt
  3. Run the experiments

     # run the experiments
     python main.py

We are using Python 3.12 for this project.

Hyperparameters

Inside main.py, you find multiple parameters you can tweak to select which experiments get executed what dataset is used. Essential hyperparameters such as batch_size, number of tasks etc. can also be tuned directly within main.py.

Additional experiment-specific hyperparameters can be found in the corresponding experiment classes. Essential hyperparameters that are shared across all experiments can be found on experiments/base_experiment.py. Here the lr, optimizer, ... are defined and can be changed.

Results

The results are stored in the tb_data directory. You can visualize them using tensorboard.

tensorboard --logdir tb_data --port 6066

Open tensorboard in your browser at localhost:6066. Now, you should be able to see all the results and plots used in the paper.

Regenerate Result Figures from Data

All metrics returned are saved inside a .pkl file. You can regenerate the figures using the following command:

python main.py --res_path tb_data/<date-of-run>_results

The results are then saved inside a tensorboard directory. You can visualize them using tensorboard.

tensorboard --logdir tb_data/<date-of-run>_results --port 6066

Development / Debugging

We use black for code formatting the code. You can install it using pip and run it on the code.

black .

Please run black before committing your changes. And configure your editor to run black on save.

Figures Generated By the Code

For debugging purposes, we generate some figures all of which are added to the tensorboard. We show and explain some of them. Note the following figures are not part of the paper and may not represent the final results described in the report but rather are chosen to explain their purpose during development.

  1. Confusion Matrix We plot the confusion matrix for the current model. imgs/confusion_matrix.png

  2. Accuracy We plot the accuracy of the test set of all tasks. imgs/accuracy_history.png

  3. Slow Learners We plot the slow learners of the model. This figure is recreated based on the "Forgetting Order of Continual Learning: Examples That are Learned First are Forgotten Last" paper (figure 2 (b)) imgs/slow_learners.png

  4. Composition of Mini-Batches To check the composition of the mini-batches, we plot the number of samples per class in each mini-batch. The figure shows the number of samples per class in each mini-batch. imgs/samples_per_batch.png

  5. Composition of Mini-Batches (Unique Samples) Similar to the previous figure, but we show the number of unique samples per class in each mini-batch. imgs/unique_samples_per_batch.png

Short Introduction to Avalanche

YouTube: Antonio Carta | "Avalanche: an End-to-End Library for Continual Learning"

About

Continual Learning: Course Project for the Deep Learning Lecture at ETH.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors