-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_setup.py
More file actions
151 lines (120 loc) Β· 4.38 KB
/
test_setup.py
File metadata and controls
151 lines (120 loc) Β· 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
"""
Test script to verify the enhanced SLM setup
"""
import torch
import json
from tokenizers import Tokenizer
from model.slm import SLM
def test_tokenizer():
"""Test the subword tokenizer"""
print("π§ͺ Testing tokenizer...")
try:
# Load tokenizer
tokenizer = Tokenizer.from_file("data/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
print(f"β
Tokenizer loaded successfully")
print(f" Vocabulary size: {vocab_size}")
# Test encoding/decoding
test_text = "Hello, how are you today?"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded.ids)
print(f" Test text: '{test_text}'")
print(f" Encoded tokens: {len(encoded.tokens)}")
print(f" Decoded: '{decoded}'")
return True
except Exception as e:
print(f"β Tokenizer test failed: {e}")
return False
def test_model():
"""Test the enhanced model architecture"""
print("\nπ§ͺ Testing model...")
try:
# Load tokenizer for vocab size
tokenizer = Tokenizer.from_file("data/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
# Create model
model = SLM(
vocab_size=vocab_size,
embed_size=256,
hidden_size=512,
num_layers=3,
dropout=0.1
)
print(f"β
Model created successfully")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test forward pass
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
test_input = torch.randint(0, vocab_size, (2, 10)).to(device)
output, hidden = model(test_input)
print(f" Input shape: {test_input.shape}")
print(f" Output shape: {output.shape}")
print(f" Hidden state: {type(hidden)}")
return True
except Exception as e:
print(f"β Model test failed: {e}")
return False
def test_generation():
"""Test the generation capabilities"""
print("\nπ§ͺ Testing generation...")
try:
# Load tokenizer
tokenizer = Tokenizer.from_file("data/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
# Create model
model = SLM(
vocab_size=vocab_size,
embed_size=256,
hidden_size=512,
num_layers=3,
dropout=0.1
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Test generation
test_prompt = "Hello, how are you?"
input_ids = torch.tensor(tokenizer.encode(test_prompt).ids, dtype=torch.long).unsqueeze(0).to(device)
generated = model.generate(
input_ids=input_ids,
max_new_tokens=10,
temperature=0.8,
top_k=50,
top_p=0.9,
do_sample=True
)
generated_text = tokenizer.decode(generated[0].tolist())
print(f"β
Generation test successful")
print(f" Prompt: '{test_prompt}'")
print(f" Generated: '{generated_text}'")
return True
except Exception as e:
print(f"β Generation test failed: {e}")
return False
def main():
"""Run all tests"""
print("π Running Enhanced SLM Tests")
print("=" * 50)
tests = [
test_tokenizer,
test_model,
test_generation
]
results = []
for test in tests:
results.append(test())
print("\n" + "=" * 50)
print("π Test Results:")
if all(results):
print("β
All tests passed! Your enhanced SLM is ready to go!")
print("\nπ― Next steps:")
print("1. Run: python GET_Data.py (to get better dataset)")
print("2. Run: python tokenizer.py (to create subword tokenizer)")
print("3. Run: python train.py (to train the enhanced model)")
print("4. Run: python generate.py (to chat with your bot)")
else:
print("β Some tests failed. Please check the errors above.")
failed_tests = [i+1 for i, result in enumerate(results) if not result]
print(f" Failed tests: {failed_tests}")
if __name__ == "__main__":
main()