Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 128 additions & 42 deletions act/front_end/torchvision_loader/data_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,67 @@
from act.front_end.torchvision_loader.model_definitions import _get_custom_model_definition


def _resolve_archive_dataset(
dataset_info: Dict[str, Any],
raw_dir: Path,
train: bool,
transform: Any = None,
download: bool = False,
):
"""Load a non-TorchVision dataset declared via the mapping's "download" key.

The archive's "image_root" must follow the ImageFolder layout (one
sub-directory per class). When "index_file" and "split_file" are present,
the dataset's official train/test split is applied via Subset.
"""
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive

cfg = dataset_info["download"]
raw_dir = Path(raw_dir)
image_root = raw_dir / cfg["image_root"]
if not image_root.exists():
if not download:
raise FileNotFoundError(
f"Archive dataset not found at {image_root}; download it first."
)
download_and_extract_archive(
url=cfg["url"], download_root=str(raw_dir), md5=cfg.get("md5"),
)
dataset = ImageFolder(str(image_root), transform=transform)
if "index_file" in cfg and "split_file" in cfg:
dataset = Subset(
dataset, _official_split_indices(dataset, raw_dir, cfg, train),
)
return dataset


def _official_split_indices(
dataset: Any, raw_dir: Path, cfg: Dict[str, Any], train: bool,
) -> List[int]:
"""Deterministic official-split filter from index/split text files.

Both files are whitespace-separated: index_file maps image_id to a path
relative to image_root's parent; split_file maps image_id to 1 (train)
or 0 (test).
"""
id_to_rel: Dict[int, str] = {}
for line in (raw_dir / cfg["index_file"]).read_text().splitlines():
img_id, rel_path = line.split()
id_to_rel[int(img_id)] = rel_path
wanted = set()
for line in (raw_dir / cfg["split_file"]).read_text().splitlines():
img_id, is_train = line.split()
if bool(int(is_train)) == train and int(img_id) in id_to_rel:
wanted.add(id_to_rel[int(img_id)])
root = Path(dataset.root)
return [
idx for idx, (path, _) in enumerate(dataset.samples)
if Path(path).relative_to(root).as_posix() in wanted
]


def download_dataset_model_pair(
dataset_name: str,
model_name: str,
Expand Down Expand Up @@ -212,41 +273,57 @@ def download_dataset_model_pair(
# Download dataset (only if not already present)
if not dataset_exists:
print(f"\n[1/3] Downloading dataset...")

import torchvision.datasets
dataset_class = getattr(torchvision.datasets, dataset_name, None)

if dataset_class is None:
return {
'status': 'error',
'message': f"Dataset {dataset_name} not found in torchvision.datasets"
}

if split in ['test', 'both']:
print(f" • Downloading test split...")
try:
test_dataset = dataset_class(
root=str(raw_dir),
train=False,
download=True
)
downloaded_splits.append('test')
print(f" ✓ Test split: {len(test_dataset)} samples")
except Exception as e:
print(f" ⚠ Test split failed: {e}")

if split in ['train', 'both']:
print(f" • Downloading train split...")

if "download" in dataset_info:
# Archive datasets ship train+test in one archive: download
# once, then both splits are available regardless of `split`.
print(f" • Downloading archive from {dataset_info['download']['url']} ...")
try:
train_dataset = dataset_class(
root=str(raw_dir),
train=True,
download=True
)
downloaded_splits.append('train')
print(f" ✓ Train split: {len(train_dataset)} samples")
for split_name in ('test', 'train'):
ds = _resolve_archive_dataset(
dataset_info, raw_dir,
train=(split_name == 'train'),
download=True,
)
downloaded_splits.append(split_name)
print(f" ✓ {split_name.capitalize()} split: {len(ds)} samples")
except Exception as e:
print(f" ⚠ Train split failed: {e}")
print(f" ⚠ Archive download failed: {e}")
else:
import torchvision.datasets
dataset_class = getattr(torchvision.datasets, dataset_name, None)

if dataset_class is None:
return {
'status': 'error',
'message': f"Dataset {dataset_name} not found in torchvision.datasets"
}

if split in ['test', 'both']:
print(f" • Downloading test split...")
try:
test_dataset = dataset_class(
root=str(raw_dir),
train=False,
download=True
)
downloaded_splits.append('test')
print(f" ✓ Test split: {len(test_dataset)} samples")
except Exception as e:
print(f" ⚠ Test split failed: {e}")

if split in ['train', 'both']:
print(f" • Downloading train split...")
try:
train_dataset = dataset_class(
root=str(raw_dir),
train=True,
download=True
)
downloaded_splits.append('train')
print(f" ✓ Train split: {len(train_dataset)} samples")
except Exception as e:
print(f" ⚠ Train split failed: {e}")

if not downloaded_splits:
return {
Expand Down Expand Up @@ -611,17 +688,26 @@ def load_dataset_model_pair(
preprocessing = create_preprocessing_pipeline(dataset_name)

# Load dataset
import torchvision.datasets
dataset_class = getattr(torchvision.datasets, dataset_name)
dataset_info = get_dataset_info(dataset_name)
is_train = (split == "train")

try:
dataset = dataset_class(
root=str(raw_dir),
train=is_train,
transform=preprocessing,
download=False # Already downloaded
)
if "download" in dataset_info:
dataset = _resolve_archive_dataset(
dataset_info, raw_dir,
train=is_train,
transform=preprocessing,
download=False, # Already downloaded
)
else:
import torchvision.datasets
dataset_class = getattr(torchvision.datasets, dataset_name)
dataset = dataset_class(
root=str(raw_dir),
train=is_train,
transform=preprocessing,
download=False # Already downloaded
)
print(f" ✓ Loaded {len(dataset)} samples")
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {e}")
Expand Down
24 changes: 24 additions & 0 deletions act/front_end/torchvision_loader/data_model_mapping.py

@guanqin-123 guanqin-123 Jun 9, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, can you have a new key as a dataset link with the url to download, specifically to CUB200? Then, the cub folder and the following code in the new files are not needed.

Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,30 @@
"num_classes": 8142, # varies by year
"category": "classification",
"notes": "Species classification, highly imbalanced with long tail"
},

# ========== Datasets not in TorchVision ==========
# Non-TorchVision datasets are configured declaratively via the "download"
# key: an archive URL whose extracted "image_root" follows the ImageFolder
# layout (one sub-directory per class). Optional "index_file"/"split_file"
# (image_id -> relative path / image_id -> is_train) select the dataset's
# official train/test split. No custom Dataset class is required.
"CUB200": {
"models": ["resnet18"],
"input_size": (3, 224, 224),
"num_classes": 200,
"category": "classification",
"preprocessing": {
"resize_to": (224, 224)
},
"notes": "Caltech-UCSD Birds 200-2011; archive-distributed ImageFolder layout.",
"download": {
"url": "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz",
"md5": "97eceeb196236b17998738112f37df78",
"image_root": "CUB_200_2011/images",
"index_file": "CUB_200_2011/images.txt",
"split_file": "CUB_200_2011/train_test_split.txt"
}
}
}

Expand Down
Loading