From d7ddeda8f7aeb4c7ca7d17b5dc32dc3be1d28bfa Mon Sep 17 00:00:00 2001 From: vimar-gu Date: Thu, 24 Oct 2024 12:34:11 -0400 Subject: [PATCH 1/5] update fishnet --- biobench/fishnet/__init__.py | 14 +++++++------- pyproject.toml | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/biobench/fishnet/__init__.py b/biobench/fishnet/__init__.py index 4687434..6ac9a74 100644 --- a/biobench/fishnet/__init__.py +++ b/biobench/fishnet/__init__.py @@ -53,9 +53,9 @@ class Args(interfaces.TaskArgs): """number of dataloader worker processes.""" log_every: int = 10 """how often (number of epochs) to log progress.""" - n_epochs: int = 100 + n_epochs: int = 50 """How many epochs to train the MLP classifier.""" - learning_rate: float = 5e-4 + learning_rate: float = 1e-4 """The learning rate for training the MLP classifier.""" threshold: float = 0.5 """The threshold to predicted "presence" rather than "absence".""" @@ -114,10 +114,10 @@ def calc_macro_f1(examples: list[interfaces.Example]) -> float: """TODO: docs.""" y_pred = np.array([example.info["y_pred"] for example in examples]) y_true = np.array([example.info["y_true"] for example in examples]) - score = sklearn.metrics.f1_score( - y_true, y_pred, average="macro", labels=np.unique(y_true) - ) - return score.item() + + correct = np.all(y_pred == y_true, axis=1) + acc = np.sum(correct) / len(y_pred) + return acc @beartype.beartype @@ -169,7 +169,7 @@ def benchmark( if (epoch + 1) % args.log_every == 0: examples = evaluate(args, classifier, test_loader) score = calc_macro_f1(examples) - logger.info("Epoch %d/%d: %.3f", epoch + 1, args.n_epochs, score) + logger.info(f"Epoch {epoch + 1}/{args.n_epochs}: {score:.3f}") return model_args, interfaces.TaskReport( "FishNet", examples, calc_mean_score=calc_macro_f1 diff --git a/pyproject.toml b/pyproject.toml index 8e0f71d..ea3bf9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "submitit>=1.5.2", "pycocotools>=2.0.8", "gdown>=5.2.0", + "transformers>=0.39.1", ] [tool.ruff.lint] From e35bf35457f9474db85844988d00e390f966dfca Mon Sep 17 00:00:00 2001 From: vimar-gu Date: Sat, 28 Dec 2024 20:18:01 -0500 Subject: [PATCH 2/5] add facebook dinov2 compatible to contrastive finetuning --- benchmark.py | 3 ++- biobench/dinov2_model.py | 40 ++++++++++++++++++++++++++++++++++ biobench/third_party_models.py | 29 ++++++++++++++++++++---- 3 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 biobench/dinov2_model.py diff --git a/benchmark.py b/benchmark.py index c4fcaeb..273505b 100644 --- a/benchmark.py +++ b/benchmark.py @@ -68,6 +68,7 @@ class Args: interfaces.ModelArgs("open-clip", "ViT-B-16/openai"), interfaces.ModelArgs("open-clip", "ViT-B-16/laion400m_e32"), interfaces.ModelArgs("open-clip", "hf-hub:imageomics/bioclip"), + interfaces.ModelArgs("open-clip", "ViT-B-16/facebook/dinov2-base"), interfaces.ModelArgs("open-clip", "ViT-B-16-SigLIP/webli"), interfaces.ModelArgs("timm-vit", "vit_base_patch14_reg4_dinov2.lvd142m"), ] @@ -398,7 +399,7 @@ def plot_task(conn: sqlite3.Connection, task: str): if not data: return - xs = [row["model_ckpt"] for row in data] + xs = [row["model_ckpt"].split("/")[-1] for row in data] ys = [row["mean_score"] for row in data] yerr = np.array([ys, ys]) diff --git a/biobench/dinov2_model.py b/biobench/dinov2_model.py new file mode 100644 index 0000000..2bd9a16 --- /dev/null +++ b/biobench/dinov2_model.py @@ -0,0 +1,40 @@ +""" DINOv2 model adapter +""" +import beartype +from jaxtyping import jaxtyped + +import torch +import torch.nn as nn + +from transformers import AutoModel + + +@jaxtyped(typechecker=beartype.beartype) +class DINOv2Model(nn.Module): + """ + Add adapter head to DINOv2. + """ + def __init__( + self, + model_name: str, + embed_dim: int, + ): + super().__init__() + self.backbone = AutoModel.from_pretrained(model_name) + self.embed_dim = embed_dim + + prev_chs = self.backbone.config.hidden_size + self.backbone.embeddings.mask_token.requires_grad_(False) + if embed_dim == 0: + self.head = None + else: + self.head = nn.Linear(prev_chs, embed_dim, bias=False) + + def get_cast_dtype(self) -> torch.dtype: + return self.head.proj.weight.dtype + + def forward(self, x): + _, x = self.backbone(x, return_dict=False) + if self.head is not None: + x = self.head(x) + return x diff --git a/biobench/third_party_models.py b/biobench/third_party_models.py index dadebb4..3eebc4c 100644 --- a/biobench/third_party_models.py +++ b/biobench/third_party_models.py @@ -6,6 +6,7 @@ from torch import Tensor from biobench import interfaces +from biobench.dinov2_model import DINOv2Model logger = logging.getLogger("third_party") @@ -63,10 +64,30 @@ def __init__(self, ckpt: str, **kwargs): if ckpt.startswith("hf-hub:"): clip, self.img_transform = open_clip.create_model_from_pretrained(ckpt) else: - arch, ckpt = ckpt.split("/") - clip, self.img_transform = open_clip.create_model_from_pretrained( - arch, pretrained=ckpt, cache_dir=get_cache_dir() - ) + from open_clip.factory import load_state_dict + arch = ckpt.split("/")[0] + ckpt = "/".join(ckpt.split("/")[1:]) + if "facebook" in ckpt and "dino" in ckpt: + dino_model = DINOv2Model( + ckpt, + pretrained=True, + embed_dim=0 + ) + clip, _, self.img_transform = open_clip.create_model_and_transforms(arch) + clip.visual = dino_model + elif "dino" in ckpt: + dino_model = DINOv2Model( + "facebook/dinov2-base", + pretrained=False, + embed_dim=512 + ) + clip, _, self.img_transform = open_clip.create_model_and_transforms(arch, force_image_size=336) + clip.visual = dino_model + clip.load_state_dict(load_state_dict(ckpt)) + else: + clip, _, self.img_transform = open_clip.create_model_and_transforms( + arch, pretrained=ckpt, cache_dir=get_cache_dir() + ) self.model = clip.visual self.model.output_tokens = True # type: ignore From 4c13a4aa490f50ff512cc63a195dfcd1785be78c Mon Sep 17 00:00:00 2001 From: vimar-gu Date: Thu, 9 Jan 2025 19:03:32 -0500 Subject: [PATCH 3/5] fix bug --- biobench/dinov2_model.py | 11 ++++++++--- biobench/third_party_models.py | 4 +--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/biobench/dinov2_model.py b/biobench/dinov2_model.py index 2bd9a16..31fd0e6 100644 --- a/biobench/dinov2_model.py +++ b/biobench/dinov2_model.py @@ -2,6 +2,7 @@ """ import beartype from jaxtyping import jaxtyped +from collections import OrderedDict import torch import torch.nn as nn @@ -25,10 +26,14 @@ def __init__( prev_chs = self.backbone.config.hidden_size self.backbone.embeddings.mask_token.requires_grad_(False) - if embed_dim == 0: - self.head = None + + if embed_dim > 0: + head_layers = OrderedDict() + head_layers['drop'] = nn.Dropout(0.) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=False) + self.head = nn.Sequential(head_layers) else: - self.head = nn.Linear(prev_chs, embed_dim, bias=False) + self.head = nn.Identity() def get_cast_dtype(self) -> torch.dtype: return self.head.proj.weight.dtype diff --git a/biobench/third_party_models.py b/biobench/third_party_models.py index 3eebc4c..b86c8aa 100644 --- a/biobench/third_party_models.py +++ b/biobench/third_party_models.py @@ -70,7 +70,6 @@ def __init__(self, ckpt: str, **kwargs): if "facebook" in ckpt and "dino" in ckpt: dino_model = DINOv2Model( ckpt, - pretrained=True, embed_dim=0 ) clip, _, self.img_transform = open_clip.create_model_and_transforms(arch) @@ -78,10 +77,9 @@ def __init__(self, ckpt: str, **kwargs): elif "dino" in ckpt: dino_model = DINOv2Model( "facebook/dinov2-base", - pretrained=False, embed_dim=512 ) - clip, _, self.img_transform = open_clip.create_model_and_transforms(arch, force_image_size=336) + clip, _, self.img_transform = open_clip.create_model_and_transforms(arch) clip.visual = dino_model clip.load_state_dict(load_state_dict(ckpt)) else: From 6a61066fabf7d4ebba161cabcaa0b869c01e75a9 Mon Sep 17 00:00:00 2001 From: vimar-gu Date: Wed, 5 Feb 2025 13:21:43 -0500 Subject: [PATCH 4/5] init mammalnet --- .gitignore | 3 + benchmark.py | 11 ++ biobench/mammalnet/__init__.py | 270 +++++++++++++++++++++++++++++++++ biobench/mammalnet/download.py | 110 ++++++++++++++ pyproject.toml | 1 + 5 files changed, 395 insertions(+) create mode 100644 biobench/mammalnet/__init__.py create mode 100644 biobench/mammalnet/download.py diff --git a/.gitignore b/.gitignore index 42cd403..a4a97f2 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ __pycache__ graphs/ *.sqlite logs/ +*.pt +.locks +uv.lock diff --git a/benchmark.py b/benchmark.py index 273505b..5b36d7d 100644 --- a/benchmark.py +++ b/benchmark.py @@ -43,6 +43,7 @@ plankton, plantnet, rarespecies, + mammalnet, ) log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" @@ -135,6 +136,10 @@ class Args: default_factory=rarespecies.Args ) """Arguments for the Rare Species benchmark.""" + mammalnet_run: bool = False + """Whether to run the MammalNet benchmark.""" + mammalnet_args: mammalnet.Args = dataclasses.field(default_factory=mammalnet.Args) + """Arguments for the MammalNet benchmark.""" # Reporting and graphing. report_to: str = os.path.join(".", "reports") @@ -338,6 +343,12 @@ def main(args: Args): ) job = executor.submit(rarespecies.benchmark, rarespecies_args, model_args) jobs.append(job) + if args.mammalnet_run: + mammalnet_args = dataclasses.replace( + args.mammalnet_args, device=args.device, debug=args.debug + ) + job = executor.submit(mammalnet.benchmark, mammalnet_args, model_args) + jobs.append(job) logger.info("Submitted %d jobs.", len(jobs)) diff --git a/biobench/mammalnet/__init__.py b/biobench/mammalnet/__init__.py new file mode 100644 index 0000000..ad67bce --- /dev/null +++ b/biobench/mammalnet/__init__.py @@ -0,0 +1,270 @@ +""" +# MammalNet + +MammalNet is built around a biological mammal taxonomy spanning 17 orders, 69 families and 173 mammal categories, and includes 12 common high-level mammal behaviors (e.g. hunt, groom). +We adopt the compositional low-shot animal and behavior recognition benchmark. + +While specialized architectures exist, we train a simple nearest-centroid classifier [which works well with few-shot tasks](https://arxiv.org/abs/1911.04623) over video representations. +We get video representations by embedding each frame of the video and taking the mean over the batch dimension. + +If you use this evaluation, be sure to cite the original work: + +``` +@InProceedings{Chen_2023_CVPR, + author = {Chen, Jun and Hu, Ming and Coker, Darren J. and Berumen, Michael L. and Costelloe, Blair and Beery, Sara and Rohrbach, Anna and Elhoseiny, Mohamed}, + title = {MammalNet: A Large-Scale Video Benchmark for Mammal Recognition and Behavior Understanding}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2023}, + pages = {13052-13061} +} +``` + +This task was contributed by [Jianyang Gu](https://vimar-gu.github.io/). +""" + +import csv +import dataclasses +import logging +import os +import typing + +import beartype +import numpy as np +import torch +from jaxtyping import Float, Int, jaxtyped +from PIL import Image +from torch import Tensor +import torchvision.io as io + +from biobench import interfaces, registry, simpleshot + +logger = logging.getLogger("mammalnet") + + +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class Args(interfaces.TaskArgs): + """Arguments for the MammalNet task.""" + + batch_size: int = 16 + """Batch size for deep model. Note that this is multiplied by 16 (number of frames)""" + n_workers: int = 4 + """Number of dataloader worker processes.""" + frame_agg: typing.Literal["mean", "max"] = "mean" + """How to aggregate features across time dimension.""" + + +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class Video: + """A single video instance as a sequence of frames.""" + + video_id: int + file_name: str + """Path to actual frame images.""" + label_behave: int + """Label for animal behavior.""" + label_species: int + """Label for animal species.""" + + +@jaxtyped(typechecker=beartype.beartype) +class Dataset(torch.utils.data.Dataset): + """ + Clips of at most 90 frames in Charades format with each frame stored as an image. + """ + + def __init__(self, path, split: str, transform=None, seed: int = 42): + self.path = path + self.split = split + self.transform = transform + self.seed = seed + + self.rng = np.random.default_rng(seed=seed) + + self.n_frames = 16 + + # Load videos + ############# + + file_name: dict[int, str] = {} + labels_behave: dict[int, int] = {} + labels_species: dict[int, int] = {} + + if not os.path.exists(self.path) or not os.path.isdir(self.path): + msg = f"Path '{self.path}' doesn't exist. Did you download the MammalNet dataset? See the docstring at the top of this file for instructions. If you did download it, pass the path as --dataset-dir PATH" + raise RuntimeError(msg) + + with open(os.path.join(self.path, "annotation", "composition", f"{split}.csv")) as fd: + reader = csv.reader(fd, delimiter=" ") + for video_id, (path, label_behave, label_species) in enumerate(reader): + video_id = int(video_id) + label_behave = int(label_behave) + label_species = int(label_species) + + path = os.path.join(self.path, path) + file_name[video_id] = path + labels_behave[video_id] = label_behave + labels_species[video_id] = label_species + + self.videos = [ + Video(video_id, file_name[video_id], labels_behave[video_id], labels_species[video_id]) + for video_id in file_name.keys() + ] + + def __getitem__( + self, i: int + ) -> tuple[list[Float[Tensor, "3 width height"]], list[int]]: + """ + Returns 16 frames and their labels sampled every 5 frames from a clip. The start of the clip is uniformly sampled. If there are fewer + """ + video = self.videos[i] + video_file = video.file_name + label_behave = video.label_behave + label_species = video.label_species + + frames, _, _ = io.read_video(video_file, pts_unit="sec") + frames = frames.permute(0, 3, 1, 2).float() / 255.0 + + # Sample n_sample frames between the start and end with equal interval. + indices = torch.linspace(0, len(frames) - 1, self.n_frames).long() + frames = frames[indices] + + if self.transform is not None: + frames = torch.stack([self.transform(frame) for frame in frames]) + + return frames, label_behave, label_species + + def __len__(self) -> int: + return len(self.videos) + + +@torch.no_grad() +@jaxtyped(typechecker=beartype.beartype) +def get_features( + args: Args, backbone: interfaces.VisionBackbone, dataloader +) -> tuple[ + Float[Tensor, "n_frames n_examples dim"], Int[Tensor, "n_frames n_examples"] +]: + """ + Gets all model features and true labels for all frames and all examples in the dataloader. + + Returns it as a pair of big tensors; other tasks like `biobench.birds525` use a dedicated class for this, but here it's just a tuple. + + Args: + args: MammalNet task arguments. + backbone: Vision backbone. + dataloader: Dataloader for whatever data you want to get features for. + + Returns: + tuple of model features and true labels. See signature for shape. + """ + backbone = torch.compile(backbone) + all_features, all_labels_behave, all_labels_species = [], [], [] + + total = len(dataloader) if not args.debug else 2 + it = iter(dataloader) + logger.debug("Need to embed %d batches of %d images.", total, args.batch_size * 16) + for b in range(total): + frames, labels_behave, labels_species = next(it) + frames = torch.stack(frames, dim=0) + labels_behave = torch.stack(labels_behave, dim=0) + labels_species = torch.stack(labels_species, dim=0) + frames = frames.to(args.device) + + with torch.amp.autocast("cuda"): + # conv2d doesn't support multiple batch dimensions, so we have to view() before and after the model.img_encode() call. + n_frames, bsz, c, h, w = frames.shape + frames = frames.view(bsz * n_frames, c, h, w) + outputs = backbone.img_encode(frames) + features = outputs.img_features.view(n_frames, bsz, -1) + all_features.append(features.cpu()) + all_labels_behave.append(labels_behave.cpu()) + all_labels_species.append(labels_species.cpu()) + + logger.debug("Embedded batch %d/%d", b + 1, total) + + all_features = torch.cat(all_features, dim=1).cpu() + all_labels_behave = torch.cat(all_labels_behave, dim=1).cpu() + all_labels_species = torch.cat(all_labels_species, dim=1).cpu() + + return all_features, all_labels_behave, all_labels_species + + +@jaxtyped(typechecker=beartype.beartype) +def aggregate_labels( + args: Args, labels: Int[Tensor, "n_frames n_examples"] +) -> Int[Tensor, " n_examples"]: + """Aggregate per-frame labels to a per-video label. Uses the most common label (mode).""" + return torch.mode(labels, dim=0).values + + +@jaxtyped(typechecker=beartype.beartype) +def aggregate_frames( + args: Args, features: Float[Tensor, "n_frames n_examples dim"] +) -> Float[Tensor, "n_examples dim"]: + if args.frame_agg == "mean": + return torch.mean(features, dim=0) + elif args.frame_agg == "max": + return torch.max(features, dim=0).values + else: + typing.assert_never(args.frame_agg) + + +@beartype.beartype +def benchmark( + args: Args, model_args: interfaces.ModelArgs +) -> tuple[interfaces.ModelArgs, interfaces.TaskReport]: + """Runs MammalNet benchmark.""" + # 1. Load model + backbone = registry.load_vision_backbone(*model_args) + img_transform = backbone.make_img_transform() + backbone = backbone.to(args.device) + + # 2. Load data. + train_dataset = Dataset(args.datadir, "train", transform=img_transform) + val_dataset = Dataset(args.datadir, "val", transform=img_transform) + + data = train_dataset[0] + import pdb; pdb.set_trace() + assert 1==0 + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=args.n_workers, + drop_last=False, + ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, + num_workers=args.n_workers, + drop_last=False, + ) + + # 3. Get features + val_features, val_labels_behave, val_labels_species = get_features(args, backbone, val_dataloader) + val_features = aggregate_frames(args, val_features) + import pdb; pdb.set_trace() + val_labels_behave = aggregate_labels(args, val_labels_behave) + val_labels_species = aggregate_labels(args, val_labels_species) + + train_features, train_labels_behave, train_labale_species = get_features(args, backbone, train_dataloader) + train_features = aggregate_frames(args, train_features) + train_labels = aggregate_labels(args, train_labels) + + # 4. Do simpleshot. + scores = simpleshot.simpleshot( + args, train_features, train_labels_behave, train_labale_species, + val_features, val_labels_behave, val_labels_species + ) + + # Return benchmark report. + video_ids = [video.video_id for video in val_dataset.videos] + examples = [ + interfaces.Example(str(id), float(score), {}) + for id, score in zip(video_ids, scores.tolist()) + ] + # TODO: include example-specific info (class? something else) + return model_args, interfaces.TaskReport("MammalNet", examples) diff --git a/biobench/mammalnet/download.py b/biobench/mammalnet/download.py new file mode 100644 index 0000000..fe20a28 --- /dev/null +++ b/biobench/mammalnet/download.py @@ -0,0 +1,110 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "requests", +# "tqdm", +# "tyro", +# ] +# /// +""" +A script to download the MammalNet dataset. + +Run with: + +1. `python biobench/mammalnet/download.py --help` if `biobench/` is in your $PWD. +2. `python -m biobench.mammalnet.download --help` if you have installed `biobench` as a package. +""" + +import dataclasses +import os.path +import tarfile + +import requests +import tqdm +import tyro + +videos_url = ( + "https://mammalnet.s3.amazonaws.com/trimmed_video.tar.gz" +) +labels_url = "https://mammalnet.s3.amazonaws.com/annotation.tar" + + +@dataclasses.dataclass(frozen=True) +class Args: + """Configure download options.""" + + dir: str = "." + """Where to save data.""" + + chunk_size_kb: int = 1 + """How many KB to download at a time before writing to file.""" + + videos: bool = True + """Whether to download videos [148GB].""" + labels: bool = True + """Whether to download labels.""" + + +def main(args: Args): + """Download MammalNet.""" + os.makedirs(args.dir, exist_ok=True) + chunk_size = int(args.chunk_size_kb * 1024) + videos_tar_path = os.path.join(args.dir, "trimmed_video.tar.gz") + labels_tar_path = os.path.join(args.dir, "annotation.tar") + videos_dir_name = "trimmed_video" + videos_dir_path = os.path.join(args.dir, videos_dir_name) + labels_dir_name = "annotation" + labels_dir_path = os.path.join(args.dir, labels_dir_name) + + if args.labels: + # Download labels + r = requests.get(labels_url, stream=True) + r.raise_for_status() + + with open(labels_tar_path, "wb") as fd: + for chunk in r.iter_content(chunk_size=chunk_size): + fd.write(chunk) + print(f"Downloaded labels: {labels_tar_path}.") + + if args.videos: + # Download videos. + r = requests.get(videos_url, stream=True) + r.raise_for_status() + + n_bytes = int(r.headers["content-length"]) + + with open(videos_tar_path, "wb") as fd: + for chunk in tqdm.tqdm( + r.iter_content(chunk_size=chunk_size), + total=n_bytes / chunk_size, + unit="b", + unit_scale=1, + unit_divisor=1024, + desc="Downloading videos", + ): + fd.write(chunk) + print(f"Downloaded videos: {videos_tar_path}.") + + with tarfile.open(labels_tar_path, "r") as tar: + tar.extractall(path=args.dir) + print(f"Extracted labels: {labels_dir_path}.") + + n_videos = 0 + all_video_files = [] + for csv_file in ["train.csv", "test.csv"]: + with open(os.path.join(labels_dir_path, "composition", csv_file)) as fd: + video_files = fd.readlines() + video_files = [video_file.split(" ")[0] for video_file in video_files] + video_files = [video_file[:13] + video_file[14:] for video_file in video_files] + all_video_files += video_files + n_videos += len(video_files) + + with tarfile.open(videos_tar_path, "r") as tar: + for member in tqdm.tqdm(tar, desc="Extracting videos", total=n_videos): + if member.name in all_video_files or member.name == "trimmed_video": + tar.extract(member, path=args.dir) + print(f"Extracted videos: {videos_dir_path}.") + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/pyproject.toml b/pyproject.toml index 4597cea..2392744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "gdown>=5.2.0", "transformers>=0.39.1", "altair>=5.4.1", + "av<14.0", ] [tool.ruff.lint] From 1779ca1c1cb9e36ec8f82ccabd8fd5af93d6cc5a Mon Sep 17 00:00:00 2001 From: vimar-gu Date: Thu, 13 Feb 2025 13:33:21 -0500 Subject: [PATCH 5/5] update mammalnet --- biobench/mammalnet/__init__.py | 79 ++++++++++++++++------------------ biobench/mammalnet/download.py | 2 +- 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/biobench/mammalnet/__init__.py b/biobench/mammalnet/__init__.py index ad67bce..da16926 100644 --- a/biobench/mammalnet/__init__.py +++ b/biobench/mammalnet/__init__.py @@ -28,6 +28,7 @@ import logging import os import typing +from tqdm import tqdm import beartype import numpy as np @@ -36,6 +37,7 @@ from PIL import Image from torch import Tensor import torchvision.io as io +from torchvision import transforms as T from biobench import interfaces, registry, simpleshot @@ -49,7 +51,7 @@ class Args(interfaces.TaskArgs): batch_size: int = 16 """Batch size for deep model. Note that this is multiplied by 16 (number of frames)""" - n_workers: int = 4 + n_workers: int = 8 """Number of dataloader worker processes.""" frame_agg: typing.Literal["mean", "max"] = "mean" """How to aggregate features across time dimension.""" @@ -62,7 +64,7 @@ class Video: video_id: int file_name: str - """Path to actual frame images.""" + """Path to actual video file.""" label_behave: int """Label for animal behavior.""" label_species: int @@ -72,7 +74,7 @@ class Video: @jaxtyped(typechecker=beartype.beartype) class Dataset(torch.utils.data.Dataset): """ - Clips of at most 90 frames in Charades format with each frame stored as an image. + Each video has two labels for behavior and species, respectively. """ def __init__(self, path, split: str, transform=None, seed: int = 42): @@ -97,13 +99,14 @@ def __init__(self, path, split: str, transform=None, seed: int = 42): raise RuntimeError(msg) with open(os.path.join(self.path, "annotation", "composition", f"{split}.csv")) as fd: - reader = csv.reader(fd, delimiter=" ") - for video_id, (path, label_behave, label_species) in enumerate(reader): + video_files = fd.readlines() + for video_id, video_file in enumerate(video_files): + path, label_behave, label_species = video_file.strip().split(" ") video_id = int(video_id) label_behave = int(label_behave) label_species = int(label_species) - path = os.path.join(self.path, path) + path = os.path.join(self.path, path[:13] + path[14:]) file_name[video_id] = path labels_behave[video_id] = label_behave labels_species[video_id] = label_species @@ -115,9 +118,9 @@ def __init__(self, path, split: str, transform=None, seed: int = 42): def __getitem__( self, i: int - ) -> tuple[list[Float[Tensor, "3 width height"]], list[int]]: + ) -> tuple[list[Float[Tensor, "3 width height"]], int, int]: """ - Returns 16 frames and their labels sampled every 5 frames from a clip. The start of the clip is uniformly sampled. If there are fewer + Returns 16 frames and their labels evenly sampled from a clip. """ video = self.videos[i] video_file = video.file_name @@ -125,11 +128,11 @@ def __getitem__( label_species = video.label_species frames, _, _ = io.read_video(video_file, pts_unit="sec") - frames = frames.permute(0, 3, 1, 2).float() / 255.0 # Sample n_sample frames between the start and end with equal interval. indices = torch.linspace(0, len(frames) - 1, self.n_frames).long() frames = frames[indices] + frames = frames.permute(0, 3, 1, 2).float() / 255.0 if self.transform is not None: frames = torch.stack([self.transform(frame) for frame in frames]) @@ -145,7 +148,7 @@ def __len__(self) -> int: def get_features( args: Args, backbone: interfaces.VisionBackbone, dataloader ) -> tuple[ - Float[Tensor, "n_frames n_examples dim"], Int[Tensor, "n_frames n_examples"] + Float[Tensor, "n_examples n_frames dim"], Int[Tensor, "n_examples"], Int[Tensor, "n_examples"] ]: """ Gets all model features and true labels for all frames and all examples in the dataloader. @@ -166,48 +169,37 @@ def get_features( total = len(dataloader) if not args.debug else 2 it = iter(dataloader) logger.debug("Need to embed %d batches of %d images.", total, args.batch_size * 16) - for b in range(total): + for b in tqdm(range(total)): frames, labels_behave, labels_species = next(it) - frames = torch.stack(frames, dim=0) - labels_behave = torch.stack(labels_behave, dim=0) - labels_species = torch.stack(labels_species, dim=0) frames = frames.to(args.device) with torch.amp.autocast("cuda"): # conv2d doesn't support multiple batch dimensions, so we have to view() before and after the model.img_encode() call. - n_frames, bsz, c, h, w = frames.shape + bsz, n_frames, c, h, w = frames.shape frames = frames.view(bsz * n_frames, c, h, w) outputs = backbone.img_encode(frames) - features = outputs.img_features.view(n_frames, bsz, -1) + features = outputs.img_features.view(bsz, n_frames, -1) all_features.append(features.cpu()) all_labels_behave.append(labels_behave.cpu()) all_labels_species.append(labels_species.cpu()) logger.debug("Embedded batch %d/%d", b + 1, total) - all_features = torch.cat(all_features, dim=1).cpu() - all_labels_behave = torch.cat(all_labels_behave, dim=1).cpu() - all_labels_species = torch.cat(all_labels_species, dim=1).cpu() + all_features = torch.cat(all_features, dim=0).cpu() + all_labels_behave = torch.cat(all_labels_behave).cpu() + all_labels_species = torch.cat(all_labels_species).cpu() return all_features, all_labels_behave, all_labels_species -@jaxtyped(typechecker=beartype.beartype) -def aggregate_labels( - args: Args, labels: Int[Tensor, "n_frames n_examples"] -) -> Int[Tensor, " n_examples"]: - """Aggregate per-frame labels to a per-video label. Uses the most common label (mode).""" - return torch.mode(labels, dim=0).values - - @jaxtyped(typechecker=beartype.beartype) def aggregate_frames( - args: Args, features: Float[Tensor, "n_frames n_examples dim"] + args: Args, features: Float[Tensor, "n_examples n_frames dim"] ) -> Float[Tensor, "n_examples dim"]: if args.frame_agg == "mean": - return torch.mean(features, dim=0) + return torch.mean(features, dim=1) elif args.frame_agg == "max": - return torch.max(features, dim=0).values + return torch.max(features, dim=1).values else: typing.assert_never(args.frame_agg) @@ -219,17 +211,17 @@ def benchmark( """Runs MammalNet benchmark.""" # 1. Load model backbone = registry.load_vision_backbone(*model_args) - img_transform = backbone.make_img_transform() + img_transform = T.Compose([ + T.Resize((224, 224)), + T.CenterCrop((224, 224)), + T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + ]) backbone = backbone.to(args.device) # 2. Load data. train_dataset = Dataset(args.datadir, "train", transform=img_transform) val_dataset = Dataset(args.datadir, "val", transform=img_transform) - data = train_dataset[0] - import pdb; pdb.set_trace() - assert 1==0 - train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, @@ -246,21 +238,22 @@ def benchmark( # 3. Get features val_features, val_labels_behave, val_labels_species = get_features(args, backbone, val_dataloader) val_features = aggregate_frames(args, val_features) - import pdb; pdb.set_trace() - val_labels_behave = aggregate_labels(args, val_labels_behave) - val_labels_species = aggregate_labels(args, val_labels_species) - train_features, train_labels_behave, train_labale_species = get_features(args, backbone, train_dataloader) + train_features, train_labels_behave, train_labels_species = get_features(args, backbone, train_dataloader) train_features = aggregate_frames(args, train_features) - train_labels = aggregate_labels(args, train_labels) # 4. Do simpleshot. - scores = simpleshot.simpleshot( - args, train_features, train_labels_behave, train_labale_species, - val_features, val_labels_behave, val_labels_species + scores_behave = simpleshot.simpleshot( + train_features, train_labels_behave, + val_features, val_labels_behave, args.batch_size, args.device + ) + scores_species = simpleshot.simpleshot( + train_features, train_labels_species, + val_features, val_labels_species, args.batch_size, args.device ) # Return benchmark report. + scores = scores_behave.long() & scores_species.long() video_ids = [video.video_id for video in val_dataset.videos] examples = [ interfaces.Example(str(id), float(score), {}) diff --git a/biobench/mammalnet/download.py b/biobench/mammalnet/download.py index fe20a28..dff9a2a 100644 --- a/biobench/mammalnet/download.py +++ b/biobench/mammalnet/download.py @@ -100,7 +100,7 @@ def main(args: Args): n_videos += len(video_files) with tarfile.open(videos_tar_path, "r") as tar: - for member in tqdm.tqdm(tar, desc="Extracting videos", total=n_videos): + for member in tqdm.tqdm(tar, desc="Extracting videos", total=n_videos + 1): if member.name in all_video_files or member.name == "trimmed_video": tar.extract(member, path=args.dir) print(f"Extracted videos: {videos_dir_path}.")