diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/README.md b/DeepLense_Classification_Transformers_Archil_Srivastava/README.md index 8ca8a1d..449c1ce 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/README.md +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/README.md @@ -32,6 +32,11 @@ ___Note__: Axion files have extra data corresponding to mass of axion used in si # __Training__ +Install dependencies from this module first: +```bash +python -m pip install -r requirements.txt +``` + Use the train.py script to train a particular model (using timm model name). The script will ask for a WandB login key, hence a WandB account is needed. Example: ```bash python3 train.py \ @@ -41,7 +46,8 @@ python3 train.py \ --tune \ --no-complex \ --device best \ ---project ml4sci_deeplense_final +--project ml4sci_deeplense_final \ +--entity $WANDB_ENTITY ``` | Arguments | Description | | :--- | :--- | @@ -60,6 +66,7 @@ python3 train.py \ | random_rotation | Random rotation for augmentation (in degreees) | | log_interval | Log interval for logging to weights and biases | | project | Project name in Weight and Biases +| entity | W&B entity/org (defaults to `$WANDB_ENTITY` when set) | | device | Device: cuda or mps or cpu or best | | seed | Random seed | @@ -70,7 +77,8 @@ Run evaluation of trained model on test sets using eval.py script. Pass the run_ python3 eval.py \ --run_id 1g9hi3n6 \ --device cuda \ --- project ml4sci_deeplense_final +--project ml4sci_deeplense_final \ +--entity $WANDB_ENTITY ```
diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py index a3f12b3..98a5d6f 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py @@ -2,10 +2,8 @@ from torch.utils.data import DataLoader from torch.nn import CrossEntropyLoss from torchmetrics.functional import auroc as auroc_fn, accuracy as accuracy_fn -from sklearn.metrics import ConfusionMatrixDisplay, roc_curve import wandb import numpy as np -import matplotlib.pyplot as plt import argparse import os @@ -55,24 +53,46 @@ def evaluate(model, data_loader, loss_fn, device): # Concatenate all results logits, y = torch.cat(logits), torch.cat(y) - loss.append(loss_fn(logits, y)) - accuracy.append(accuracy_fn(logits, y, num_classes=NUM_CLASSES)) - class_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average=None)) - macro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="macro")) + probs = torch.nn.functional.softmax(logits, dim=-1) + loss.append(loss_fn(logits, y).item()) + accuracy.append( + accuracy_fn(probs, y, task="multiclass", num_classes=NUM_CLASSES).item() + ) + class_auroc.append( + auroc_fn( + probs, y, task="multiclass", num_classes=NUM_CLASSES, average=None + ).cpu() + ) + # torchmetrics multiclass AUROC doesn't support micro-averaging directly. + # Compute micro-AUROC as binary AUROC over flattened one-vs-rest targets. + micro_auroc.append( + auroc_fn( + probs.reshape(-1), + torch.nn.functional.one_hot(y, num_classes=NUM_CLASSES) + .to(dtype=torch.int) + .reshape(-1), + task="binary", + ).item() + ) + macro_auroc.append( + auroc_fn( + probs, y, task="multiclass", num_classes=NUM_CLASSES, average="macro" + ).item() + ) result = { "ground_truth": y, "logits": logits, - "loss": np.mean(loss), - "accuracy": np.mean(accuracy), - "micro_auroc": np.mean(micro_auroc), - "macro_auroc": np.mean(macro_auroc), + "loss": float(np.mean(loss)), + "accuracy": float(np.mean(accuracy)), + "micro_auroc": float(np.mean(micro_auroc)), + "macro_auroc": float(np.mean(macro_auroc)), } # Class-wise AUROC class_auroc = class_auroc[0] for i, label in enumerate(LABELS): - result[f"{label}_auroc"] = class_auroc[i] + result[f"{label}_auroc"] = float(class_auroc[i].item()) return result @@ -82,8 +102,20 @@ def evaluate(model, data_loader, loss_fn, device): parser = argparse.ArgumentParser() # Wandb-specific params - parser.add_argument("--runid", type=str, help="ID of train run") + parser.add_argument( + "--runid", + "--run_id", + dest="runid", + type=str, + help="ID of train run", + ) parser.add_argument("--project", type=str, default="ml4sci_deeplense_final") + parser.add_argument( + "--entity", + type=str, + default=os.environ.get("WANDB_ENTITY"), + help="W&B entity/org. Defaults to $WANDB_ENTITY when set.", + ) # Device to run on parser.add_argument( @@ -91,10 +123,19 @@ def evaluate(model, data_loader, loss_fn, device): ) run_config = parser.parse_args() + from sklearn.metrics import ConfusionMatrixDisplay, roc_curve + import matplotlib.pyplot as plt + # Start wandb run - with wandb.init( - entity="_archil", project=run_config.project, id=run_config.runid, resume="must" - ): + wandb_init_kwargs = dict( + project=run_config.project, + id=run_config.runid, + resume="must", + ) + if run_config.entity: + wandb_init_kwargs["entity"] = run_config.entity + + with wandb.init(**wandb_init_kwargs): # Get best device on machine device = get_device(run_config.device) @@ -169,7 +210,7 @@ def evaluate(model, data_loader, loss_fn, device): roc_auc = dict() for idx, cls in enumerate(LABELS): class_truth = (metrics["ground_truth"].numpy() == idx).astype(int) - class_pred = torch.nn.functional.softmax(metrics["logits"]).numpy()[ + class_pred = torch.nn.functional.softmax(metrics["logits"], dim=-1).numpy()[ ..., idx ] fpr[idx], tpr[idx], _ = roc_curve(class_truth, class_pred) @@ -185,7 +226,7 @@ def evaluate(model, data_loader, loss_fn, device): disp = ConfusionMatrixDisplay.from_predictions( y_true=metrics["ground_truth"].numpy(), - y_pred=np.argmax(metrics["logits"], axis=-1), + y_pred=metrics["logits"].argmax(dim=-1).numpy(), display_labels=LABELS, cmap=plt.cm.Blues, colorbar=False, diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt b/DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt new file mode 100644 index 0000000..9be1e3c --- /dev/null +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/requirements.txt @@ -0,0 +1,15 @@ +# Usage: +# python -m pip install -r requirements.txt +# +# Note: +# - Installing PyTorch differs per OS/CUDA; install it separately if needed: +# https://pytorch.org/get-started/locally/ + +einops>=0.7 +matplotlib>=3.6 +numpy>=1.23 +scikit-learn>=1.2 +timm>=0.9 +torchmetrics>=1.2 +tqdm>=4.60 +wandb>=0.16 diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py index a5a6303..f3d6204 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py @@ -187,6 +187,12 @@ def train( ) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--project", type=str, default="ml4sci_deeplense_final") + parser.add_argument( + "--entity", + type=str, + default=os.environ.get("WANDB_ENTITY"), + help="W&B entity/org. Defaults to $WANDB_ENTITY when set.", + ) # Timm-Specific parameters parser.add_argument("--model_name", type=str, default="vit_base_patch16_224") @@ -233,13 +239,16 @@ def train( group = f"{group}-complex" # Start wandb run - with wandb.init( - entity="_archil", + wandb_init_kwargs = dict( project=run_config.project, config=run_config, group=group, job_type=f"{run_config.dataset}", - ): + ) + if run_config.entity: + wandb_init_kwargs["entity"] = run_config.entity + + with wandb.init(**wandb_init_kwargs): # Set random seed if run_config.seed: set_seed(run_config.seed) @@ -319,7 +328,7 @@ def train( # Scheduler if run_config.decay_lr: scheduler = CosineAnnealingWarmRestarts( - optimizer, T_0=15, T_mult=1, eta_min=1e-6, verbose=True + optimizer, T_0=15, T_mult=1, eta_min=1e-6 ) else: scheduler = None diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py b/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py index 759d3fe..908a2bc 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/utils.py @@ -29,7 +29,16 @@ def get_device(device): return xm.xla_device() if (device == "cuda" or device == "best") and torch.cuda.is_available(): return "cuda" - if (device == "mps" or device == "best") and torch.has_mps: + mps_backend = getattr(torch.backends, "mps", None) + mps_available = False + if mps_backend is not None: + is_available = getattr(mps_backend, "is_available", None) + is_built = getattr(mps_backend, "is_built", None) + if callable(is_available): + mps_available = bool(is_available()) + elif callable(is_built): + mps_available = bool(is_built()) + if (device == "mps" or device == "best") and mps_available: return "mps" if device == "cpu" or device == "best": return "cpu"