Skip to content

andrinbuerli/imitation-learning

Repository files navigation

Imitation Learning for the Godot Bonus Unit

This repository contains a lightweight PyTorch/Lightning implementation for behavior cloning for a Godot RL environment from the Hugging Face Deep RL course bonus lesson. The project compares classical results achieved with Generative Adversarial Imitation Learning (GAIL) with explicit Behavioral Cloning (EBC) using state and sequence based modelling. In the future also Implicit Behavioral Cloning (IBC) using a energy-based model will be implemented.

Environment

The environment features a robot that needs to:

  1. Pull a lever to raise the stairs leading to the second room,
  2. Navigate to the key 🔑 and collect it while avoiding falling down into traps, water, or outside the map,
  3. Navigate back to the treasure chest in the first room, and open it. Victory! 🏆

🧭 Observations

The observation vector combines ray perception, relative directions/distances to key objects, progress flags, raft state, and normalized player motion information into one flattened array with a total of 209 values.

Category Contents Details
Raycasts Raycast sensor outputs Concatenated values from all raycast_sensors
Chest (relative) direction (x, y, z) + distance Local-space direction (normalized) and clipped distance
Lever (relative) direction (x, y, z) + distance Local-space direction (normalized) and clipped distance
Key (relative) direction (x, y, z) + distance Local-space direction (normalized) and clipped distance
Raft (relative) direction (x, y, z) + distance Local-space direction (normalized) and clipped distance
Raft state movement_direction_multiplier Scalar multiplier
Task flags lever pulled, chest opened, key collected Stored as floats (0.0 or 1.0)
Player grounded is_on_floor() 0.0 or 1.0
Player velocity (vx, vy, vz) Local-space velocity, clipped and normalized

🎮 Actions

The agent uses continuous control, grouped as follows:

Action group Size Meaning
movement 2 planar movement (x, y input)
rotation 1 yaw rotation
jump 1 jump request (continuous → thresholded)
use_action 1 interact/use request (continuous → thresholded)

The total action vector length is 5, all clipped to the range [-1,+1].

Behavior:

  • movement and rotation are used directly as continuous values

  • jump and use_action are treated as:

    • value > 0 → pressed
    • value ≤ 0 → not pressed
  • rotation affects only Y-axis (turning left/right)

Results

GAIL performs best and is perfectly stable; transformer-based EBC outperforms MLP EBC, and token masking further improves performance and stability.

Method Avg. cumulative reward (50 eps)
GAIL 3.000 ± 0.000
MLP EBC 0.520+-0.640
Transformer EBC 2.700+-0.755
Transformer + token masking EBC 2.820+-0.623

Video renders

MLP EBC Transformer + token masking EBC
mlp mlp

Project layout

  • scripts/: entrypoints for data cleaning, training, evaluation, and a Stable-Baselines3 GAIL baseline.
  • src/imitation_learning/: Lightning module, data pipeline, and model architectures (MLP and Transformer behavior cloning).
  • configs/: Hydra configuration files for training and evaluation.
  • data/raw/: raw expert demonstrations from the course; data/clean/ is created after filtering.
  • data/models/: pretrained checkpoints (ONNX for gail, lightning ckpts for other).

Installation

This project targets Python 3.12+. Install dependencies and the package in editable mode:

python -m venv .venv
source .venv/bin/activate
pip install -e .

If you plan to log to Weights & Biases, ensure the wandb CLI is configured (set WANDB_API_KEY or run wandb login).

Data

Raw demonstrations live under data/raw/*.json. Cleaned trajectories are generated by running:

python scripts/clean_data.py

For the bundled dataset, cleaning loads 121 trajectories, filters out sequences shorter than 300 steps, and leaves 101 trajectories with lengths between 305 and 608 (mean 417.24 steps). Clean data is written to data/clean/filtered_trajectories.json for training.

Training

Behavior cloning is configured via Hydra. The default configuration trains an MLP policy on windowed observations of length 1. Example command:

python scripts/train_torch.py

Key options (edit in configs/train_torch.yaml or override via CLI):

  • train.seq_len: sequence length for windowed training (set to >1 for the Transformer).
  • train.action_mode: probabilistic for discrete actions, continuous for continuous control.
  • train.batch_size, train.epochs, train.lr: optimization hyperparameters.
  • train.export_onnx_path: set to export the trained network to ONNX after training.
  • eval.*: controls periodic evaluation in the Godot environment during training.

Checkpoints are stored under output/checkpoints/last.ckpt by default.

Evaluation

You can evaluate either a Lightning checkpoint or an ONNX export:

# Using a Lightning checkpoint
python scripts/eval.py model.checkpoint_path=output/checkpoints/last.ckpt

# Using an ONNX model
python scripts/eval.py model.onnx_path=model_bc.onnx

Evaluation runs multiple parallel environments (16 by default) and reports the mean episodic reward across eval.n_episodes episodes. When using ONNX, inference relies on onnxruntime.

Pretrained results

Two pretrained ONNX policies are bundled:

  • model_bc.onnx: behavior cloning baseline trained on the filtered demonstrations.
  • model_gail.onnx: adversarial imitation (GAIL) baseline trained via Stable-Baselines3 (scripts/sb3_imitation.py).

Use these to sanity-check your environment setup or to compare against your own training runs.

About

Learning to imitate a long-horizon, sparse reward task using GAIL, EBC & IBC

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages