diff --git a/act/front_end/torchvision_loader/data_model_loader.py b/act/front_end/torchvision_loader/data_model_loader.py index a1d332c6d..594358547 100644 --- a/act/front_end/torchvision_loader/data_model_loader.py +++ b/act/front_end/torchvision_loader/data_model_loader.py @@ -32,6 +32,67 @@ from act.front_end.torchvision_loader.model_definitions import _get_custom_model_definition +def _resolve_archive_dataset( + dataset_info: Dict[str, Any], + raw_dir: Path, + train: bool, + transform: Any = None, + download: bool = False, +): + """Load a non-TorchVision dataset declared via the mapping's "download" key. + + The archive's "image_root" must follow the ImageFolder layout (one + sub-directory per class). When "index_file" and "split_file" are present, + the dataset's official train/test split is applied via Subset. + """ + from torch.utils.data import Subset + from torchvision.datasets import ImageFolder + from torchvision.datasets.utils import download_and_extract_archive + + cfg = dataset_info["download"] + raw_dir = Path(raw_dir) + image_root = raw_dir / cfg["image_root"] + if not image_root.exists(): + if not download: + raise FileNotFoundError( + f"Archive dataset not found at {image_root}; download it first." + ) + download_and_extract_archive( + url=cfg["url"], download_root=str(raw_dir), md5=cfg.get("md5"), + ) + dataset = ImageFolder(str(image_root), transform=transform) + if "index_file" in cfg and "split_file" in cfg: + dataset = Subset( + dataset, _official_split_indices(dataset, raw_dir, cfg, train), + ) + return dataset + + +def _official_split_indices( + dataset: Any, raw_dir: Path, cfg: Dict[str, Any], train: bool, +) -> List[int]: + """Deterministic official-split filter from index/split text files. + + Both files are whitespace-separated: index_file maps image_id to a path + relative to image_root's parent; split_file maps image_id to 1 (train) + or 0 (test). + """ + id_to_rel: Dict[int, str] = {} + for line in (raw_dir / cfg["index_file"]).read_text().splitlines(): + img_id, rel_path = line.split() + id_to_rel[int(img_id)] = rel_path + wanted = set() + for line in (raw_dir / cfg["split_file"]).read_text().splitlines(): + img_id, is_train = line.split() + if bool(int(is_train)) == train and int(img_id) in id_to_rel: + wanted.add(id_to_rel[int(img_id)]) + root = Path(dataset.root) + return [ + idx for idx, (path, _) in enumerate(dataset.samples) + if Path(path).relative_to(root).as_posix() in wanted + ] + + def download_dataset_model_pair( dataset_name: str, model_name: str, @@ -212,41 +273,57 @@ def download_dataset_model_pair( # Download dataset (only if not already present) if not dataset_exists: print(f"\n[1/3] Downloading dataset...") - - import torchvision.datasets - dataset_class = getattr(torchvision.datasets, dataset_name, None) - - if dataset_class is None: - return { - 'status': 'error', - 'message': f"Dataset {dataset_name} not found in torchvision.datasets" - } - - if split in ['test', 'both']: - print(f" • Downloading test split...") - try: - test_dataset = dataset_class( - root=str(raw_dir), - train=False, - download=True - ) - downloaded_splits.append('test') - print(f" ✓ Test split: {len(test_dataset)} samples") - except Exception as e: - print(f" ⚠ Test split failed: {e}") - - if split in ['train', 'both']: - print(f" • Downloading train split...") + + if "download" in dataset_info: + # Archive datasets ship train+test in one archive: download + # once, then both splits are available regardless of `split`. + print(f" • Downloading archive from {dataset_info['download']['url']} ...") try: - train_dataset = dataset_class( - root=str(raw_dir), - train=True, - download=True - ) - downloaded_splits.append('train') - print(f" ✓ Train split: {len(train_dataset)} samples") + for split_name in ('test', 'train'): + ds = _resolve_archive_dataset( + dataset_info, raw_dir, + train=(split_name == 'train'), + download=True, + ) + downloaded_splits.append(split_name) + print(f" ✓ {split_name.capitalize()} split: {len(ds)} samples") except Exception as e: - print(f" ⚠ Train split failed: {e}") + print(f" ⚠ Archive download failed: {e}") + else: + import torchvision.datasets + dataset_class = getattr(torchvision.datasets, dataset_name, None) + + if dataset_class is None: + return { + 'status': 'error', + 'message': f"Dataset {dataset_name} not found in torchvision.datasets" + } + + if split in ['test', 'both']: + print(f" • Downloading test split...") + try: + test_dataset = dataset_class( + root=str(raw_dir), + train=False, + download=True + ) + downloaded_splits.append('test') + print(f" ✓ Test split: {len(test_dataset)} samples") + except Exception as e: + print(f" ⚠ Test split failed: {e}") + + if split in ['train', 'both']: + print(f" • Downloading train split...") + try: + train_dataset = dataset_class( + root=str(raw_dir), + train=True, + download=True + ) + downloaded_splits.append('train') + print(f" ✓ Train split: {len(train_dataset)} samples") + except Exception as e: + print(f" ⚠ Train split failed: {e}") if not downloaded_splits: return { @@ -611,17 +688,26 @@ def load_dataset_model_pair( preprocessing = create_preprocessing_pipeline(dataset_name) # Load dataset - import torchvision.datasets - dataset_class = getattr(torchvision.datasets, dataset_name) + dataset_info = get_dataset_info(dataset_name) is_train = (split == "train") - + try: - dataset = dataset_class( - root=str(raw_dir), - train=is_train, - transform=preprocessing, - download=False # Already downloaded - ) + if "download" in dataset_info: + dataset = _resolve_archive_dataset( + dataset_info, raw_dir, + train=is_train, + transform=preprocessing, + download=False, # Already downloaded + ) + else: + import torchvision.datasets + dataset_class = getattr(torchvision.datasets, dataset_name) + dataset = dataset_class( + root=str(raw_dir), + train=is_train, + transform=preprocessing, + download=False # Already downloaded + ) print(f" ✓ Loaded {len(dataset)} samples") except Exception as e: raise RuntimeError(f"Failed to load dataset: {e}") diff --git a/act/front_end/torchvision_loader/data_model_mapping.py b/act/front_end/torchvision_loader/data_model_mapping.py index ae9a144f7..c7dfef434 100644 --- a/act/front_end/torchvision_loader/data_model_mapping.py +++ b/act/front_end/torchvision_loader/data_model_mapping.py @@ -423,6 +423,30 @@ "num_classes": 8142, # varies by year "category": "classification", "notes": "Species classification, highly imbalanced with long tail" + }, + + # ========== Datasets not in TorchVision ========== + # Non-TorchVision datasets are configured declaratively via the "download" + # key: an archive URL whose extracted "image_root" follows the ImageFolder + # layout (one sub-directory per class). Optional "index_file"/"split_file" + # (image_id -> relative path / image_id -> is_train) select the dataset's + # official train/test split. No custom Dataset class is required. + "CUB200": { + "models": ["resnet18"], + "input_size": (3, 224, 224), + "num_classes": 200, + "category": "classification", + "preprocessing": { + "resize_to": (224, 224) + }, + "notes": "Caltech-UCSD Birds 200-2011; archive-distributed ImageFolder layout.", + "download": { + "url": "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz", + "md5": "97eceeb196236b17998738112f37df78", + "image_root": "CUB_200_2011/images", + "index_file": "CUB_200_2011/images.txt", + "split_file": "CUB_200_2011/train_test_split.txt" + } } }