refactor: overhaul dataset loaders, add UTKFace, and extend AmuletDataset schema#73
Open
asim29 wants to merge 1 commit intopr/abstract-base-classesfrom
Open
Conversation
…aset schema
AmuletDataset gains two required-but-typed fields: modality ("image" |
"tabular") to describe tensor shape seen by models, and sensitive_columns
(list[str] | None) to name z_train/z_test columns in order.
Dataset loader changes:
- load_census: replaces ucimlrepo fetch with GDrive download (_CENSUS_GDRIVE_ID);
drops ucimlrepo from dependencies; all callers now get modality and
sensitive_columns populated.
- load_lfw: rewrites image/attribute pipeline using helper functions
(_lfw_read_attributes, _lfw_attr_labels, _lfw_build_images_npz,
_lfw_build_processed_cache); supports configurable target and two sensitive
attributes with parameter-keyed .npz caching.
- load_celeba: rewrites to use GDrive downloads (_CELEBA_IMAGES_GDRIVE_ID,
_CELEBA_ATTRS_GDRIVE_ID) and numpy-based image processing instead of
CSV pixel strings; populates modality and sensitive_columns.
- load_utkface: new loader. Parses age/gender/race from filenames, downloads
archive from GDrive, builds parameter-keyed .npz cache. Supports age
discretization via age_bins. Exported from amulet.datasets.
- load_cifar10, load_cifar100, load_fmnist, load_mnist: add modality="image"
to AmuletDataset construction; docstrings trimmed to Google style.
pipeline: adds load_utkface import and elif branch in load_data().
pyproject.toml: removes ucimlrepo==0.0.3 (replaced by GDrive download).
tests: updates test_dataset_dataclass.py to the refactor branch version
which covers modality and sensitive_columns fields.
db745b5 to
63758d2
Compare
be3e9a1 to
13e03b9
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR reworks the dataset substrate that all risk-module pipelines depend on.
AmuletDatasetgains two new fields, every loader is updated to populate them,ucimlrepois removed in favour of GDrive downloads, and a fullload_utkfaceloader is added.Stacked on:
pr/abstract-base-classesWhat changed and why
AmuletDatasetschema extension (amulet/datasets/__data.py)Two new fields are added to the dataclass:
modalityLiteral["image", "tabular"]"image"for(C, H, W)samples;"tabular"for 1-D feature vectors. LFW is"tabular"because its images are flattened in the loader. Required field (no default) so callers are forced to be explicit.sensitive_columnslist[str] | Nonez_train/z_testcolumns in order.Nonewhen the dataset has no sensitive attributes.load_censusrewrite (amulet/datasets/__tabular_datasets.py)ucimlrepodependency entirely. The Census CSV is now downloaded directly from Google Drive (_CENSUS_GDRIVE_ID) on first use and cached locally, matching how other datasets work.pyproject.toml: dropsucimlrepo==0.0.3;uv.lockupdated accordingly.AmuletDatasetwithmodality="tabular"andsensitive_columns=["race", "sex"].load_lfwrewrite (amulet/datasets/__tabular_datasets.py)The previous implementation loaded images via scikit-learn and fetched attributes from a fixed URL. The rewrite:
_LFW_ATTRIBUTES_GDRIVE_ID) on first use._lfw_read_attributes,_lfw_attr_labels,_lfw_build_images_npz,_lfw_build_processed_cache.target,attribute_1,attribute_2parameters (previously hardcoded to gender/race)..npzcache so repeated calls with the same arguments are fast.AmuletDatasetwithmodality="tabular"(images are flattened) andsensitive_columns=[attribute_1, attribute_2].load_celebarewrite (amulet/datasets/__image_datasets.py)_CELEBA_IMAGES_GDRIVE_ID,_CELEBA_ATTRS_GDRIVE_ID) and processed into a.npzcache.AmuletDatasetwithmodality="image"andsensitive_columnspopulated from the target attribute config.load_utkface— new loader (amulet/datasets/__image_datasets.py)_UTKFACE_GDRIVE_ID)age(0–116 int),gender(0/1),race(0–4).npz(target,attribute_1,attribute_2)age_binsvianp.digitize, applied to any attribute that is"age"(N, 3, 64, 64)float32 in[0, 1]sensitive_columns[attribute_1, attribute_2]Exported from
amulet.datasetsand registered inload_data()inamulet/utils/__pipeline.py.CIFAR-10/100, FMNIST, MNIST loaders
All four loaders receive
modality="image"in theirAmuletDatasetconstruction. Docstrings trimmed to Google style (imperative summary, concise Args/Returns).Test update (
tests/unit/test_dataset_dataclass.py)Replaced the PR1 stub with the refactor-branch version that covers
modalityandsensitive_columnsfields.Test plan
uv run pre-commit run --all-filespassesuv run pytest tests/unit/test_dataset_dataclass.pypassesfrom amulet.datasets import load_utkfaceresolvesAmuletDataset(train_set=..., test_set=..., num_features=1, num_classes=2)raisesTypeError(missingmodality)AmuletDataset(..., modality="tabular")constructs withsensitive_columns=None🤖 Generated with Claude Code