-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
124 lines (96 loc) · 4.8 KB
/
inference.py
File metadata and controls
124 lines (96 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import numpy as np
import transformers
import yaml
import argparse
import torch
from time import time
from tqdm import tqdm
import pandas as pd
from utils.utils import load_configs, test_gpu_cuda, get_logging, load_checkpoints_inference, prepare_saving_dir
from dataset import prepare_inference_dataloader
from model import prepare_models
from accelerate import Accelerator
from accelerate import DataLoaderConfiguration
def inference(net, dataloader, configs, decoder_tokenizer, mode):
counter = 0
results = []
net.eval()
inference_config = {
"beam_width": (configs.beam_search.beam_width,),
"temperature": (configs.beam_search.temperature,),
"top_k": configs.beam_search.top_k
}
for i, data in enumerate(tqdm(dataloader, desc=f'{mode}', total=len(dataloader),
leave=False, disable=not configs.tqdm_progress_bar)):
protein_sequence, target, molecule_sequence, sequence, task_name = data
batch = {"protein_sequence": protein_sequence, "molecule_sequence": molecule_sequence,
"target_input": target}
with torch.inference_mode():
preds = net(batch, mode=configs.inference_type, inference_config=inference_config)
preds = preds.detach().cpu().numpy().tolist()[0]
preds = [decoder_tokenizer.index_token_dict[pred] for pred in preds[2:-1]]
results.append([sequence[0], task_name[0], configs.merging_character.join(preds)])
counter += 1
return results
def main(dict_inference_config, dict_config, inference_config_file_path):
configs = load_configs(dict_config)
inference_configs = load_configs(dict_inference_config, inference=True)
transformers.logging.set_verbosity_error()
if isinstance(configs.fix_seed, int):
torch.manual_seed(configs.fix_seed)
torch.random.manual_seed(configs.fix_seed)
np.random.seed(configs.fix_seed)
torch.cuda.empty_cache()
test_gpu_cuda()
result_path, _ = prepare_saving_dir(inference_configs, inference_config_file_path, inference_result=True)
logging = get_logging(result_path)
inference_dataloader, protein_encoder_tokenizer, molecule_encoder_tokenizer, decoder_tokenizer = prepare_inference_dataloader(
configs, inference_configs)
logging.info('preparing dataloaders are done')
dataloader_config = DataLoaderConfiguration(dispatch_batches=False)
accelerator = Accelerator(
mixed_precision=configs.train_settings.mixed_precision,
gradient_accumulation_steps=configs.train_settings.grad_accumulation,
dataloader_config=dataloader_config
)
net = prepare_models(configs, protein_encoder_tokenizer, decoder_tokenizer, logging, accelerator, inference=True)
logging.info('preparing model is done')
net = load_checkpoints_inference(inference_configs.checkpoint_path, logging, net)
logging.info('loading model weights is done')
# Compile model to predict faster and efficiently on GPU
if inference_configs.compile_model:
net = torch.compile(net)
if accelerator.is_main_process:
logging.info('compile model is done')
net, inference_dataloader = accelerator.prepare(net, inference_dataloader)
net.to(accelerator.device)
if accelerator.is_main_process:
inference_steps = len(inference_dataloader)
logging.info(f'number of inference steps: {int(inference_steps)}')
torch.cuda.empty_cache()
start_time = time()
inference_results = inference(net, inference_dataloader, inference_configs, decoder_tokenizer, mode='inference')
end_time = time()
inference_time = end_time - start_time
if accelerator.is_main_process:
logging.info(
f'inference dataset 1 - steps {len(inference_dataloader)} - time {np.round(inference_time, 2)}s')
inference_results = pd.DataFrame(inference_results, columns=['input', 'task_name', 'predicted'])
inference_results.to_csv(os.path.join(result_path, 'inference_results.csv'), index=False)
accelerator.free_memory()
del net, inference_dataloader, accelerator
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Doing inference of a pre-trained Prot2Token model.")
parser.add_argument("--config_path", "-c", help="The location of inference config file",
default='./configs/inference_config.yaml')
args = parser.parse_args()
inference_config_path = args.config_path
with open(inference_config_path) as file:
inference_config_file = yaml.full_load(file)
result_config_path = inference_config_file['result_config_path']
with open(result_config_path) as file:
result_config_file = yaml.full_load(file)
main(inference_config_file, result_config_file, inference_config_path)
print('done!')