-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
70 lines (57 loc) · 3.02 KB
/
eval.py
File metadata and controls
70 lines (57 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
# Created: 2023-08-09 10:28
# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
# Author: Qingwen Zhang (https://kin-zhang.github.io/)
#
# This file is part of DeFlow (https://github.com/KTH-RPL/DeFlow).
# If you find this repo helpful, please cite the respective publication as
# listed on the above website.
# Description: Output the evaluation results, go for local evaluation or online evaluation
"""
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig
import hydra, wandb, os, sys
from hydra.core.hydra_config import HydraConfig
from scripts.network.dataloader import HDF5Dataset
from scripts.pl_model import ModelWrapper
def precheck_cfg_valid(cfg):
if os.path.exists(cfg.dataset_path + f"/{cfg.av2_mode}") is False:
raise ValueError(f"Dataset {cfg.dataset_path}/{cfg.av2_mode} does not exist. Please check the path.")
if cfg.supervised_flag not in [True, False]:
raise ValueError(f"Supervised flag {cfg.supervised_flag} is not valid. Please set it to True or False.")
if cfg.leaderboard_version not in [1, 2]:
raise ValueError(f"Leaderboard version {cfg.leaderboard_version} is not valid. Please set it to 1 or 2.")
return cfg
@hydra.main(version_base=None, config_path="conf", config_name="eval")
def main(cfg):
pl.seed_everything(cfg.seed, workers=True)
output_dir = HydraConfig.get().runtime.output_dir
if not os.path.exists(cfg.checkpoint):
print(f"Checkpoint {cfg.checkpoint} does not exist. Need checkpoints for evaluation.")
sys.exit(1)
torch_load_ckpt = torch.load(cfg.checkpoint)
checkpoint_params = DictConfig(torch_load_ckpt["hyper_parameters"])
cfg.output = checkpoint_params.cfg.output + f"-e{torch_load_ckpt['epoch']}-{cfg.av2_mode}-v{cfg.leaderboard_version}"
cfg.model.update(checkpoint_params.cfg.model)
mymodel = ModelWrapper.load_from_checkpoint(cfg.checkpoint, cfg=cfg, eval=True)
print(
f"\n---LOG[eval]: Loaded model from {cfg.checkpoint}. The backbone network is {checkpoint_params.cfg.model.name}.\n")
wandb_logger = WandbLogger(save_dir=output_dir,
entity="kth-rpl",
project=f"deflow-eval",
name=f"{cfg.output}",
offline=(cfg.wandb_mode == "offline"))
trainer = pl.Trainer(logger=wandb_logger, devices=1)
# NOTE(Qingwen): search & check: def eval_only_step_(self, batch, res_dict)
trainer.validate(model=mymodel, \
dataloaders=DataLoader( \
HDF5Dataset(cfg.dataset_path + f"/{cfg.av2_mode}", \
n_frames=checkpoint_params.cfg.num_frames if 'num_frames' in checkpoint_params.cfg else 2, \
eval=True, leaderboard_version=cfg.leaderboard_version), \
batch_size=1, shuffle=False))
wandb.finish()
if __name__ == "__main__":
main()