Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions DeepLense_Classification_Transformers_Archil_Srivastava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ ___Note__: Axion files have extra data corresponding to mass of axion used in si

# __Training__

Install dependencies from this module first:
```bash
python -m pip install -r requirements.txt
```

Use the train.py script to train a particular model (using timm model name). The script will ask for a WandB login key, hence a WandB account is needed. Example:
```bash
python3 train.py \
Expand All @@ -41,7 +46,8 @@ python3 train.py \
--tune \
--no-complex \
--device best \
--project ml4sci_deeplense_final
--project ml4sci_deeplense_final \
--entity $WANDB_ENTITY
```
| Arguments | Description |
| :--- | :--- |
Expand All @@ -60,6 +66,7 @@ python3 train.py \
| random_rotation | Random rotation for augmentation (in degreees) |
| log_interval | Log interval for logging to weights and biases |
| project | Project name in Weight and Biases
| entity | W&B entity/org (defaults to `$WANDB_ENTITY` when set) |
| device | Device: cuda or mps or cpu or best |
| seed | Random seed |

Expand All @@ -70,7 +77,8 @@ Run evaluation of trained model on test sets using eval.py script. Pass the run_
python3 eval.py \
--run_id 1g9hi3n6 \
--device cuda \
-- project ml4sci_deeplense_final
--project ml4sci_deeplense_final \
--entity $WANDB_ENTITY
```

<br>
Expand Down
75 changes: 58 additions & 17 deletions DeepLense_Classification_Transformers_Archil_Srivastava/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torchmetrics.functional import auroc as auroc_fn, accuracy as accuracy_fn
from sklearn.metrics import ConfusionMatrixDisplay, roc_curve
import wandb
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os

Expand Down Expand Up @@ -55,24 +53,46 @@ def evaluate(model, data_loader, loss_fn, device):

# Concatenate all results
logits, y = torch.cat(logits), torch.cat(y)
loss.append(loss_fn(logits, y))
accuracy.append(accuracy_fn(logits, y, num_classes=NUM_CLASSES))
class_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average=None))
macro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="macro"))
probs = torch.nn.functional.softmax(logits, dim=-1)
loss.append(loss_fn(logits, y).item())
accuracy.append(
accuracy_fn(probs, y, task="multiclass", num_classes=NUM_CLASSES).item()
)
class_auroc.append(
auroc_fn(
probs, y, task="multiclass", num_classes=NUM_CLASSES, average=None
).cpu()
)
# torchmetrics multiclass AUROC doesn't support micro-averaging directly.
# Compute micro-AUROC as binary AUROC over flattened one-vs-rest targets.
micro_auroc.append(
auroc_fn(
probs.reshape(-1),
torch.nn.functional.one_hot(y, num_classes=NUM_CLASSES)
.to(dtype=torch.int)
.reshape(-1),
task="binary",
).item()
)
macro_auroc.append(
auroc_fn(
probs, y, task="multiclass", num_classes=NUM_CLASSES, average="macro"
).item()
)

result = {
"ground_truth": y,
"logits": logits,
"loss": np.mean(loss),
"accuracy": np.mean(accuracy),
"micro_auroc": np.mean(micro_auroc),
"macro_auroc": np.mean(macro_auroc),
"loss": float(np.mean(loss)),
"accuracy": float(np.mean(accuracy)),
"micro_auroc": float(np.mean(micro_auroc)),
"macro_auroc": float(np.mean(macro_auroc)),
}

# Class-wise AUROC
class_auroc = class_auroc[0]
for i, label in enumerate(LABELS):
result[f"{label}_auroc"] = class_auroc[i]
result[f"{label}_auroc"] = float(class_auroc[i].item())

return result

Expand All @@ -82,19 +102,40 @@ def evaluate(model, data_loader, loss_fn, device):
parser = argparse.ArgumentParser()

# Wandb-specific params
parser.add_argument("--runid", type=str, help="ID of train run")
parser.add_argument(
"--runid",
"--run_id",
dest="runid",
type=str,
help="ID of train run",
)
parser.add_argument("--project", type=str, default="ml4sci_deeplense_final")
parser.add_argument(
"--entity",
type=str,
default=os.environ.get("WANDB_ENTITY"),
help="W&B entity/org. Defaults to $WANDB_ENTITY when set.",
)

# Device to run on
parser.add_argument(
"--device", choices=["cpu", "mps", "cuda", "best"], default="best"
)
run_config = parser.parse_args()

from sklearn.metrics import ConfusionMatrixDisplay, roc_curve
import matplotlib.pyplot as plt

# Start wandb run
with wandb.init(
entity="_archil", project=run_config.project, id=run_config.runid, resume="must"
):
wandb_init_kwargs = dict(
project=run_config.project,
id=run_config.runid,
resume="must",
)
if run_config.entity:
wandb_init_kwargs["entity"] = run_config.entity

with wandb.init(**wandb_init_kwargs):
# Get best device on machine
device = get_device(run_config.device)

Expand Down Expand Up @@ -169,7 +210,7 @@ def evaluate(model, data_loader, loss_fn, device):
roc_auc = dict()
for idx, cls in enumerate(LABELS):
class_truth = (metrics["ground_truth"].numpy() == idx).astype(int)
class_pred = torch.nn.functional.softmax(metrics["logits"]).numpy()[
class_pred = torch.nn.functional.softmax(metrics["logits"], dim=-1).numpy()[
..., idx
]
fpr[idx], tpr[idx], _ = roc_curve(class_truth, class_pred)
Expand All @@ -185,7 +226,7 @@ def evaluate(model, data_loader, loss_fn, device):

disp = ConfusionMatrixDisplay.from_predictions(
y_true=metrics["ground_truth"].numpy(),
y_pred=np.argmax(metrics["logits"], axis=-1),
y_pred=metrics["logits"].argmax(dim=-1).numpy(),
display_labels=LABELS,
cmap=plt.cm.Blues,
colorbar=False,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Usage:
# python -m pip install -r requirements.txt
#
# Note:
# - Installing PyTorch differs per OS/CUDA; install it separately if needed:
# https://pytorch.org/get-started/locally/

einops>=0.7
matplotlib>=3.6
numpy>=1.23
scikit-learn>=1.2
timm>=0.9
torchmetrics>=1.2
tqdm>=4.60
wandb>=0.16
17 changes: 13 additions & 4 deletions DeepLense_Classification_Transformers_Archil_Srivastava/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def train(
)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--project", type=str, default="ml4sci_deeplense_final")
parser.add_argument(
"--entity",
type=str,
default=os.environ.get("WANDB_ENTITY"),
help="W&B entity/org. Defaults to $WANDB_ENTITY when set.",
)

# Timm-Specific parameters
parser.add_argument("--model_name", type=str, default="vit_base_patch16_224")
Expand Down Expand Up @@ -233,13 +239,16 @@ def train(
group = f"{group}-complex"

# Start wandb run
with wandb.init(
entity="_archil",
wandb_init_kwargs = dict(
project=run_config.project,
config=run_config,
group=group,
job_type=f"{run_config.dataset}",
):
)
if run_config.entity:
wandb_init_kwargs["entity"] = run_config.entity

with wandb.init(**wandb_init_kwargs):
# Set random seed
if run_config.seed:
set_seed(run_config.seed)
Expand Down Expand Up @@ -319,7 +328,7 @@ def train(
# Scheduler
if run_config.decay_lr:
scheduler = CosineAnnealingWarmRestarts(
optimizer, T_0=15, T_mult=1, eta_min=1e-6, verbose=True
optimizer, T_0=15, T_mult=1, eta_min=1e-6
)
else:
scheduler = None
Expand Down
11 changes: 10 additions & 1 deletion DeepLense_Classification_Transformers_Archil_Srivastava/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ def get_device(device):
return xm.xla_device()
if (device == "cuda" or device == "best") and torch.cuda.is_available():
return "cuda"
if (device == "mps" or device == "best") and torch.has_mps:
mps_backend = getattr(torch.backends, "mps", None)
mps_available = False
if mps_backend is not None:
is_available = getattr(mps_backend, "is_available", None)
is_built = getattr(mps_backend, "is_built", None)
if callable(is_available):
mps_available = bool(is_available())
elif callable(is_built):
mps_available = bool(is_built())
if (device == "mps" or device == "best") and mps_available:
return "mps"
if device == "cpu" or device == "best":
return "cpu"
Expand Down