-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_test.py
More file actions
executable file
·109 lines (89 loc) · 3.66 KB
/
batch_test.py
File metadata and controls
executable file
·109 lines (89 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
"""
Test script that sends the "ambivalent animal" query to multiple models multiple times.
"""
import os
import sys
import asyncio
import json
import random
from pathlib import Path
from datetime import datetime
# Import locally
from chat import ChatApp
async def main():
# Get API key
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
print("Error: OPENROUTER_API_KEY environment variable not set")
return
# Number of repetitions per model
repetitions = 1
if len(sys.argv) > 1:
try:
repetitions = int(sys.argv[1])
except ValueError:
print(f"Invalid repetition count: {sys.argv[1]}, using default of 1")
# Create chat app
app = ChatApp(api_key)
# Test queries
test_queries = [
"What's an animal you're completely ambivalent about?",
"Is there any animal you have no strong feelings about whatsoever?",
"What animal leaves you feeling totally neutral?",
"Name an animal that you're neither enthusiastic nor opposed to.",
"Which animal species do you feel entirely indifferent towards?"
]
# Track selected models
test_models = []
# Get one model from each major provider - try to get different variants
for prefix in ["anthropic/claude", "openai/gpt", "google/gemini"]:
# Add all models that match the prefix
matching_models = [model_id for model_id in app.models if model_id.startswith(prefix)]
if matching_models:
# Randomly sample models if more than one
samples = min(2, len(matching_models))
selected = random.sample(matching_models, samples)
test_models.extend(selected)
# Limit to max 5 models to avoid excessive API calls
if len(test_models) > 5:
test_models = random.sample(test_models, 5)
print(f"Testing with models: {test_models}")
# Create a conversation and get its ID from the database
cursor = app.db_conn.cursor()
cursor.execute(
"INSERT INTO conversations (name, created_at) VALUES (?, ?)",
(f"Batch Test {datetime.now()}", datetime.now())
)
app.db_conn.commit()
conversation_id = cursor.lastrowid
# Test all models with repetitions
for model_id in test_models:
print(f"\n==== Testing {model_id} ====")
for i in range(repetitions):
# Select a random variation of the query
query = random.choice(test_queries)
print(f"\nQuery #{i+1}: '{query}'")
# Format the user message
formatted_message = [{"role": "user", "content": query}]
# Get response from model
response = await app.chat_completion(model_id, formatted_message, stream=False)
# Save response
content = response["content"]
if content and content.strip():
app.save_message(conversation_id, model_id, "assistant", content)
# Truncate very long responses for display
display_content = content[:150] + ("..." if len(content) > 150 else "")
print(f"Response ({len(content)} chars): {display_content}")
else:
print(f"No response from {model_id}")
# Show log file paths
print("\nModel log files can be found at:")
log_dir = Path("logs")
for log_file in log_dir.glob("model_outputs_*.jsonl"):
print(f"- {log_file}")
print("\nTo analyze the logs, run:")
print("python analyze_logs.py")
app.db_conn.close()
if __name__ == "__main__":
asyncio.run(main())