From 55dbd83b6215274cb1a09e91f3f454920d94a60c Mon Sep 17 00:00:00 2001 From: Hanyuan Li Date: Tue, 9 Jun 2026 16:26:14 +1000 Subject: [PATCH 1/4] feat: add support for non-TorchVision datasets --- .../torchvision_loader/custom/README.md | 42 +++++ .../torchvision_loader/custom/__init__.py | 1 + .../torchvision_loader/custom/cub/__init__.py | 1 + .../custom/cub/data_processing.py | 71 +++++++ .../torchvision_loader/custom/cub/dataset.py | 131 +++++++++++++ .../torchvision_loader/data_model_loader.py | 18 +- .../torchvision_loader/data_model_mapping.py | 13 ++ ipynb/cub_load.ipynb | 176 ++++++++++++++++++ 8 files changed, 449 insertions(+), 4 deletions(-) create mode 100644 act/front_end/torchvision_loader/custom/README.md create mode 100644 act/front_end/torchvision_loader/custom/__init__.py create mode 100644 act/front_end/torchvision_loader/custom/cub/__init__.py create mode 100644 act/front_end/torchvision_loader/custom/cub/data_processing.py create mode 100644 act/front_end/torchvision_loader/custom/cub/dataset.py create mode 100644 ipynb/cub_load.ipynb diff --git a/act/front_end/torchvision_loader/custom/README.md b/act/front_end/torchvision_loader/custom/README.md new file mode 100644 index 000000000..fd5879431 --- /dev/null +++ b/act/front_end/torchvision_loader/custom/README.md @@ -0,0 +1,42 @@ +# Custom dataset types + +This folder contains the `Dataset` types for custom datasets, outside of `torchvision.datasets`. To add your own dataset type, please: + +1. Make a new folder for your dataset. +2. Create an `__init__.py` inside your folder so that only your dataset is exposed within the folder. +3. Add your dataset to `custom/__init__.py`. + +Once that is done, you must add an extra field to your custom dataset definition in `data_model_mapping.py`, `"class_name"`, which is equal to your +dataset class's name. For example, if your dataset class name is `ABCDataset`: + +```py +DATASET_MODEL_MAPPING: Dict[str, Dict[str, Any]] = { + # ... + "ABC": { + # ... + "class_name": "ABCDataset", + }, + # ... +} +``` + +Your dataset type must have the following mandatory initialisation arguments (to align with the ones in `torchvision.datasets`), and all other arguments +must be optional: + +```py +from torch.util.data import Dataset + + +class ABCDataset(Dataset): + def __init__( + self, + root: str, + train: bool, + download: bool, + # All other args must be optional: + foo: str | None = None, + bar: int = 5, + ): + # ... + pass +``` diff --git a/act/front_end/torchvision_loader/custom/__init__.py b/act/front_end/torchvision_loader/custom/__init__.py new file mode 100644 index 000000000..71f6e9db2 --- /dev/null +++ b/act/front_end/torchvision_loader/custom/__init__.py @@ -0,0 +1 @@ +from .cub import CUBDataset diff --git a/act/front_end/torchvision_loader/custom/cub/__init__.py b/act/front_end/torchvision_loader/custom/cub/__init__.py new file mode 100644 index 000000000..593bd8f5f --- /dev/null +++ b/act/front_end/torchvision_loader/custom/cub/__init__.py @@ -0,0 +1 @@ +from .dataset import CUBDataset diff --git a/act/front_end/torchvision_loader/custom/cub/data_processing.py b/act/front_end/torchvision_loader/custom/cub/data_processing.py new file mode 100644 index 000000000..3fc369c81 --- /dev/null +++ b/act/front_end/torchvision_loader/custom/cub/data_processing.py @@ -0,0 +1,71 @@ +""" +Code modified from https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/data_processing.py +Make train, val, test datasets based on train_test_split.txt, and by sampling val_ratio of the official train data to make a validation set +Each dataset is a list of metadata, each includes official image id, full image path, class label, attribute labels, attribute certainty scores, and attribute labels calibrated for uncertainty +""" +import os +import random +from os import listdir +from os.path import isfile, isdir, join +from collections import defaultdict as ddict + + +def extract_data(data_dir): + cwd = os.getcwd() + data_path = join(cwd,data_dir + '/images') + val_ratio = 0.2 + + path_to_id_map = dict() #map from full image path to image id + with open(data_path.replace('images', 'images.txt'), 'r') as f: + for line in f: + items = line.strip().split() + path_to_id_map[join(data_path, items[1])] = int(items[0]) + + attribute_labels_all = ddict(list) #map from image id to a list of attribute labels + attribute_certainties_all = ddict(list) #map from image id to a list of attribute certainties + attribute_uncertain_labels_all = ddict(list) #map from image id to a list of attribute labels calibrated for uncertainty + # 1 = not visible, 2 = guessing, 3 = probably, 4 = definitely + uncertainty_map = {1: {1: 0, 2: 0.5, 3: 0.75, 4:1}, #calibrate main label based on uncertainty label + 0: {1: 0, 2: 0.5, 3: 0.25, 4: 0}} + with open(join(cwd, data_dir + '/attributes/image_attribute_labels.txt'), 'r') as f: + for line in f: + file_idx, attribute_idx, attribute_label, attribute_certainty = line.strip().split()[:4] + attribute_label = int(attribute_label) + attribute_certainty = int(attribute_certainty) + uncertain_label = uncertainty_map[attribute_label][attribute_certainty] + attribute_labels_all[int(file_idx)].append(attribute_label) + attribute_uncertain_labels_all[int(file_idx)].append(uncertain_label) + attribute_certainties_all[int(file_idx)].append(attribute_certainty) + + is_train_test = dict() #map from image id to 0 / 1 (1 = train) + with open(join(cwd, data_dir + '/train_test_split.txt'), 'r') as f: + for line in f: + idx, is_train = line.strip().split() + is_train_test[int(idx)] = int(is_train) + print("Number of train images from official train test split:", sum(list(is_train_test.values()))) + + train_val_data, test_data = [], [] + train_data, val_data = [], [] + folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] + folder_list.sort() #sort by class index + for i, folder in enumerate(folder_list): + folder_path = join(data_path, folder) + classfile_list = [cf for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')] + #classfile_list.sort() + for cf in classfile_list: + img_id = path_to_id_map[join(folder_path, cf)] + img_path = join(folder_path, cf) + metadata = {'id': img_id, 'img_path': img_path, 'class_label': i, + 'attribute_label': attribute_labels_all[img_id], 'attribute_certainty': attribute_certainties_all[img_id], + 'uncertain_attribute_label': attribute_uncertain_labels_all[img_id]} + if is_train_test[img_id]: + train_val_data.append(metadata) + else: + test_data.append(metadata) + + random.shuffle(train_val_data) + split = int(val_ratio * len(train_val_data)) + train_data = train_val_data[split :] + val_data = train_val_data[: split] + print('Size of train set:', len(train_data)) + return train_data, val_data, test_data diff --git a/act/front_end/torchvision_loader/custom/cub/dataset.py b/act/front_end/torchvision_loader/custom/cub/dataset.py new file mode 100644 index 000000000..cd641e3b4 --- /dev/null +++ b/act/front_end/torchvision_loader/custom/cub/dataset.py @@ -0,0 +1,131 @@ +""" +Code modified from https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/dataset.py +General utils for training, evaluation and data loading +""" +import os +import pickle +import numpy as np + +from PIL import Image +from torch.utils.data import Dataset +from torchvision.datasets.utils import check_integrity, download_and_extract_archive, extract_archive + +from .data_processing import extract_data + + +N_ATTRIBUTES = 312 + +class CUBDataset(Dataset): + """ + Returns a compatible Torch Dataset object customized for the CUB dataset + """ + + RAW_DATASET_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" + MD5 = "97eceeb196236b17998738112f37df78" + + # Modified to be in format + # ``` + # dataset_class( + # root=str(raw_dir), + # train=True, + # download=True + # ) + # ``` + def __init__( + self, + root, + train, + download, + # Use sensible defaults + use_attr=False, + no_img=False, + uncertain_label=False, + image_dir="images", + n_class_attr=1, + transform=None, + ): + """ + Arguments: + pkl_file_paths: list of full path to all the pkl data + use_attr: whether to load the attributes (e.g. False for simple finetune) + no_img: whether to load the images (e.g. False for A -> Y model) + uncertain_label: if True, use 'uncertain_attribute_label' field (i.e. label weighted by uncertainty score, e.g. 1 & 3(probably) -> 0.75) + image_dir: default = 'images'. Will be append to the parent dir + n_class_attr: number of classes to predict for each attribute. If 3, then make a separate class for not visible + transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing + """ + self.root = root + self.train = train + + if download: + self.download_and_process() + + if self.train: + self.data = pickle.load(open(f"{root}/processed/train", "rb")) + else: + self.data = pickle.load(open(f"{root}/processed/test", "rb")) + + self.transform = transform + self.use_attr = use_attr + self.no_img = no_img + self.uncertain_label = uncertain_label + self.image_dir = image_dir + self.n_class_attr = n_class_attr + + def download_and_process(self): + if check_integrity(f"{self.root}/CUB_200_2011.tgz", self.MD5): + return + + download_and_extract_archive( + url=self.RAW_DATASET_URL, + download_root=self.root, + extract_root=f"{self.root}/decompressed" + ) + + train, _, test = extract_data(f"{self.root}/decompressed") + + os.mkdir(f"{self.root}/processed") + pickle.dump(train, open(f"{self.root}/processed/train", "wb")) + pickle.dump(test, open(f"{self.root}/processed/test", "wb")) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + img_data = self.data[idx] + img_path = img_data['img_path'] + # Trim unnecessary paths + try: + idx = img_path.split('/').index('CUB_200_2011') + if self.image_dir != 'images': + img_path = '/'.join([self.image_dir] + img_path.split('/')[idx+1:]) + img_path = img_path.replace('images/', '') + else: + img_path = '/'.join(img_path.split('/')[idx:]) + img = Image.open(img_path).convert('RGB') + except: + img_path_split = img_path.split('/') + split = 'train' if self.train else 'test' + img_path = '/'.join(img_path_split[:2] + [split] + img_path_split[2:]) + img = Image.open(img_path).convert('RGB') + + class_label = img_data['class_label'] + if self.transform: + img = self.transform(img) + + if self.use_attr: + if self.uncertain_label: + attr_label = img_data['uncertain_attribute_label'] + else: + attr_label = img_data['attribute_label'] + if self.no_img: + if self.n_class_attr == 3: + one_hot_attr_label = np.zeros((N_ATTRIBUTES, self.n_class_attr)) + one_hot_attr_label[np.arange(N_ATTRIBUTES), attr_label] = 1 + return one_hot_attr_label, class_label + else: + return attr_label, class_label + else: + return img, class_label, attr_label + else: + return img, class_label diff --git a/act/front_end/torchvision_loader/data_model_loader.py b/act/front_end/torchvision_loader/data_model_loader.py index a1d332c6d..79bd44f55 100644 --- a/act/front_end/torchvision_loader/data_model_loader.py +++ b/act/front_end/torchvision_loader/data_model_loader.py @@ -21,6 +21,7 @@ # Import from data_model_mapping from act.front_end.torchvision_loader.data_model_mapping import ( + DATASET_MODEL_MAPPING, get_dataset_info, validate_dataset_model_compatibility, create_preprocessing_pipeline, @@ -213,8 +214,12 @@ def download_dataset_model_pair( if not dataset_exists: print(f"\n[1/3] Downloading dataset...") - import torchvision.datasets - dataset_class = getattr(torchvision.datasets, dataset_name, None) + if "class_name" in dataset_info: + import act.front_end.torchvision_loader.custom + dataset_class = getattr(act.front_end.torchvision_loader.custom, dataset_info["class_name"], None) + else: + import torchvision.datasets + dataset_class = getattr(torchvision.datasets, dataset_name, None) if dataset_class is None: return { @@ -611,8 +616,13 @@ def load_dataset_model_pair( preprocessing = create_preprocessing_pipeline(dataset_name) # Load dataset - import torchvision.datasets - dataset_class = getattr(torchvision.datasets, dataset_name) + if "class_name" in DATASET_MODEL_MAPPING[dataset_name]: + import act.front_end.torchvision_loader.custom + dataset_class = getattr(act.front_end.torchvision_loader.custom, DATASET_MODEL_MAPPING[dataset_name]["class_name"]) + else: + import torchvision.datasets + dataset_class = getattr(torchvision.datasets, dataset_name) + is_train = (split == "train") try: diff --git a/act/front_end/torchvision_loader/data_model_mapping.py b/act/front_end/torchvision_loader/data_model_mapping.py index ae9a144f7..a8c4876f0 100644 --- a/act/front_end/torchvision_loader/data_model_mapping.py +++ b/act/front_end/torchvision_loader/data_model_mapping.py @@ -423,6 +423,19 @@ "num_classes": 8142, # varies by year "category": "classification", "notes": "Species classification, highly imbalanced with long tail" + }, + + # ========== Datasets not in TorchVision ========== + "CUB200": { + "models": ["resnet18"], + "input_size": (3, 224, 224), + "num_classes": 200, + "category": "classification", + "preprocessing": { + "resize_to": (224, 224) + }, + "notes": "Images of birds with 200 classes and 312 binary attributes.", + "class_name": "CUBDataset" } } diff --git a/ipynb/cub_load.ipynb b/ipynb/cub_load.ipynb new file mode 100644 index 000000000..1719e8be3 --- /dev/null +++ b/ipynb/cub_load.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "2993cdc8", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup ACT paths using path_config\n", + "import os\n", + "import sys\n", + "\n", + "\n", + "act_root = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))\n", + "if act_root not in sys.path:\n", + " sys.path.insert(0, act_root)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "19572480", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "================================================================================\n", + "LOADING: CUB200 + resnet18\n", + "================================================================================\n", + "[1/3] Loading dataset (test split)...\n", + " ✓ Loaded 5794 samples\n", + "[2/3] Loading model architecture...\n", + " ✓ Loaded resnet18 with pre-trained weights\n", + " ✓ Adjusted final layer: 512 → 200 classes\n", + " ✓ Loaded resnet18 from torchvision.models\n", + "[3/3] Summary...\n", + " Dataset: 5794 samples (test split)\n", + " Model: 11,279,112 parameters (11,279,112 trainable)\n", + " Batch size: 1\n", + " Preprocessing: Yes\n", + "\n", + "================================================================================\n", + "✓ LOADED SUCCESSFULLY\n", + "================================================================================\n" + ] + }, + { + "data": { + "text/plain": [ + "ResNet(\n", + " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (layer1): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer2): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer3): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer4): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", + " (fc): Linear(in_features=512, out_features=200, bias=True)\n", + ")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from act.front_end.torchvision_loader.data_model_loader import load_dataset_model_pair\n", + "\n", + "\n", + "pair = load_dataset_model_pair(\"CUB200\", \"resnet18\", split=\"test\")\n", + "model, dataset = pair['model'], pair['dataset']\n", + "model.eval()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "act-slcbm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 277f8a7b1506d342890665c153e2e06c16e2e51f Mon Sep 17 00:00:00 2001 From: Hanyuan Li Date: Tue, 9 Jun 2026 16:33:04 +1000 Subject: [PATCH 2/4] fix: use get_dataset_info over MAPPING directly --- act/front_end/torchvision_loader/data_model_loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/act/front_end/torchvision_loader/data_model_loader.py b/act/front_end/torchvision_loader/data_model_loader.py index 79bd44f55..57e663aab 100644 --- a/act/front_end/torchvision_loader/data_model_loader.py +++ b/act/front_end/torchvision_loader/data_model_loader.py @@ -21,7 +21,6 @@ # Import from data_model_mapping from act.front_end.torchvision_loader.data_model_mapping import ( - DATASET_MODEL_MAPPING, get_dataset_info, validate_dataset_model_compatibility, create_preprocessing_pipeline, @@ -616,9 +615,11 @@ def load_dataset_model_pair( preprocessing = create_preprocessing_pipeline(dataset_name) # Load dataset - if "class_name" in DATASET_MODEL_MAPPING[dataset_name]: + dataset_info = get_dataset_info(dataset_name) + + if "class_name" in dataset_info: import act.front_end.torchvision_loader.custom - dataset_class = getattr(act.front_end.torchvision_loader.custom, DATASET_MODEL_MAPPING[dataset_name]["class_name"]) + dataset_class = getattr(act.front_end.torchvision_loader.custom, dataset_info["class_name"]) else: import torchvision.datasets dataset_class = getattr(torchvision.datasets, dataset_name) From bd0c0b8f693a5ded558171368021b513ff56728e Mon Sep 17 00:00:00 2001 From: Hanyuan Li Date: Wed, 10 Jun 2026 11:22:13 +1000 Subject: [PATCH 3/4] chore: remove cub_load.ipynb --- ipynb/cub_load.ipynb | 176 ------------------------------------------- 1 file changed, 176 deletions(-) delete mode 100644 ipynb/cub_load.ipynb diff --git a/ipynb/cub_load.ipynb b/ipynb/cub_load.ipynb deleted file mode 100644 index 1719e8be3..000000000 --- a/ipynb/cub_load.ipynb +++ /dev/null @@ -1,176 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "2993cdc8", - "metadata": {}, - "outputs": [], - "source": [ - "# Setup ACT paths using path_config\n", - "import os\n", - "import sys\n", - "\n", - "\n", - "act_root = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))\n", - "if act_root not in sys.path:\n", - " sys.path.insert(0, act_root)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "19572480", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "================================================================================\n", - "LOADING: CUB200 + resnet18\n", - "================================================================================\n", - "[1/3] Loading dataset (test split)...\n", - " ✓ Loaded 5794 samples\n", - "[2/3] Loading model architecture...\n", - " ✓ Loaded resnet18 with pre-trained weights\n", - " ✓ Adjusted final layer: 512 → 200 classes\n", - " ✓ Loaded resnet18 from torchvision.models\n", - "[3/3] Summary...\n", - " Dataset: 5794 samples (test split)\n", - " Model: 11,279,112 parameters (11,279,112 trainable)\n", - " Batch size: 1\n", - " Preprocessing: Yes\n", - "\n", - "================================================================================\n", - "✓ LOADED SUCCESSFULLY\n", - "================================================================================\n" - ] - }, - { - "data": { - "text/plain": [ - "ResNet(\n", - " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", - " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", - " (layer1): Sequential(\n", - " (0): BasicBlock(\n", - " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (1): BasicBlock(\n", - " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (layer2): Sequential(\n", - " (0): BasicBlock(\n", - " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (downsample): Sequential(\n", - " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (1): BasicBlock(\n", - " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (layer3): Sequential(\n", - " (0): BasicBlock(\n", - " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (downsample): Sequential(\n", - " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (1): BasicBlock(\n", - " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (layer4): Sequential(\n", - " (0): BasicBlock(\n", - " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (downsample): Sequential(\n", - " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", - " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (1): BasicBlock(\n", - " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " )\n", - " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", - " (fc): Linear(in_features=512, out_features=200, bias=True)\n", - ")" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from act.front_end.torchvision_loader.data_model_loader import load_dataset_model_pair\n", - "\n", - "\n", - "pair = load_dataset_model_pair(\"CUB200\", \"resnet18\", split=\"test\")\n", - "model, dataset = pair['model'], pair['dataset']\n", - "model.eval()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "act-slcbm", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From b41ed5676eff9b10a402ef72a4c016518ad8b42f Mon Sep 17 00:00:00 2001 From: guanqin-123 Date: Wed, 10 Jun 2026 17:43:55 +1000 Subject: [PATCH 4/4] feat(torchvision): load archive-distributed datasets via mapping download config Non-TorchVision datasets (e.g. CUB-200-2011) are declared with a 'download' key (url/md5/image_root + optional index_file/split_file). The loader downloads and extracts the archive into data/torchvision//raw, wraps image_root in ImageFolder, and applies the dataset's official deterministic train/test split via Subset. The custom/cub package is removed: ImageFolder covers the CUB layout, the official split comes from images.txt + train_test_split.txt, and the ConceptBottleneck-derived attribute/pickle code was unused by ACT's pipeline. Verified on real CUB-200-2011: 5994 train / 5794 test. --- .../torchvision_loader/custom/README.md | 42 ----- .../torchvision_loader/custom/__init__.py | 1 - .../torchvision_loader/custom/cub/__init__.py | 1 - .../custom/cub/data_processing.py | 71 ------- .../torchvision_loader/custom/cub/dataset.py | 131 ------------- .../torchvision_loader/data_model_loader.py | 177 +++++++++++++----- .../torchvision_loader/data_model_mapping.py | 15 +- 7 files changed, 139 insertions(+), 299 deletions(-) delete mode 100644 act/front_end/torchvision_loader/custom/README.md delete mode 100644 act/front_end/torchvision_loader/custom/__init__.py delete mode 100644 act/front_end/torchvision_loader/custom/cub/__init__.py delete mode 100644 act/front_end/torchvision_loader/custom/cub/data_processing.py delete mode 100644 act/front_end/torchvision_loader/custom/cub/dataset.py diff --git a/act/front_end/torchvision_loader/custom/README.md b/act/front_end/torchvision_loader/custom/README.md deleted file mode 100644 index fd5879431..000000000 --- a/act/front_end/torchvision_loader/custom/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# Custom dataset types - -This folder contains the `Dataset` types for custom datasets, outside of `torchvision.datasets`. To add your own dataset type, please: - -1. Make a new folder for your dataset. -2. Create an `__init__.py` inside your folder so that only your dataset is exposed within the folder. -3. Add your dataset to `custom/__init__.py`. - -Once that is done, you must add an extra field to your custom dataset definition in `data_model_mapping.py`, `"class_name"`, which is equal to your -dataset class's name. For example, if your dataset class name is `ABCDataset`: - -```py -DATASET_MODEL_MAPPING: Dict[str, Dict[str, Any]] = { - # ... - "ABC": { - # ... - "class_name": "ABCDataset", - }, - # ... -} -``` - -Your dataset type must have the following mandatory initialisation arguments (to align with the ones in `torchvision.datasets`), and all other arguments -must be optional: - -```py -from torch.util.data import Dataset - - -class ABCDataset(Dataset): - def __init__( - self, - root: str, - train: bool, - download: bool, - # All other args must be optional: - foo: str | None = None, - bar: int = 5, - ): - # ... - pass -``` diff --git a/act/front_end/torchvision_loader/custom/__init__.py b/act/front_end/torchvision_loader/custom/__init__.py deleted file mode 100644 index 71f6e9db2..000000000 --- a/act/front_end/torchvision_loader/custom/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cub import CUBDataset diff --git a/act/front_end/torchvision_loader/custom/cub/__init__.py b/act/front_end/torchvision_loader/custom/cub/__init__.py deleted file mode 100644 index 593bd8f5f..000000000 --- a/act/front_end/torchvision_loader/custom/cub/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataset import CUBDataset diff --git a/act/front_end/torchvision_loader/custom/cub/data_processing.py b/act/front_end/torchvision_loader/custom/cub/data_processing.py deleted file mode 100644 index 3fc369c81..000000000 --- a/act/front_end/torchvision_loader/custom/cub/data_processing.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Code modified from https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/data_processing.py -Make train, val, test datasets based on train_test_split.txt, and by sampling val_ratio of the official train data to make a validation set -Each dataset is a list of metadata, each includes official image id, full image path, class label, attribute labels, attribute certainty scores, and attribute labels calibrated for uncertainty -""" -import os -import random -from os import listdir -from os.path import isfile, isdir, join -from collections import defaultdict as ddict - - -def extract_data(data_dir): - cwd = os.getcwd() - data_path = join(cwd,data_dir + '/images') - val_ratio = 0.2 - - path_to_id_map = dict() #map from full image path to image id - with open(data_path.replace('images', 'images.txt'), 'r') as f: - for line in f: - items = line.strip().split() - path_to_id_map[join(data_path, items[1])] = int(items[0]) - - attribute_labels_all = ddict(list) #map from image id to a list of attribute labels - attribute_certainties_all = ddict(list) #map from image id to a list of attribute certainties - attribute_uncertain_labels_all = ddict(list) #map from image id to a list of attribute labels calibrated for uncertainty - # 1 = not visible, 2 = guessing, 3 = probably, 4 = definitely - uncertainty_map = {1: {1: 0, 2: 0.5, 3: 0.75, 4:1}, #calibrate main label based on uncertainty label - 0: {1: 0, 2: 0.5, 3: 0.25, 4: 0}} - with open(join(cwd, data_dir + '/attributes/image_attribute_labels.txt'), 'r') as f: - for line in f: - file_idx, attribute_idx, attribute_label, attribute_certainty = line.strip().split()[:4] - attribute_label = int(attribute_label) - attribute_certainty = int(attribute_certainty) - uncertain_label = uncertainty_map[attribute_label][attribute_certainty] - attribute_labels_all[int(file_idx)].append(attribute_label) - attribute_uncertain_labels_all[int(file_idx)].append(uncertain_label) - attribute_certainties_all[int(file_idx)].append(attribute_certainty) - - is_train_test = dict() #map from image id to 0 / 1 (1 = train) - with open(join(cwd, data_dir + '/train_test_split.txt'), 'r') as f: - for line in f: - idx, is_train = line.strip().split() - is_train_test[int(idx)] = int(is_train) - print("Number of train images from official train test split:", sum(list(is_train_test.values()))) - - train_val_data, test_data = [], [] - train_data, val_data = [], [] - folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] - folder_list.sort() #sort by class index - for i, folder in enumerate(folder_list): - folder_path = join(data_path, folder) - classfile_list = [cf for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')] - #classfile_list.sort() - for cf in classfile_list: - img_id = path_to_id_map[join(folder_path, cf)] - img_path = join(folder_path, cf) - metadata = {'id': img_id, 'img_path': img_path, 'class_label': i, - 'attribute_label': attribute_labels_all[img_id], 'attribute_certainty': attribute_certainties_all[img_id], - 'uncertain_attribute_label': attribute_uncertain_labels_all[img_id]} - if is_train_test[img_id]: - train_val_data.append(metadata) - else: - test_data.append(metadata) - - random.shuffle(train_val_data) - split = int(val_ratio * len(train_val_data)) - train_data = train_val_data[split :] - val_data = train_val_data[: split] - print('Size of train set:', len(train_data)) - return train_data, val_data, test_data diff --git a/act/front_end/torchvision_loader/custom/cub/dataset.py b/act/front_end/torchvision_loader/custom/cub/dataset.py deleted file mode 100644 index cd641e3b4..000000000 --- a/act/front_end/torchvision_loader/custom/cub/dataset.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Code modified from https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/dataset.py -General utils for training, evaluation and data loading -""" -import os -import pickle -import numpy as np - -from PIL import Image -from torch.utils.data import Dataset -from torchvision.datasets.utils import check_integrity, download_and_extract_archive, extract_archive - -from .data_processing import extract_data - - -N_ATTRIBUTES = 312 - -class CUBDataset(Dataset): - """ - Returns a compatible Torch Dataset object customized for the CUB dataset - """ - - RAW_DATASET_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" - MD5 = "97eceeb196236b17998738112f37df78" - - # Modified to be in format - # ``` - # dataset_class( - # root=str(raw_dir), - # train=True, - # download=True - # ) - # ``` - def __init__( - self, - root, - train, - download, - # Use sensible defaults - use_attr=False, - no_img=False, - uncertain_label=False, - image_dir="images", - n_class_attr=1, - transform=None, - ): - """ - Arguments: - pkl_file_paths: list of full path to all the pkl data - use_attr: whether to load the attributes (e.g. False for simple finetune) - no_img: whether to load the images (e.g. False for A -> Y model) - uncertain_label: if True, use 'uncertain_attribute_label' field (i.e. label weighted by uncertainty score, e.g. 1 & 3(probably) -> 0.75) - image_dir: default = 'images'. Will be append to the parent dir - n_class_attr: number of classes to predict for each attribute. If 3, then make a separate class for not visible - transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing - """ - self.root = root - self.train = train - - if download: - self.download_and_process() - - if self.train: - self.data = pickle.load(open(f"{root}/processed/train", "rb")) - else: - self.data = pickle.load(open(f"{root}/processed/test", "rb")) - - self.transform = transform - self.use_attr = use_attr - self.no_img = no_img - self.uncertain_label = uncertain_label - self.image_dir = image_dir - self.n_class_attr = n_class_attr - - def download_and_process(self): - if check_integrity(f"{self.root}/CUB_200_2011.tgz", self.MD5): - return - - download_and_extract_archive( - url=self.RAW_DATASET_URL, - download_root=self.root, - extract_root=f"{self.root}/decompressed" - ) - - train, _, test = extract_data(f"{self.root}/decompressed") - - os.mkdir(f"{self.root}/processed") - pickle.dump(train, open(f"{self.root}/processed/train", "wb")) - pickle.dump(test, open(f"{self.root}/processed/test", "wb")) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - img_data = self.data[idx] - img_path = img_data['img_path'] - # Trim unnecessary paths - try: - idx = img_path.split('/').index('CUB_200_2011') - if self.image_dir != 'images': - img_path = '/'.join([self.image_dir] + img_path.split('/')[idx+1:]) - img_path = img_path.replace('images/', '') - else: - img_path = '/'.join(img_path.split('/')[idx:]) - img = Image.open(img_path).convert('RGB') - except: - img_path_split = img_path.split('/') - split = 'train' if self.train else 'test' - img_path = '/'.join(img_path_split[:2] + [split] + img_path_split[2:]) - img = Image.open(img_path).convert('RGB') - - class_label = img_data['class_label'] - if self.transform: - img = self.transform(img) - - if self.use_attr: - if self.uncertain_label: - attr_label = img_data['uncertain_attribute_label'] - else: - attr_label = img_data['attribute_label'] - if self.no_img: - if self.n_class_attr == 3: - one_hot_attr_label = np.zeros((N_ATTRIBUTES, self.n_class_attr)) - one_hot_attr_label[np.arange(N_ATTRIBUTES), attr_label] = 1 - return one_hot_attr_label, class_label - else: - return attr_label, class_label - else: - return img, class_label, attr_label - else: - return img, class_label diff --git a/act/front_end/torchvision_loader/data_model_loader.py b/act/front_end/torchvision_loader/data_model_loader.py index 57e663aab..594358547 100644 --- a/act/front_end/torchvision_loader/data_model_loader.py +++ b/act/front_end/torchvision_loader/data_model_loader.py @@ -32,6 +32,67 @@ from act.front_end.torchvision_loader.model_definitions import _get_custom_model_definition +def _resolve_archive_dataset( + dataset_info: Dict[str, Any], + raw_dir: Path, + train: bool, + transform: Any = None, + download: bool = False, +): + """Load a non-TorchVision dataset declared via the mapping's "download" key. + + The archive's "image_root" must follow the ImageFolder layout (one + sub-directory per class). When "index_file" and "split_file" are present, + the dataset's official train/test split is applied via Subset. + """ + from torch.utils.data import Subset + from torchvision.datasets import ImageFolder + from torchvision.datasets.utils import download_and_extract_archive + + cfg = dataset_info["download"] + raw_dir = Path(raw_dir) + image_root = raw_dir / cfg["image_root"] + if not image_root.exists(): + if not download: + raise FileNotFoundError( + f"Archive dataset not found at {image_root}; download it first." + ) + download_and_extract_archive( + url=cfg["url"], download_root=str(raw_dir), md5=cfg.get("md5"), + ) + dataset = ImageFolder(str(image_root), transform=transform) + if "index_file" in cfg and "split_file" in cfg: + dataset = Subset( + dataset, _official_split_indices(dataset, raw_dir, cfg, train), + ) + return dataset + + +def _official_split_indices( + dataset: Any, raw_dir: Path, cfg: Dict[str, Any], train: bool, +) -> List[int]: + """Deterministic official-split filter from index/split text files. + + Both files are whitespace-separated: index_file maps image_id to a path + relative to image_root's parent; split_file maps image_id to 1 (train) + or 0 (test). + """ + id_to_rel: Dict[int, str] = {} + for line in (raw_dir / cfg["index_file"]).read_text().splitlines(): + img_id, rel_path = line.split() + id_to_rel[int(img_id)] = rel_path + wanted = set() + for line in (raw_dir / cfg["split_file"]).read_text().splitlines(): + img_id, is_train = line.split() + if bool(int(is_train)) == train and int(img_id) in id_to_rel: + wanted.add(id_to_rel[int(img_id)]) + root = Path(dataset.root) + return [ + idx for idx, (path, _) in enumerate(dataset.samples) + if Path(path).relative_to(root).as_posix() in wanted + ] + + def download_dataset_model_pair( dataset_name: str, model_name: str, @@ -212,45 +273,57 @@ def download_dataset_model_pair( # Download dataset (only if not already present) if not dataset_exists: print(f"\n[1/3] Downloading dataset...") - - if "class_name" in dataset_info: - import act.front_end.torchvision_loader.custom - dataset_class = getattr(act.front_end.torchvision_loader.custom, dataset_info["class_name"], None) + + if "download" in dataset_info: + # Archive datasets ship train+test in one archive: download + # once, then both splits are available regardless of `split`. + print(f" • Downloading archive from {dataset_info['download']['url']} ...") + try: + for split_name in ('test', 'train'): + ds = _resolve_archive_dataset( + dataset_info, raw_dir, + train=(split_name == 'train'), + download=True, + ) + downloaded_splits.append(split_name) + print(f" ✓ {split_name.capitalize()} split: {len(ds)} samples") + except Exception as e: + print(f" ⚠ Archive download failed: {e}") else: import torchvision.datasets dataset_class = getattr(torchvision.datasets, dataset_name, None) - - if dataset_class is None: - return { - 'status': 'error', - 'message': f"Dataset {dataset_name} not found in torchvision.datasets" - } - - if split in ['test', 'both']: - print(f" • Downloading test split...") - try: - test_dataset = dataset_class( - root=str(raw_dir), - train=False, - download=True - ) - downloaded_splits.append('test') - print(f" ✓ Test split: {len(test_dataset)} samples") - except Exception as e: - print(f" ⚠ Test split failed: {e}") - - if split in ['train', 'both']: - print(f" • Downloading train split...") - try: - train_dataset = dataset_class( - root=str(raw_dir), - train=True, - download=True - ) - downloaded_splits.append('train') - print(f" ✓ Train split: {len(train_dataset)} samples") - except Exception as e: - print(f" ⚠ Train split failed: {e}") + + if dataset_class is None: + return { + 'status': 'error', + 'message': f"Dataset {dataset_name} not found in torchvision.datasets" + } + + if split in ['test', 'both']: + print(f" • Downloading test split...") + try: + test_dataset = dataset_class( + root=str(raw_dir), + train=False, + download=True + ) + downloaded_splits.append('test') + print(f" ✓ Test split: {len(test_dataset)} samples") + except Exception as e: + print(f" ⚠ Test split failed: {e}") + + if split in ['train', 'both']: + print(f" • Downloading train split...") + try: + train_dataset = dataset_class( + root=str(raw_dir), + train=True, + download=True + ) + downloaded_splits.append('train') + print(f" ✓ Train split: {len(train_dataset)} samples") + except Exception as e: + print(f" ⚠ Train split failed: {e}") if not downloaded_splits: return { @@ -616,23 +689,25 @@ def load_dataset_model_pair( # Load dataset dataset_info = get_dataset_info(dataset_name) - - if "class_name" in dataset_info: - import act.front_end.torchvision_loader.custom - dataset_class = getattr(act.front_end.torchvision_loader.custom, dataset_info["class_name"]) - else: - import torchvision.datasets - dataset_class = getattr(torchvision.datasets, dataset_name) - is_train = (split == "train") - + try: - dataset = dataset_class( - root=str(raw_dir), - train=is_train, - transform=preprocessing, - download=False # Already downloaded - ) + if "download" in dataset_info: + dataset = _resolve_archive_dataset( + dataset_info, raw_dir, + train=is_train, + transform=preprocessing, + download=False, # Already downloaded + ) + else: + import torchvision.datasets + dataset_class = getattr(torchvision.datasets, dataset_name) + dataset = dataset_class( + root=str(raw_dir), + train=is_train, + transform=preprocessing, + download=False # Already downloaded + ) print(f" ✓ Loaded {len(dataset)} samples") except Exception as e: raise RuntimeError(f"Failed to load dataset: {e}") diff --git a/act/front_end/torchvision_loader/data_model_mapping.py b/act/front_end/torchvision_loader/data_model_mapping.py index a8c4876f0..c7dfef434 100644 --- a/act/front_end/torchvision_loader/data_model_mapping.py +++ b/act/front_end/torchvision_loader/data_model_mapping.py @@ -426,6 +426,11 @@ }, # ========== Datasets not in TorchVision ========== + # Non-TorchVision datasets are configured declaratively via the "download" + # key: an archive URL whose extracted "image_root" follows the ImageFolder + # layout (one sub-directory per class). Optional "index_file"/"split_file" + # (image_id -> relative path / image_id -> is_train) select the dataset's + # official train/test split. No custom Dataset class is required. "CUB200": { "models": ["resnet18"], "input_size": (3, 224, 224), @@ -434,8 +439,14 @@ "preprocessing": { "resize_to": (224, 224) }, - "notes": "Images of birds with 200 classes and 312 binary attributes.", - "class_name": "CUBDataset" + "notes": "Caltech-UCSD Birds 200-2011; archive-distributed ImageFolder layout.", + "download": { + "url": "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz", + "md5": "97eceeb196236b17998738112f37df78", + "image_root": "CUB_200_2011/images", + "index_file": "CUB_200_2011/images.txt", + "split_file": "CUB_200_2011/train_test_split.txt" + } } }