diff --git a/README.md b/README.md
index 5f1e51d..9c084cb 100644
--- a/README.md
+++ b/README.md
@@ -78,6 +78,9 @@ Use it as the main entry point to find:
- extracted features
- pretrained models and checkpoints
+See the [Model Zoo](docs/model-zoo.md) for available pretrained models,
+reported scores, datasets, and loading snippets.
+
--
## Quickstart
@@ -96,7 +99,7 @@ from opensportslib.apis import ClassificationModel
my_model = ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
my_model.train(
@@ -112,7 +115,7 @@ from opensportslib.apis import ClassificationModel
my_model = ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
predictions = my_model.infer(
@@ -143,7 +146,7 @@ from opensportslib.apis import LocalizationModel
my_model = LocalizationModel(
config="/path/to/localization.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
predictions = my_model.infer(
diff --git a/docs/model-zoo.md b/docs/model-zoo.md
new file mode 100644
index 0000000..d7a8a3b
--- /dev/null
+++ b/docs/model-zoo.md
@@ -0,0 +1,77 @@
+# Model Zoo
+
+This page lists the pretrained OpenSportsLib models published on Hugging Face.
+Use the model repository ID with `load_weights(...)` to load a checkpoint into an
+OpenSportsLib model.
+
+## Available Models
+
+| Model | Task | Dataset trained on | Backbone / architecture | Classes / label set | Scores | Hugging Face link | Load weights snippet |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| `OSL-cls-action-mvitv2` | Action / Event Classification | SoccerNet - MVFouls classification subset | MViT v2 | Not reported on model card | Accuracy: 0.57
Balanced Accuracy: 0.40
Top-2: 0.78 | [OpenSportsLab/OSL-cls-action-mvitv2](https://huggingface.co/OpenSportsLab/OSL-cls-action-mvitv2) | `myModel.load_weights(weights="OpenSportsLab/OSL-cls-action-mvitv2")` |
+| `OSL-loc-snbas-2023-e2e` | Action Spotting / Localization | SoccerNet - Ball Action Spotting 2023 | E2E, DALI backend | PASS, DRIVE | tight mAP: 71.48
loose mAP: 85.62 | [OpenSportsLab/OSL-loc-snbas-2023-e2e](https://huggingface.co/OpenSportsLab/OSL-loc-snbas-2023-e2e) | `myModel.load_weights(weights="OpenSportsLab/OSL-loc-snbas-2023-e2e")` |
+| `OSL-loc-snbas-2025-e2e` | Action Spotting / Localization | SoccerNet - Ball Action Spotting 2025 | E2E, DALI backend | PASS, DRIVE, HEADER, HIGH PASS, OUT, CROSS, THROW IN, SHOT, BALL PLAYER BLOCK, PLAYER SUCCESSFUL TACKLE, FREE KICK, GOAL | tight mAP: 47.98
loose mAP: 58.35 | [OpenSportsLab/OSL-loc-snbas-2025-e2e](https://huggingface.co/OpenSportsLab/OSL-loc-snbas-2025-e2e) | `myModel.load_weights(weights="OpenSportsLab/OSL-loc-snbas-2025-e2e")` |
+
+## OSL-cls-action-mvitv2
+
+**Intended use:** video-based soccer action / event classification.
+
+**Dataset/training source:** SoccerNet - MVFouls classification subset, using
+video clips.
+
+**Reported metrics:**
+
+| Metric | Score |
+| --- | --- |
+| Accuracy | 0.57 |
+| Balanced Accuracy | 0.40 |
+| Top-2 | 0.78 |
+
+**Hugging Face:** [OpenSportsLab/OSL-cls-action-mvitv2](https://huggingface.co/OpenSportsLab/OSL-cls-action-mvitv2)
+
+```python
+myModel.load_weights(weights="OpenSportsLab/OSL-cls-action-mvitv2")
+```
+
+## OSL-loc-snbas-2023-e2e
+
+**Intended use:** video-based soccer action spotting / localization.
+
+**Dataset/training source:** SoccerNet - Ball Action Spotting 2023, using video
+clips at 224p resolution. The model card reports two classes: `PASS` and
+`DRIVE`.
+
+**Reported metrics:**
+
+| Metric | Score |
+| --- | --- |
+| tight mAP | 71.48 |
+| loose mAP | 85.62 |
+
+**Hugging Face:** [OpenSportsLab/OSL-loc-snbas-2023-e2e](https://huggingface.co/OpenSportsLab/OSL-loc-snbas-2023-e2e)
+
+```python
+myModel.load_weights(weights="OpenSportsLab/OSL-loc-snbas-2023-e2e")
+```
+
+## OSL-loc-snbas-2025-e2e
+
+**Intended use:** video-based soccer action spotting / localization.
+
+**Dataset/training source:** SoccerNet - Ball Action Spotting 2025, using video
+clips at 224p resolution. The model card reports twelve classes: `PASS`,
+`DRIVE`, `HEADER`, `HIGH PASS`, `OUT`, `CROSS`, `THROW IN`, `SHOT`,
+`BALL PLAYER BLOCK`, `PLAYER SUCCESSFUL TACKLE`, `FREE KICK`, and `GOAL`.
+
+**Reported metrics:**
+
+| Metric | Score |
+| --- | --- |
+| tight mAP | 47.98 |
+| loose mAP | 58.35 |
+
+**Hugging Face:** [OpenSportsLab/OSL-loc-snbas-2025-e2e](https://huggingface.co/OpenSportsLab/OSL-loc-snbas-2025-e2e)
+
+```python
+myModel.load_weights(weights="OpenSportsLab/OSL-loc-snbas-2025-e2e")
+```
diff --git a/docs/tni/config-guide.md b/docs/tni/config-guide.md
index 278573e..13ec9a0 100644
--- a/docs/tni/config-guide.md
+++ b/docs/tni/config-guide.md
@@ -8,6 +8,7 @@ Main config files in the repo:
- `opensportslib/config/classification.yaml`
- `opensportslib/config/localization.yaml`
+- `opensportslib/config/localization-e2e-ocv.yaml`
- `opensportslib/config/localization-json_netvlad++_resnetpca512.yaml`
- `opensportslib/config/localization-json_calf_resnetpca512.yaml`
- `opensportslib/config/sngar-tracking.yaml`
@@ -40,7 +41,7 @@ Defines which task pipeline is used.
- `classification`: clip-level classification pipeline.
- `localization`: spotting/localization pipeline.
-If `TASK` does not match the selected API (`model.classification` / `model.localization`), behavior can be incorrect or fail.
+If `TASK` does not match the selected API (`ClassificationModel` / `LocalizationModel`), behavior can be incorrect or fail.
### `DATA`
@@ -121,7 +122,7 @@ This avoids duplication and keeps paths consistent.
| Key | Type | Example | Meaning |
|---|---|---|---|
| `DATA.dataset_name` | string | `mvfouls` | Dataset identifier |
-| `DATA.data_dir` | path | `/.../SoccerNet/mvfouls` | Dataset root directory |
+| `DATA.data_dir` | path | `/.../OSL-XFoul/224p` | Dataset root directory |
| `DATA.data_modality` | string | `video` | Input modality for loader |
| `DATA.view_type` | string | `multi` | Single-view or multi-view processing |
| `DATA.num_classes` | int | `8` | Number of target classes |
@@ -139,7 +140,7 @@ Each split (`train`, `valid`, `test`) has:
| Key | Type | Example | Meaning |
|---|---|---|---|
| `DATA..video_path` | path | `${DATA.data_dir}/train` | Video root for split; relative media paths in annotations are resolved from here |
-| `DATA..path` | path | `.../annotations-train.json` | Annotation file |
+| `DATA..path` | path | `${DATA.train.video_path}/train.json` | Annotation file |
| `DATA..dataloader.batch_size` | int | `8` | Batch size |
| `DATA..dataloader.shuffle` | bool | `true` | Shuffle data each epoch |
| `DATA..dataloader.num_workers` | int | `4` | Data loading worker count |
@@ -214,7 +215,7 @@ Each split (`train`, `valid`, `test`) has:
| Key | Type | Example | Meaning |
|---|---|---|---|
| `DATA.dataset_name` | string | `SoccerNet` | Dataset identity |
-| `DATA.data_dir` | path | `/.../annotations` | Data root |
+| `DATA.data_dir` | path | `/.../OSL-SNBAS/224p-2024` | Data root |
| `DATA.classes` | list[string] | `PASS, DRIVE, ...` | Event class set |
| `DATA.epoch_num_frames` | int | `500000` | Frames sampled per epoch |
| `DATA.mixup` | bool | `true` | Mixup augmentation |
diff --git a/docs/tni/tni.md b/docs/tni/tni.md
index 8eb032f..69605d1 100644
--- a/docs/tni/tni.md
+++ b/docs/tni/tni.md
@@ -12,355 +12,26 @@ For full key-by-key config documentation and Python-only override workflow, see
---
## Configuration Sample (.yaml) file
-### 1. Classification
-```bash
-TASK: classification
-
-DATA:
- dataset_name: mvfouls
- data_dir: /home/vorajv/opensportslib/SoccerNet/mvfouls
- data_modality: video
- view_type: multi # multi or single
- num_classes: 8 # mvfoul
- train:
- type: annotations_train.json
- video_path: ${DATA.data_dir}/train
- path: ${DATA.train.video_path}/annotations-train.json
- dataloader:
- batch_size: 8
- shuffle: true
- num_workers: 4
- pin_memory: true
- valid:
- type: annotations_valid.json
- video_path: ${DATA.data_dir}/valid
- path: ${DATA.valid.video_path}/annotations-valid.json
- dataloader:
- batch_size: 1
- num_workers: 1
- shuffle: false
- test:
- type: annotations_test.json
- video_path: ${DATA.data_dir}/test
- path: ${DATA.test.video_path}/annotations-test.json
- dataloader:
- batch_size: 1
- num_workers: 1
- shuffle: false
- num_frames: 16 # 8 before + 8 after the foul
- input_fps: 25 # Original FPS of video
- target_fps: 17 # Temporal downsampling to 1s clip (approx)
- start_frame: 63 # Start frame of clip relative to foul frame
- end_frame: 87 # End frame of clip relative to foul frame
- frame_size: [224, 224] # Spatial resolution (HxW)
- augmentations:
- random_affine: true
- translate: [0.1, 0.1]
- affine_scale: [0.9, 1.0]
- random_perspective: true
- distortion_scale: 0.3
- perspective_prob: 0.5
- random_rotation: true
- rotation_degrees: 5
- color_jitter: true
- jitter_params: [0.2, 0.2, 0.2, 0.1] # brightness, contrast, saturation, hue
- random_horizontal_flip: true
- flip_prob: 0.5
- random_crop: false
-
-MODEL:
- type: custom # huggingface, custom
- backbone:
- type: mvit_v2_s # video_mae, r3d_18, mc3_18, r2plus1d_18, s3d, mvit_v2_s
- neck:
- type: MV_Aggregate
- agr_type: max # max, mean, attention
- head:
- type: MV_LinearLayer
- pretrained_model: mvit_v2_s # MCG-NJU/videomae-base, OpenGVLab/VideoMAEv2-Base, r3d_18, mc3_18, r2plus1d_18, s3d, mvit_v2_s
- unfreeze_head: true # for videomae backbone
- unfreeze_last_n_layers: 3 # for videomae backbone
-
-
-TRAIN:
- monitor: balanced_accuracy # balanced_accuracy, loss
- mode: max # max or min
- enabled: true
- use_weighted_sampler: false
- use_weighted_loss: true
- epochs: 20 #20
- save_dir: ./checkpoints
- log_interval: 10
- save_every: 2 #5
-
- criterion:
- type: CrossEntropyLoss
-
- optimizer:
- type: AdamW
- lr: 0.0001 #0.001
- backbone_lr: 0.00005
- head_lr: 0.001
- betas: [0.9, 0.999]
- eps: 0.0000001
- weight_decay: 0.001 #0.01 - videomae, 0.001 - others
- amsgrad: false
-
- scheduler:
- type: StepLR
- step_size: 3
- gamma: 0.1
-
-SYSTEM:
- log_dir: ./logs
- use_seed: false
- seed: 42
- GPU: 4
- device: cuda # auto | cuda | cpu
- gpu_id: 0
+The examples below are included directly from the latest YAML files in
+`opensportslib/config/`, so the documentation stays aligned with the runnable
+configs.
+
+### 1. Classification
+```yaml
+--8<-- "opensportslib/config/classification.yaml"
```
### 2. Classification (Tracking)
-```bash
-TASK: classification
-
-DATA:
- dataset_name: sngar
- data_modality: tracking_parquet
- data_dir: /home/karkid/opensportslib/sngar-tracking
- preload_data: false
- train:
- type: annotations_train.json
- video_path: ${DATA.data_dir}/train
- path: ${DATA.train.video_path}/train.json
- dataloader:
- batch_size: 32
- shuffle: true
- num_workers: 8
- pin_memory: true
- valid:
- type: annotations_valid.json
- video_path: ${DATA.data_dir}/valid
- path: ${DATA.valid.video_path}/valid.json
- dataloader:
- batch_size: 32
- num_workers: 8
- shuffle: false
- test:
- type: annotations_test.json
- video_path: ${DATA.data_dir}/test
- path: ${DATA.test.video_path}/test.json
- dataloader:
- batch_size: 32
- num_workers: 8
- shuffle: false
- num_frames: 16
- frame_interval: 9
- augmentations:
- vertical_flip: true
- horizontal_flip: true
- team_flip: true
- normalize: true
- num_objects: 23
- feature_dim: 8
- pitch_half_length: 85.0
- pitch_half_width: 50.0
- max_displacement: 110.0
- max_ball_height: 30.0
-
-MODEL:
- type: custom
- backbone:
- type: graph_conv
- encoder: graphconv
- hidden_dim: 64
- num_layers: 20
- dropout: 0.1
- neck:
- type: TemporalAggregation
- agr_type: maxpool
- hidden_dim: 64
- dropout: 0.1
- head:
- type: TrackingClassifier
- hidden_dim: 64
- dropout: 0.1
- num_classes: 10
- edge: positional
- k: 8
- r: 15.0
-
-TRAIN:
- monitor: loss # balanced_accuracy, loss
- mode: min # max or min
- enabled: true
- use_weighted_sampler: true
- use_weighted_loss: false
- samples_per_class: 4000
- epochs: 10
- patience: 10
- save_every: 20
- detailed_results: true
-
- optimizer:
- type: Adam
- lr: 0.001
-
- scheduler:
- type: ReduceLROnPlateau
- mode: ${TRAIN.mode}
- patience: 10
- factor: 0.1
- min_lr: 1e-8
-
- criterion:
- type: CrossEntropyLoss
-
- save_dir: ./checkpoints_tracking
-
-SYSTEM:
- log_dir: ./logs
- use_seed: true
- seed: 42
- GPU: 4
- device: cuda # auto | cuda | cpu
- gpu_id: 0
+
+```yaml
+--8<-- "opensportslib/config/sngar-tracking.yaml"
```
### 3. Localization
-```bash
-TASK: localization
-
-dali: True
-
-DATA:
- dataset_name: SoccerNet
- data_dir: /home/vorajv/opensportslib/SoccerNet/annotations/
- classes:
- - PASS
- - DRIVE
- - HEADER
- - HIGH PASS
- - OUT
- - CROSS
- - THROW IN
- - SHOT
- - BALL PLAYER BLOCK
- - PLAYER SUCCESSFUL TACKLE
- - FREE KICK
- - GOAL
-
- epoch_num_frames: 500000
- mixup: true
- modality: rgb
- crop_dim: -1
- dilate_len: 0 # Dilate ground truth labels
- clip_len: 100
- input_fps: 25
- extract_fps: 2
- imagenet_mean: [0.485, 0.456, 0.406]
- imagenet_std: [0.229, 0.224, 0.225]
- target_height: 224
- target_width: 398
-
- train:
- type: VideoGameWithDali
- classes: ${DATA.classes}
- output_map: [data, label]
- video_path: ${DATA.data_dir}/train/
- path: ${DATA.train.video_path}/annotations-2024-224p-train.json
- dataloader:
- batch_size: 8
- shuffle: true
- num_workers: 4
- pin_memory: true
-
- valid:
- type: VideoGameWithDali
- classes: ${DATA.classes}
- output_map: [data, label]
- video_path: ${DATA.data_dir}/valid/
- path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
- dataloader:
- batch_size: 8
- shuffle: true
-
- valid_data_frames:
- type: VideoGameWithDaliVideo
- classes: ${DATA.classes}
- output_map: [data, label]
- video_path: ${DATA.valid.video_path}
- path: ${DATA.valid.path}
- overlap_len: 0
- dataloader:
- batch_size: 4
- shuffle: false
-
- test:
- type: VideoGameWithDaliVideo
- classes: ${DATA.classes}
- output_map: [data, label]
- video_path: ${DATA.data_dir}/test/
- path: ${DATA.test.video_path}/annotations-2024-224p-test.json
- results: results_spotting_test
- nms_window: 2
- metric: tight
- overlap_len: 50
- dataloader:
- batch_size: 4
- shuffle: false
-
- challenge:
- type: VideoGameWithDaliVideo
- overlap_len: 50
- output_map: [data, label]
- path: ${DATA.data_dir}/challenge/annotations.json
- dataloader:
- batch_size: 4
- shuffle: false
-
-MODEL:
- type: E2E
- runner:
- type: runner_e2e
- backbone:
- type: rny008_gsm
- head:
- type: gru
- multi_gpu: true
- load_weights: null
- save_dir: ./checkpoints
- work_dir: ${MODEL.save_dir}
-
-TRAIN:
- type: trainer_e2e
- num_epochs: 10
- acc_grad_iter: 1
- base_num_valid_epochs: 30
- start_valid_epoch: 4
- valid_map_every: 1
- criterion_valid: map
-
- criterion:
- type: CrossEntropyLoss
-
- optimizer:
- type: AdamWithScaler
- lr: 0.01
-
- scheduler:
- type: ChainedSchedulerE2E
- acc_grad_iter: 1
- num_epochs: ${TRAIN.num_epochs}
- warm_up_epochs: 3
-
-SYSTEM:
- log_dir: ./logs
- seed: 42
- GPU: 4 # number of gpus to use
- device: cuda # auto | cuda | cpu
- gpu_id: 0 # device id for single gpu training
+
+```yaml
+--8<-- "opensportslib/config/localization.yaml"
```
## Annotations (train/valid/test) (.json) Format
@@ -385,6 +56,9 @@ Download annotation files from the links below.
## Download Weights from HuggingFace
+For a comparison table with datasets, reported scores, and model links, see the
+[Model Zoo](../model-zoo.md).
+
### 1. Classification (MViT)
**MVFoul Classification (MViT backbone)**
@@ -420,7 +94,7 @@ import wandb
# Initialize model with config
myModel = model.ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
## Localization ##
@@ -442,7 +116,7 @@ from opensportslib import model
def main():
myModel = model.ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
## Localization ##
@@ -468,7 +142,7 @@ from opensportslib import model
# Load trained model
myModel = model.ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
## Localization ##
@@ -498,7 +172,7 @@ from opensportslib import model
def main():
myModel = model.ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
## Localization ##
diff --git a/examples/configs/README.md b/examples/configs/README.md
index 63e78ff..f101576 100644
--- a/examples/configs/README.md
+++ b/examples/configs/README.md
@@ -18,10 +18,10 @@ Point the OpenSportsLib Python API to one of these configs.
Example:
```python
-from opensportslib import model
+from opensportslib.apis import ClassificationModel
-my_model = model.classification(
- config="examples/configs/classification.yaml"
+my_model = ClassificationModel(
+ config="examples/configs/classification_video.yaml"
)
```
diff --git a/examples/configs/classification_tracking.yaml b/examples/configs/classification_tracking.yaml
index 32ad0d7..a482742 100644
--- a/examples/configs/classification_tracking.yaml
+++ b/examples/configs/classification_tracking.yaml
@@ -47,12 +47,15 @@ DATA:
pitch_half_width: 50.0
max_displacement: 110.0
max_ball_height: 30.0
+ data_slicing: # only used for data scaling experiments
+ enabled: false
+ training_matches: 45 # default: all 45 training matches
MODEL:
type: custom
backbone:
type: graph_conv
- encoder: graphconv
+ encoder: gin
hidden_dim: 64
num_layers: 20
dropout: 0.1
@@ -61,6 +64,7 @@ MODEL:
agr_type: maxpool
hidden_dim: 64
dropout: 0.1
+ use_position_encoding: true
head:
type: TrackingClassifier
hidden_dim: 64
@@ -77,7 +81,7 @@ TRAIN:
use_weighted_sampler: true
use_weighted_loss: false
samples_per_class: 4000
- epochs: 10
+ epochs: 100
patience: 10
save_every: 20
detailed_results: true
@@ -96,12 +100,11 @@ TRAIN:
criterion:
type: CrossEntropyLoss
- save_dir: ./checkpoints_tracking
-
SYSTEM:
- log_dir: ./logs
- use_seed: true
- seed: 42
- GPU: 4
- device: cuda # auto | cuda | cpu
- gpu_id: 0
\ No newline at end of file
+ log_dir: ./logs
+ save_dir: ./checkpoints_tracking
+ use_seed: true
+ seed: 42
+ GPU: 1
+ device: cuda # auto | cuda | cpu
+ gpu_id: 0
diff --git a/examples/configs/classification_video.yaml b/examples/configs/classification_video.yaml
index 539a10c..037f915 100644
--- a/examples/configs/classification_video.yaml
+++ b/examples/configs/classification_video.yaml
@@ -16,6 +16,9 @@ DATA:
shuffle: true
num_workers: 4
pin_memory: true
+ mp_context: spawn
+ persistent_workers: true
+ prefetch_factor: 4
valid:
type: annotations_valid.json
@@ -25,6 +28,9 @@ DATA:
batch_size: 1
num_workers: 1
shuffle: false
+ mp_context: spawn
+ persistent_workers: true
+ prefetch_factor: 4
test:
type: annotations_test.json
@@ -32,8 +38,11 @@ DATA:
path: ${DATA.test.video_path}/annotations-test.json
dataloader:
batch_size: 1
- num_workers: 1
+ num_workers: 0
shuffle: false
+ mp_context: spawn
+ persistent_workers: true
+ prefetch_factor: 4
num_frames: 16 # 8 before + 8 after the foul
input_fps: 25 # Original FPS of video
@@ -78,7 +87,6 @@ TRAIN:
use_weighted_sampler: false
use_weighted_loss: true
epochs: 20 #20
- save_dir: ./checkpoints
log_interval: 10
save_every: 2 #5
@@ -102,8 +110,9 @@ TRAIN:
SYSTEM:
log_dir: ./logs
+ save_dir: ./checkpoints
use_seed: false
seed: 42
GPU: 4
device: cuda # auto | cuda | cpu
- gpu_id: 0
\ No newline at end of file
+ gpu_id: 0
diff --git a/examples/configs/localization.yaml b/examples/configs/localization.yaml
index 6229beb..0532162 100644
--- a/examples/configs/localization.yaml
+++ b/examples/configs/localization.yaml
@@ -3,8 +3,8 @@ TASK: localization
dali: True
DATA:
- dataset_name: SoccerNet-Ball-Action-Spotting
- data_dir: /home/vorajv/opensportslib/SoccerNet/annotations/
+ dataset_name: SoccerNet
+ data_dir: /path/to/OSL-SNBAS/224p-2024
classes:
- PASS
- DRIVE
@@ -37,7 +37,7 @@ DATA:
classes: ${DATA.classes}
output_map: [data, label]
video_path: ${DATA.data_dir}/train/
- path: ${DATA.train.video_path}/annotations-2024-224p-train.json
+ path: ${DATA.train.video_path}/train.json
dataloader:
batch_size: 8
shuffle: true
@@ -49,10 +49,12 @@ DATA:
classes: ${DATA.classes}
output_map: [data, label]
video_path: ${DATA.data_dir}/valid/
- path: ${DATA.valid.video_path}/annotations-2024-224p-valid.json
+ path: ${DATA.valid.video_path}/valid.json
dataloader:
batch_size: 8
shuffle: true
+ num_workers: 4
+ pin_memory: true
valid_data_frames:
type: VideoGameWithDaliVideo
@@ -64,13 +66,15 @@ DATA:
dataloader:
batch_size: 4
shuffle: false
+ num_workers: 4
+ pin_memory: true
test:
type: VideoGameWithDaliVideo
classes: ${DATA.classes}
output_map: [data, label]
video_path: ${DATA.data_dir}/test/
- path: ${DATA.test.video_path}/annotations-2024-224p-test.json
+ path: ${DATA.test.video_path}/test.json
results: results_spotting_test
nms_window: 2
metric: tight
@@ -98,8 +102,6 @@ MODEL:
type: gru
multi_gpu: true
load_weights: null
- save_dir: ./checkpoints
- work_dir: ${MODEL.save_dir}
TRAIN:
type: trainer_e2e
@@ -125,7 +127,9 @@ TRAIN:
SYSTEM:
log_dir: ./logs
+ save_dir: ./checkpoints
+ work_dir: ${SYSTEM.save_dir}
seed: 42
GPU: 4 # number of gpus to use
device: cuda # auto | cuda | cpu
- gpu_id: 0 # device id for single gpu training
\ No newline at end of file
+ gpu_id: 0 # device id for single gpu training
diff --git a/examples/quickstart/basic_classification.py b/examples/quickstart/basic_classification.py
index 35cacb1..4560d9d 100644
--- a/examples/quickstart/basic_classification.py
+++ b/examples/quickstart/basic_classification.py
@@ -9,7 +9,7 @@ def main():
my_model = ClassificationModel(
config="examples/configs/classification_video.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
my_model.train(
diff --git a/examples/quickstart/basic_localization.py b/examples/quickstart/basic_localization.py
index e55642a..24ccb81 100644
--- a/examples/quickstart/basic_localization.py
+++ b/examples/quickstart/basic_localization.py
@@ -9,7 +9,7 @@ def main():
my_model = LocalizationModel(
config="examples/configs/localization.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
my_model.train(
diff --git a/mkdocs.yml b/mkdocs.yml
index 3c5fb0d..42a74df 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -27,6 +27,7 @@ theme:
nav:
- Home: index.md
+ - Model Zoo: model-zoo.md
- Getting Started:
- Installation: getting-started/installation.md
diff --git a/opensportslib/apis/README.md b/opensportslib/apis/README.md
index 477eda3..318e081 100644
--- a/opensportslib/apis/README.md
+++ b/opensportslib/apis/README.md
@@ -45,7 +45,7 @@ from opensportslib.apis import ClassificationModel
m = ClassificationModel(
config="/path/to/classification.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
best_ckpt = m.train(
@@ -83,7 +83,7 @@ from opensportslib.apis import LocalizationModel
m = LocalizationModel(
config="/path/to/localization.yaml",
- weights="/path/to/weights.pt", # optional
+ weights=None, # optional: path or Hugging Face model ID
)
best_ckpt = m.train(
diff --git a/tools/training/README.md b/tools/training/README.md
index 0fd0691..1818422 100644
--- a/tools/training/README.md
+++ b/tools/training/README.md
@@ -6,8 +6,8 @@ Minimal training scripts for each task. Run from the **repository root**.
| Script | Task |
|---|---|
-| `basic_classification.py` | Action classification |
-| `basic_localization.py` | Action localization |
+| `classification.py` | Action classification |
+| `localization.py` | Action localization |
## Arguments
@@ -16,9 +16,9 @@ Both scripts accept the same CLI arguments:
| Argument | Required | Description |
|---|---|---|
| `--config` | yes | Path to the YAML config file |
-| `--train-set` | yes | Path to train annotations JSON |
-| `--valid-set` | yes | Path to validation annotations JSON |
-| `--test-set` | yes | Path to test annotations JSON |
+| `--train-set` | no | Path to train annotations JSON; defaults to `DATA.train.path` |
+| `--valid-set` | no | Path to validation annotations JSON; defaults to `DATA.valid.path` |
+| `--test-set` | no | Path to test annotations JSON; defaults to `DATA.test.path` |
| `--weights` | no | Path to pretrained weights |
## Usage
@@ -26,32 +26,23 @@ Both scripts accept the same CLI arguments:
### Classification
```bash
-python tools/training/basic_classification.py \
- --config examples/configs/classification_video.yaml \
- --train-set /path/to/train_annotations.json \
- --valid-set /path/to/valid_annotations.json \
- --test-set /path/to/test_annotations.json
+python tools/training/classification.py \
+ --config examples/configs/classification_video.yaml
```
With pretrained weights:
```bash
-python tools/training/basic_classification.py \
+python tools/training/classification.py \
--config examples/configs/classification_video.yaml \
- --weights /path/to/weights.pt \
- --train-set /path/to/train_annotations.json \
- --valid-set /path/to/valid_annotations.json \
- --test-set /path/to/test_annotations.json
+ --weights OpenSportsLab/OSL-cls-action-mvitv2
```
### Localization
```bash
-python tools/training/basic_localization.py \
- --config examples/configs/localization.yaml \
- --train-set /path/to/train_annotations.json \
- --valid-set /path/to/valid_annotations.json \
- --test-set /path/to/test_annotations.json
+python tools/training/localization.py \
+ --config examples/configs/localization.yaml
```
## Example Configs
diff --git a/tools/training/classification.py b/tools/training/classification.py
index 6cf3a60..bcc4190 100644
--- a/tools/training/classification.py
+++ b/tools/training/classification.py
@@ -7,9 +7,9 @@ def parse_args():
parser = argparse.ArgumentParser(description="Minimal classification training script.")
parser.add_argument("--config", required=True, help="Path to the YAML config file.")
parser.add_argument("--weights", default=None, help="Path to pretrained weights (optional).")
- parser.add_argument("--train-set", required=True, help="Path to train annotations JSON.")
- parser.add_argument("--valid-set", required=True, help="Path to validation annotations JSON.")
- parser.add_argument("--test-set", required=True, help="Path to test annotations JSON.")
+ parser.add_argument("--train-set", default=None, help="Path to train annotations JSON. Defaults to DATA.train.path from the config.")
+ parser.add_argument("--valid-set", default=None, help="Path to validation annotations JSON. Defaults to DATA.valid.path from the config.")
+ parser.add_argument("--test-set", default=None, help="Path to test annotations JSON. Defaults to DATA.test.path from the config.")
return parser.parse_args()
diff --git a/tools/training/localization.py b/tools/training/localization.py
index 1ac4944..aaa007e 100644
--- a/tools/training/localization.py
+++ b/tools/training/localization.py
@@ -7,9 +7,9 @@ def parse_args():
parser = argparse.ArgumentParser(description="Minimal localization training script.")
parser.add_argument("--config", required=True, help="Path to the YAML config file.")
parser.add_argument("--weights", default=None, help="Path to pretrained weights (optional).")
- parser.add_argument("--train-set", required=True, help="Path to train annotations JSON.")
- parser.add_argument("--valid-set", required=True, help="Path to validation annotations JSON.")
- parser.add_argument("--test-set", required=True, help="Path to test annotations JSON.")
+ parser.add_argument("--train-set", default=None, help="Path to train annotations JSON. Defaults to DATA.train.path from the config.")
+ parser.add_argument("--valid-set", default=None, help="Path to validation annotations JSON. Defaults to DATA.valid.path from the config.")
+ parser.add_argument("--test-set", default=None, help="Path to test annotations JSON. Defaults to DATA.test.path from the config.")
return parser.parse_args()