-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrainModel.py
More file actions
50 lines (44 loc) · 1.47 KB
/
retrainModel.py
File metadata and controls
50 lines (44 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import TrainingArguments
from transformers import Trainer
from utils.dataset import *
import pandas as pd
def freeze_layers(model, n):
for i, param in enumerate(model.parameters()):
if i < n:
param.requires_grad = False
return model
def retrain_vit_model(model, prepared_ds, epochs=2, output_dir="/content/drive/MyDrive/UCSB 2023/Code/ViT/saved_models"):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=epochs,
#fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
trainer = Trainer(
model=model.to(device),
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["test"],
tokenizer=feature_extractor,
)
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
best_model_dir = trainer.state.best_model_checkpoint
updated_model = model.__class__.from_pretrained(best_model_dir)
return updated_model