diff --git a/.gitignore b/.gitignore index 42cd403..a4a97f2 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ __pycache__ graphs/ *.sqlite logs/ +*.pt +.locks +uv.lock diff --git a/benchmark.py b/benchmark.py index c4fcaeb..5b36d7d 100644 --- a/benchmark.py +++ b/benchmark.py @@ -43,6 +43,7 @@ plankton, plantnet, rarespecies, + mammalnet, ) log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" @@ -68,6 +69,7 @@ class Args: interfaces.ModelArgs("open-clip", "ViT-B-16/openai"), interfaces.ModelArgs("open-clip", "ViT-B-16/laion400m_e32"), interfaces.ModelArgs("open-clip", "hf-hub:imageomics/bioclip"), + interfaces.ModelArgs("open-clip", "ViT-B-16/facebook/dinov2-base"), interfaces.ModelArgs("open-clip", "ViT-B-16-SigLIP/webli"), interfaces.ModelArgs("timm-vit", "vit_base_patch14_reg4_dinov2.lvd142m"), ] @@ -134,6 +136,10 @@ class Args: default_factory=rarespecies.Args ) """Arguments for the Rare Species benchmark.""" + mammalnet_run: bool = False + """Whether to run the MammalNet benchmark.""" + mammalnet_args: mammalnet.Args = dataclasses.field(default_factory=mammalnet.Args) + """Arguments for the MammalNet benchmark.""" # Reporting and graphing. report_to: str = os.path.join(".", "reports") @@ -337,6 +343,12 @@ def main(args: Args): ) job = executor.submit(rarespecies.benchmark, rarespecies_args, model_args) jobs.append(job) + if args.mammalnet_run: + mammalnet_args = dataclasses.replace( + args.mammalnet_args, device=args.device, debug=args.debug + ) + job = executor.submit(mammalnet.benchmark, mammalnet_args, model_args) + jobs.append(job) logger.info("Submitted %d jobs.", len(jobs)) @@ -398,7 +410,7 @@ def plot_task(conn: sqlite3.Connection, task: str): if not data: return - xs = [row["model_ckpt"] for row in data] + xs = [row["model_ckpt"].split("/")[-1] for row in data] ys = [row["mean_score"] for row in data] yerr = np.array([ys, ys]) diff --git a/biobench/dinov2_model.py b/biobench/dinov2_model.py new file mode 100644 index 0000000..31fd0e6 --- /dev/null +++ b/biobench/dinov2_model.py @@ -0,0 +1,45 @@ +""" DINOv2 model adapter +""" +import beartype +from jaxtyping import jaxtyped +from collections import OrderedDict + +import torch +import torch.nn as nn + +from transformers import AutoModel + + +@jaxtyped(typechecker=beartype.beartype) +class DINOv2Model(nn.Module): + """ + Add adapter head to DINOv2. + """ + def __init__( + self, + model_name: str, + embed_dim: int, + ): + super().__init__() + self.backbone = AutoModel.from_pretrained(model_name) + self.embed_dim = embed_dim + + prev_chs = self.backbone.config.hidden_size + self.backbone.embeddings.mask_token.requires_grad_(False) + + if embed_dim > 0: + head_layers = OrderedDict() + head_layers['drop'] = nn.Dropout(0.) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=False) + self.head = nn.Sequential(head_layers) + else: + self.head = nn.Identity() + + def get_cast_dtype(self) -> torch.dtype: + return self.head.proj.weight.dtype + + def forward(self, x): + _, x = self.backbone(x, return_dict=False) + if self.head is not None: + x = self.head(x) + return x diff --git a/biobench/fishnet/__init__.py b/biobench/fishnet/__init__.py index 46674df..824f7f8 100644 --- a/biobench/fishnet/__init__.py +++ b/biobench/fishnet/__init__.py @@ -52,9 +52,9 @@ class Args(interfaces.TaskArgs): """number of dataloader worker processes.""" log_every: int = 10 """how often (number of epochs) to log progress.""" - n_epochs: int = 100 + n_epochs: int = 50 """How many epochs to train the MLP classifier.""" - learning_rate: float = 5e-4 + learning_rate: float = 1e-4 """The learning rate for training the MLP classifier.""" threshold: float = 0.5 """The threshold to predicted "presence" rather than "absence".""" @@ -113,10 +113,10 @@ def calc_macro_f1(examples: list[interfaces.Example]) -> float: """TODO: docs.""" y_pred = np.array([example.info["y_pred"] for example in examples]) y_true = np.array([example.info["y_true"] for example in examples]) - score = sklearn.metrics.f1_score( - y_true, y_pred, average="macro", labels=np.unique(y_true) - ) - return score.item() + + correct = np.all(y_pred == y_true, axis=1) + acc = np.sum(correct) / len(y_pred) + return acc @beartype.beartype @@ -168,7 +168,7 @@ def benchmark( if (epoch + 1) % args.log_every == 0: examples = evaluate(args, classifier, test_loader) score = calc_macro_f1(examples) - logger.info("Epoch %d/%d: %.3f", epoch + 1, args.n_epochs, score) + logger.info(f"Epoch {epoch + 1}/{args.n_epochs}: {score:.3f}") return model_args, interfaces.TaskReport( "FishNet", examples, calc_mean_score=calc_macro_f1 diff --git a/biobench/mammalnet/__init__.py b/biobench/mammalnet/__init__.py new file mode 100644 index 0000000..da16926 --- /dev/null +++ b/biobench/mammalnet/__init__.py @@ -0,0 +1,263 @@ +""" +# MammalNet + +MammalNet is built around a biological mammal taxonomy spanning 17 orders, 69 families and 173 mammal categories, and includes 12 common high-level mammal behaviors (e.g. hunt, groom). +We adopt the compositional low-shot animal and behavior recognition benchmark. + +While specialized architectures exist, we train a simple nearest-centroid classifier [which works well with few-shot tasks](https://arxiv.org/abs/1911.04623) over video representations. +We get video representations by embedding each frame of the video and taking the mean over the batch dimension. + +If you use this evaluation, be sure to cite the original work: + +``` +@InProceedings{Chen_2023_CVPR, + author = {Chen, Jun and Hu, Ming and Coker, Darren J. and Berumen, Michael L. and Costelloe, Blair and Beery, Sara and Rohrbach, Anna and Elhoseiny, Mohamed}, + title = {MammalNet: A Large-Scale Video Benchmark for Mammal Recognition and Behavior Understanding}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2023}, + pages = {13052-13061} +} +``` + +This task was contributed by [Jianyang Gu](https://vimar-gu.github.io/). +""" + +import csv +import dataclasses +import logging +import os +import typing +from tqdm import tqdm + +import beartype +import numpy as np +import torch +from jaxtyping import Float, Int, jaxtyped +from PIL import Image +from torch import Tensor +import torchvision.io as io +from torchvision import transforms as T + +from biobench import interfaces, registry, simpleshot + +logger = logging.getLogger("mammalnet") + + +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class Args(interfaces.TaskArgs): + """Arguments for the MammalNet task.""" + + batch_size: int = 16 + """Batch size for deep model. Note that this is multiplied by 16 (number of frames)""" + n_workers: int = 8 + """Number of dataloader worker processes.""" + frame_agg: typing.Literal["mean", "max"] = "mean" + """How to aggregate features across time dimension.""" + + +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class Video: + """A single video instance as a sequence of frames.""" + + video_id: int + file_name: str + """Path to actual video file.""" + label_behave: int + """Label for animal behavior.""" + label_species: int + """Label for animal species.""" + + +@jaxtyped(typechecker=beartype.beartype) +class Dataset(torch.utils.data.Dataset): + """ + Each video has two labels for behavior and species, respectively. + """ + + def __init__(self, path, split: str, transform=None, seed: int = 42): + self.path = path + self.split = split + self.transform = transform + self.seed = seed + + self.rng = np.random.default_rng(seed=seed) + + self.n_frames = 16 + + # Load videos + ############# + + file_name: dict[int, str] = {} + labels_behave: dict[int, int] = {} + labels_species: dict[int, int] = {} + + if not os.path.exists(self.path) or not os.path.isdir(self.path): + msg = f"Path '{self.path}' doesn't exist. Did you download the MammalNet dataset? See the docstring at the top of this file for instructions. If you did download it, pass the path as --dataset-dir PATH" + raise RuntimeError(msg) + + with open(os.path.join(self.path, "annotation", "composition", f"{split}.csv")) as fd: + video_files = fd.readlines() + for video_id, video_file in enumerate(video_files): + path, label_behave, label_species = video_file.strip().split(" ") + video_id = int(video_id) + label_behave = int(label_behave) + label_species = int(label_species) + + path = os.path.join(self.path, path[:13] + path[14:]) + file_name[video_id] = path + labels_behave[video_id] = label_behave + labels_species[video_id] = label_species + + self.videos = [ + Video(video_id, file_name[video_id], labels_behave[video_id], labels_species[video_id]) + for video_id in file_name.keys() + ] + + def __getitem__( + self, i: int + ) -> tuple[list[Float[Tensor, "3 width height"]], int, int]: + """ + Returns 16 frames and their labels evenly sampled from a clip. + """ + video = self.videos[i] + video_file = video.file_name + label_behave = video.label_behave + label_species = video.label_species + + frames, _, _ = io.read_video(video_file, pts_unit="sec") + + # Sample n_sample frames between the start and end with equal interval. + indices = torch.linspace(0, len(frames) - 1, self.n_frames).long() + frames = frames[indices] + frames = frames.permute(0, 3, 1, 2).float() / 255.0 + + if self.transform is not None: + frames = torch.stack([self.transform(frame) for frame in frames]) + + return frames, label_behave, label_species + + def __len__(self) -> int: + return len(self.videos) + + +@torch.no_grad() +@jaxtyped(typechecker=beartype.beartype) +def get_features( + args: Args, backbone: interfaces.VisionBackbone, dataloader +) -> tuple[ + Float[Tensor, "n_examples n_frames dim"], Int[Tensor, "n_examples"], Int[Tensor, "n_examples"] +]: + """ + Gets all model features and true labels for all frames and all examples in the dataloader. + + Returns it as a pair of big tensors; other tasks like `biobench.birds525` use a dedicated class for this, but here it's just a tuple. + + Args: + args: MammalNet task arguments. + backbone: Vision backbone. + dataloader: Dataloader for whatever data you want to get features for. + + Returns: + tuple of model features and true labels. See signature for shape. + """ + backbone = torch.compile(backbone) + all_features, all_labels_behave, all_labels_species = [], [], [] + + total = len(dataloader) if not args.debug else 2 + it = iter(dataloader) + logger.debug("Need to embed %d batches of %d images.", total, args.batch_size * 16) + for b in tqdm(range(total)): + frames, labels_behave, labels_species = next(it) + frames = frames.to(args.device) + + with torch.amp.autocast("cuda"): + # conv2d doesn't support multiple batch dimensions, so we have to view() before and after the model.img_encode() call. + bsz, n_frames, c, h, w = frames.shape + frames = frames.view(bsz * n_frames, c, h, w) + outputs = backbone.img_encode(frames) + features = outputs.img_features.view(bsz, n_frames, -1) + all_features.append(features.cpu()) + all_labels_behave.append(labels_behave.cpu()) + all_labels_species.append(labels_species.cpu()) + + logger.debug("Embedded batch %d/%d", b + 1, total) + + all_features = torch.cat(all_features, dim=0).cpu() + all_labels_behave = torch.cat(all_labels_behave).cpu() + all_labels_species = torch.cat(all_labels_species).cpu() + + return all_features, all_labels_behave, all_labels_species + + +@jaxtyped(typechecker=beartype.beartype) +def aggregate_frames( + args: Args, features: Float[Tensor, "n_examples n_frames dim"] +) -> Float[Tensor, "n_examples dim"]: + if args.frame_agg == "mean": + return torch.mean(features, dim=1) + elif args.frame_agg == "max": + return torch.max(features, dim=1).values + else: + typing.assert_never(args.frame_agg) + + +@beartype.beartype +def benchmark( + args: Args, model_args: interfaces.ModelArgs +) -> tuple[interfaces.ModelArgs, interfaces.TaskReport]: + """Runs MammalNet benchmark.""" + # 1. Load model + backbone = registry.load_vision_backbone(*model_args) + img_transform = T.Compose([ + T.Resize((224, 224)), + T.CenterCrop((224, 224)), + T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + ]) + backbone = backbone.to(args.device) + + # 2. Load data. + train_dataset = Dataset(args.datadir, "train", transform=img_transform) + val_dataset = Dataset(args.datadir, "val", transform=img_transform) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=args.n_workers, + drop_last=False, + ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, + num_workers=args.n_workers, + drop_last=False, + ) + + # 3. Get features + val_features, val_labels_behave, val_labels_species = get_features(args, backbone, val_dataloader) + val_features = aggregate_frames(args, val_features) + + train_features, train_labels_behave, train_labels_species = get_features(args, backbone, train_dataloader) + train_features = aggregate_frames(args, train_features) + + # 4. Do simpleshot. + scores_behave = simpleshot.simpleshot( + train_features, train_labels_behave, + val_features, val_labels_behave, args.batch_size, args.device + ) + scores_species = simpleshot.simpleshot( + train_features, train_labels_species, + val_features, val_labels_species, args.batch_size, args.device + ) + + # Return benchmark report. + scores = scores_behave.long() & scores_species.long() + video_ids = [video.video_id for video in val_dataset.videos] + examples = [ + interfaces.Example(str(id), float(score), {}) + for id, score in zip(video_ids, scores.tolist()) + ] + # TODO: include example-specific info (class? something else) + return model_args, interfaces.TaskReport("MammalNet", examples) diff --git a/biobench/mammalnet/download.py b/biobench/mammalnet/download.py new file mode 100644 index 0000000..dff9a2a --- /dev/null +++ b/biobench/mammalnet/download.py @@ -0,0 +1,110 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "requests", +# "tqdm", +# "tyro", +# ] +# /// +""" +A script to download the MammalNet dataset. + +Run with: + +1. `python biobench/mammalnet/download.py --help` if `biobench/` is in your $PWD. +2. `python -m biobench.mammalnet.download --help` if you have installed `biobench` as a package. +""" + +import dataclasses +import os.path +import tarfile + +import requests +import tqdm +import tyro + +videos_url = ( + "https://mammalnet.s3.amazonaws.com/trimmed_video.tar.gz" +) +labels_url = "https://mammalnet.s3.amazonaws.com/annotation.tar" + + +@dataclasses.dataclass(frozen=True) +class Args: + """Configure download options.""" + + dir: str = "." + """Where to save data.""" + + chunk_size_kb: int = 1 + """How many KB to download at a time before writing to file.""" + + videos: bool = True + """Whether to download videos [148GB].""" + labels: bool = True + """Whether to download labels.""" + + +def main(args: Args): + """Download MammalNet.""" + os.makedirs(args.dir, exist_ok=True) + chunk_size = int(args.chunk_size_kb * 1024) + videos_tar_path = os.path.join(args.dir, "trimmed_video.tar.gz") + labels_tar_path = os.path.join(args.dir, "annotation.tar") + videos_dir_name = "trimmed_video" + videos_dir_path = os.path.join(args.dir, videos_dir_name) + labels_dir_name = "annotation" + labels_dir_path = os.path.join(args.dir, labels_dir_name) + + if args.labels: + # Download labels + r = requests.get(labels_url, stream=True) + r.raise_for_status() + + with open(labels_tar_path, "wb") as fd: + for chunk in r.iter_content(chunk_size=chunk_size): + fd.write(chunk) + print(f"Downloaded labels: {labels_tar_path}.") + + if args.videos: + # Download videos. + r = requests.get(videos_url, stream=True) + r.raise_for_status() + + n_bytes = int(r.headers["content-length"]) + + with open(videos_tar_path, "wb") as fd: + for chunk in tqdm.tqdm( + r.iter_content(chunk_size=chunk_size), + total=n_bytes / chunk_size, + unit="b", + unit_scale=1, + unit_divisor=1024, + desc="Downloading videos", + ): + fd.write(chunk) + print(f"Downloaded videos: {videos_tar_path}.") + + with tarfile.open(labels_tar_path, "r") as tar: + tar.extractall(path=args.dir) + print(f"Extracted labels: {labels_dir_path}.") + + n_videos = 0 + all_video_files = [] + for csv_file in ["train.csv", "test.csv"]: + with open(os.path.join(labels_dir_path, "composition", csv_file)) as fd: + video_files = fd.readlines() + video_files = [video_file.split(" ")[0] for video_file in video_files] + video_files = [video_file[:13] + video_file[14:] for video_file in video_files] + all_video_files += video_files + n_videos += len(video_files) + + with tarfile.open(videos_tar_path, "r") as tar: + for member in tqdm.tqdm(tar, desc="Extracting videos", total=n_videos + 1): + if member.name in all_video_files or member.name == "trimmed_video": + tar.extract(member, path=args.dir) + print(f"Extracted videos: {videos_dir_path}.") + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/biobench/third_party_models.py b/biobench/third_party_models.py index dadebb4..b86c8aa 100644 --- a/biobench/third_party_models.py +++ b/biobench/third_party_models.py @@ -6,6 +6,7 @@ from torch import Tensor from biobench import interfaces +from biobench.dinov2_model import DINOv2Model logger = logging.getLogger("third_party") @@ -63,10 +64,28 @@ def __init__(self, ckpt: str, **kwargs): if ckpt.startswith("hf-hub:"): clip, self.img_transform = open_clip.create_model_from_pretrained(ckpt) else: - arch, ckpt = ckpt.split("/") - clip, self.img_transform = open_clip.create_model_from_pretrained( - arch, pretrained=ckpt, cache_dir=get_cache_dir() - ) + from open_clip.factory import load_state_dict + arch = ckpt.split("/")[0] + ckpt = "/".join(ckpt.split("/")[1:]) + if "facebook" in ckpt and "dino" in ckpt: + dino_model = DINOv2Model( + ckpt, + embed_dim=0 + ) + clip, _, self.img_transform = open_clip.create_model_and_transforms(arch) + clip.visual = dino_model + elif "dino" in ckpt: + dino_model = DINOv2Model( + "facebook/dinov2-base", + embed_dim=512 + ) + clip, _, self.img_transform = open_clip.create_model_and_transforms(arch) + clip.visual = dino_model + clip.load_state_dict(load_state_dict(ckpt)) + else: + clip, _, self.img_transform = open_clip.create_model_and_transforms( + arch, pretrained=ckpt, cache_dir=get_cache_dir() + ) self.model = clip.visual self.model.output_tokens = True # type: ignore diff --git a/pyproject.toml b/pyproject.toml index be70cde..2392744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,9 @@ dependencies = [ "submitit>=1.5.2", "pycocotools>=2.0.8", "gdown>=5.2.0", + "transformers>=0.39.1", "altair>=5.4.1", + "av<14.0", ] [tool.ruff.lint]