This repository contains code for GradMem, a memory mechanism where the model compresses a context into a small writable memory state using test-time gradient descent.
The key idea is not just to optimize memory at inference, but to meta-learn the model so that a few (<=5) test-time updates are effective.
arXiv: https://arxiv.org/abs/2603.13875
Each example is split into:
context C(information to store),query Q(what to ask),target Y(expected output).
GradMem runs in two phases:
-
WRITE (inner loop, K steps)
- Start from learned memory initialization
M0. - Optimize memory tokens
Mon a self-supervised reconstruction lossL_write(M, C). - Update only memory state; model weights stay fixed during this phase.
- Start from learned memory initialization
-
READ (outer objective)
- Predict
Yfrom[M; Q]after WRITE. - Train model parameters and
M0by backpropagating task loss through the WRITE updates (meta-learning).
- Predict
In short: GradMem learns to use gradient descent itself as a writing operation.
grad_memgpt.py- GradMem model and inner-loop memory optimization logic.rmt.py- recurrent memory transformer baseline with forward-only memory updates.run_gradmemgpt_on_*.py- GradMem training/eval entry points.run_rmt_on_*.py- RMT baseline entry points.run_gpt2_on_*.py- non-compressive causal LM baselines.kv_dataset_utils.py- synthetic key-value retrieval data generation and tokenizer helpers.squad_utils.py,phonebook_utils.py- NLP dataset preprocessing.prepare_pg19_chunks.py- PG19 chunking for language modeling/compression experiments.attn_double_bwd/- custom attention double-backward implementations for higher-order GradMem training.scripts/- runnable experiment presets and dataset download scripts.
Requirements:
- Python 3.11
- conda
Create environment:
conda env create -f conda_env.yaml
conda activate /home/jovyan/kuratov/envs/py311_pt2.6_cu12.4accelerate.yaml contains a default single-process setup.
Synthetic samples contain !key:value! pairs in context and a query like ?!K:; target is the corresponding value.
Download prepared datasets from Hugging Face:
./scripts/download_kv_retrieval.shYou can also generate data with notebooks/dump_dataset.ipynb (uses kv_dataset_utils.generate_sequence).
./scripts/download_babi.sh./scripts/prepare_pg19_chunks.shaccelerate launch --config_file accelerate.yaml \
run_gpt2_on_kv_retrieval.py \
--exp_path ./runs/gpt2_example \
--per_device_batch_size 64 \
--data_path ./data/N16-K2V2-V62_1M \
--tokenizer_path ./tokenizers/kv_alphabet_62/ \
--base_model llamaaccelerate launch --config_file accelerate.yaml \
run_gradmemgpt_on_kv_retrieval.py \
--exp_path ./runs/gradmem_example \
--per_device_batch_size 64 \
--data_path ./data/N16-K2V2-V62_1M \
--tokenizer_path ./tokenizers/kv_alphabet_62/ \
--base_model llama \
--n_mem_tokens 8 \
--K 2 \
--inner_lr 0.04 \
--grad_mode secondFor full experiment configurations, use scripts in scripts/:
scripts/run_gradmemgpt_on_kv_retrieval.shscripts/run_rmt_on_kv_retrieval.shscripts/run_gpt_on_kv_retrieval.shscripts/run_gradmemgpt_on_babi.shscripts/run_gradmemgpt_on_squad.sh
For text-compression experiments, prepare PG19 chunks with
scripts/prepare_pg19_chunks.sh and run run_gradmemgpt_on_text_compression.py directly with accelerate.
grad_mode controls gradient flow through WRITE updates:
none- no meta-gradient through inner updates.first- first-order approximation.second- full second-order differentiation through inner loop (default for strongest results).
Second-order mode is more expensive in memory/compute, but it is what we found that actually makes GradMem to learn; attn_double_bwd/ includes optimized double-backward implementations for attention for second-order optimization.
All runs write checkpoints, metrics, and trainer state under --exp_path (typically in ./runs/...).
@misc{kuratov2026gradmem,
title={GradMem: Learning to Write Context into Memory with Test-Time Gradient Descent},
author={Yuri Kuratov and Matvey Kairov and Aydar Bulatov and Ivan Rodkin and Mikhail Burtsev},
year={2026},
eprint={2603.13875},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2603.13875},
}
