-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
125 lines (113 loc) · 5.62 KB
/
Copy pathmain.py
File metadata and controls
125 lines (113 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Importing necessary libraries and modules
import time
import pytorch_lightning as pl
import torch
import arg_parser # Argument parser
import utils # Custom utility functions
from lightning_model import CustomLightningModel, get_datasets_and_dataloaders
# Main execution block
if __name__ == '__main__':
print("""
.----------------. .----------------. .----------------. .----------------. .----------------.
| .--------------. | .--------------. | .--------------. | .--------------. | .--------------. |
| | ____ ____ | | | _____ | | | ___ | | | ________ | | | _____ | |
| ||_ \ / _|| | | |_ _| | | | .' _ '. | | | |_ ___ `. | | | |_ _| | |
| | | \/ | | | | | | | | | | (_) '___ | | | | | `. \ | | | | | | |
| | | |\ /| | | | | | | _ | | | .`___'/ _/ | | | | | | | | | | | | _ | |
| | _| |_\/_| |_ | | | _| |__/ | | | | | (___) \_ | | | _| |___.' / | | | _| |__/ | | |
| ||_____||_____|| | | |________| | | | `._____.\__| | | | |________.' | | | |________| | |
| | | | | | | | | | | | | | | |
| '--------------' | '--------------' | '--------------' | '--------------' | '--------------' |
'----------------' '----------------' '----------------' '----------------' '----------------'
""")
print("Welcome to the ML&DL Project! Please wait while the program is starting up...\n")
# Parse command line arguments
args = arg_parser.parse_arguments()
utils.print_program_config(args)
model = None
should_train = True
train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader = get_datasets_and_dataloaders(
args)
# Check if a checkpoint path was provided for evaluation
if args.test == 'latest':
# Load the latest checkpoint from the logs directory
model, checkpoint_path = utils.load_latest_checkpoint_model(val_dataset, test_dataset)
print(f"Loaded model from latest checkpoint: {checkpoint_path}")
should_train = False
elif args.test:
# Load the model from the specified checkpoint
model, checkpoint_path = utils.load_model_from_checkpoint(args.test, val_dataset, test_dataset)
print(f"Loaded model from checkpoint: {args.test}")
should_train = False
else:
# Initialize the model for training from scratch
print("No checkpoint provided, initializing model for training...")
# Instantiate a Lightning model with given parameters
model = CustomLightningModel(val_dataset, test_dataset, args)
initial_weights = {name: param.clone() for name, param in model.named_parameters()}
print("Model loaded successfully")
utils.print_model_configuration(model)
checkpoint_cb = utils.checkpoint_setup(args)
if torch.cuda.is_available():
accelerator = 'gpu'
devices = 1 # Assuming 1 GPU
precision = 16 # 16-bit precision GPU
print("Trainer configured with GPU.")
else:
accelerator = 'cpu'
devices = None # No device IDs for CPU training
precision = 32 # Use full precision on CPU
print("Trainer configured with CPU.")
trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
default_root_dir='./LOGS',
num_sanity_val_steps=0,
precision=precision, # Set precision based on whether GPU is available
max_epochs=args.max_epochs,
check_val_every_n_epoch=1,
callbacks=[checkpoint_cb],
reload_dataloaders_every_n_epochs=1,
log_every_n_steps=20,
enable_progress_bar=False
)
print("Trainer initialized, all ready.")
print("Starting validation...")
# Validate the model using the validation data loader
trainer.validate(model=model, dataloaders=val_loader)
print("Validation completed.")
if should_train:
training_start_time = time.time()
print("Starting training...")
# Train the model using the training data loader and validate using the validation data loader
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# Calculate and print training time
training_end_time = time.time()
training_duration = training_end_time - training_start_time
final_weights = {name: param.clone() for name, param in model.named_parameters()}
print("Training completed.")
# Test the model and print the summary
print("Starting testing...")
trainer.test(model=model, dataloaders=test_loader)
testing_end_time = time.time()
testing_duration = testing_end_time - training_end_time
print(f"Testing completed in {testing_duration:.2f} seconds.")
# Print a summary of the model's performance
print("\nModel Performance Summary:")
print(f"Training Duration: {training_duration:.2f} seconds")
print(f"Testing Duration: {testing_duration:.2f} seconds")
utils.print_program_config(args)
utils.print_model_configuration(model)
utils.print_weights_summary(initial_weights, final_weights)
else:
# Evaluate the model
print("Evaluating the model...")
trainer.validate(model=model, dataloaders=val_loader)
print("Validation completed.")
print("Starting testing...")
trainer.test(model=model, dataloaders=test_loader)
print("Testing completed.")
# Print a summary of the model's performance
print("\nModel Performance Summary:")
utils.print_program_config(args)
utils.print_model_configuration(model)