Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ __pycache__
graphs/
*.sqlite
logs/
*.pt
.locks
uv.lock
14 changes: 13 additions & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
plankton,
plantnet,
rarespecies,
mammalnet,
)

log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
Expand All @@ -68,6 +69,7 @@ class Args:
interfaces.ModelArgs("open-clip", "ViT-B-16/openai"),
interfaces.ModelArgs("open-clip", "ViT-B-16/laion400m_e32"),
interfaces.ModelArgs("open-clip", "hf-hub:imageomics/bioclip"),
interfaces.ModelArgs("open-clip", "ViT-B-16/facebook/dinov2-base"),
interfaces.ModelArgs("open-clip", "ViT-B-16-SigLIP/webli"),
interfaces.ModelArgs("timm-vit", "vit_base_patch14_reg4_dinov2.lvd142m"),
]
Expand Down Expand Up @@ -134,6 +136,10 @@ class Args:
default_factory=rarespecies.Args
)
"""Arguments for the Rare Species benchmark."""
mammalnet_run: bool = False
"""Whether to run the MammalNet benchmark."""
mammalnet_args: mammalnet.Args = dataclasses.field(default_factory=mammalnet.Args)
"""Arguments for the MammalNet benchmark."""

# Reporting and graphing.
report_to: str = os.path.join(".", "reports")
Expand Down Expand Up @@ -337,6 +343,12 @@ def main(args: Args):
)
job = executor.submit(rarespecies.benchmark, rarespecies_args, model_args)
jobs.append(job)
if args.mammalnet_run:
mammalnet_args = dataclasses.replace(
args.mammalnet_args, device=args.device, debug=args.debug
)
job = executor.submit(mammalnet.benchmark, mammalnet_args, model_args)
jobs.append(job)

logger.info("Submitted %d jobs.", len(jobs))

Expand Down Expand Up @@ -398,7 +410,7 @@ def plot_task(conn: sqlite3.Connection, task: str):
if not data:
return

xs = [row["model_ckpt"] for row in data]
xs = [row["model_ckpt"].split("/")[-1] for row in data]
ys = [row["mean_score"] for row in data]

yerr = np.array([ys, ys])
Expand Down
45 changes: 45 additions & 0 deletions biobench/dinov2_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
""" DINOv2 model adapter
"""
import beartype
from jaxtyping import jaxtyped
from collections import OrderedDict

import torch
import torch.nn as nn

from transformers import AutoModel


@jaxtyped(typechecker=beartype.beartype)
class DINOv2Model(nn.Module):
"""
Add adapter head to DINOv2.
"""
def __init__(
self,
model_name: str,
embed_dim: int,
):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.embed_dim = embed_dim

prev_chs = self.backbone.config.hidden_size
self.backbone.embeddings.mask_token.requires_grad_(False)

if embed_dim > 0:
head_layers = OrderedDict()
head_layers['drop'] = nn.Dropout(0.)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=False)
self.head = nn.Sequential(head_layers)
else:
self.head = nn.Identity()

def get_cast_dtype(self) -> torch.dtype:
return self.head.proj.weight.dtype

def forward(self, x):
_, x = self.backbone(x, return_dict=False)
if self.head is not None:
x = self.head(x)
return x
14 changes: 7 additions & 7 deletions biobench/fishnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class Args(interfaces.TaskArgs):
"""number of dataloader worker processes."""
log_every: int = 10
"""how often (number of epochs) to log progress."""
n_epochs: int = 100
n_epochs: int = 50
"""How many epochs to train the MLP classifier."""
learning_rate: float = 5e-4
learning_rate: float = 1e-4
"""The learning rate for training the MLP classifier."""
threshold: float = 0.5
"""The threshold to predicted "presence" rather than "absence"."""
Expand Down Expand Up @@ -113,10 +113,10 @@ def calc_macro_f1(examples: list[interfaces.Example]) -> float:
"""TODO: docs."""
y_pred = np.array([example.info["y_pred"] for example in examples])
y_true = np.array([example.info["y_true"] for example in examples])
score = sklearn.metrics.f1_score(
y_true, y_pred, average="macro", labels=np.unique(y_true)
)
return score.item()

correct = np.all(y_pred == y_true, axis=1)
acc = np.sum(correct) / len(y_pred)
return acc


@beartype.beartype
Expand Down Expand Up @@ -168,7 +168,7 @@ def benchmark(
if (epoch + 1) % args.log_every == 0:
examples = evaluate(args, classifier, test_loader)
score = calc_macro_f1(examples)
logger.info("Epoch %d/%d: %.3f", epoch + 1, args.n_epochs, score)
logger.info(f"Epoch {epoch + 1}/{args.n_epochs}: {score:.3f}")

return model_args, interfaces.TaskReport(
"FishNet", examples, calc_mean_score=calc_macro_f1
Expand Down
Loading