-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathapp.py
More file actions
162 lines (136 loc) · 5.36 KB
/
Copy pathapp.py
File metadata and controls
162 lines (136 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
### Chatbot to ask your SQL database
### Author: Ardya Dipta Nandaviri ardyadipta@gmail.com
### Date : Nov 28, 2024
import os
from dotenv import load_dotenv
import streamlit as st
import mysql.connector
import google.generativeai as genai
# Configure Google Gemini API
load_dotenv() # Load environment variables
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Utility Functions
def configure_streamlit():
"""Configure the Streamlit app settings."""
st.set_page_config(page_title="Chat with your database through Gemini")
st.header("Chat with your database through Gemini")
def get_gemini_response(question, prompt):
"""Get response from the Google Gemini model."""
try:
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content([prompt, question])
return response.text
except Exception as e:
st.error(f"Error with Google Gemini API: {e}")
return None
def connect_to_database():
"""Connect to the database using credentials from Streamlit session state."""
try:
connection = mysql.connector.connect(
host=st.session_state["Host"],
user=st.session_state["User"],
password=st.session_state["Password"],
database=st.session_state["Database"]
)
st.success(f"Connected to the database {st.session_state['Database']} successfully!")
return connection
except mysql.connector.Error as err:
st.error(f"Database connection failed: {err}")
return None
def read_sql_query(query):
"""Execute an SQL query and return the result."""
connection = connect_to_database()
if connection:
try:
cursor = connection.cursor()
cursor.execute(query)
rows = cursor.fetchall()
return rows
except mysql.connector.Error as err:
st.error(f"Error executing query: {err}")
return None
finally:
cursor.close()
connection.close()
return None
# Prompt Definitions
PROMPT_QUERY = """
You are an expert in converting English questions to MySQL query!
The SQL database "sales_database" has the table "sales_table" and the following columns:
- ORDERNUMBER: Unique identifier for sales
- QUANTITYORDERED: Number of products sold
- SALES: Amount of sales or revenue in USD
- PRODUCTLINE: Line of products
- ORDERDATE: Date of order
- YEAR_ID: Year of the order
- COUNTRY: Country
- CUSTOMERNAME: Name of customer
Example SQL command: SELECT COUNT(*) FROM sales;
The output should not include ``` or the word "sql".
"""
PROMPT_HUMANE_RESPONSE_TEMPLATE = """
You are a customer service agent.
Previously, you were asked: "{question}"
The query result from the database is: "{result}".
Please respond to the customer in a humane and friendly and detailed manner.
For example, if the question is "What is the biggest sales of product A?",
you should answer "The biggest sales of product A is 1000 USD".
"""
# Be careful and don't sugarcoat the number. Whether the number is rising or declining, tell the customer as it is.
#
# Main Application Logic
def main():
show_query = False # Show query for debugging
configure_streamlit()
# User input
question = st.text_input("Input: ", key="input")
if st.button("Ask the question") and question:
with st.spinner("Processing your query..."):
# Get SQL query from Gemini
sql_query = get_gemini_response(question, PROMPT_QUERY)
if sql_query:
if show_query:
st.subheader("Generated SQL Query:")
st.write(sql_query)
# Execute the SQL query
result = read_sql_query(sql_query)
if result:
if show_query:
st.subheader("Query Results:")
for row in result:
st.write(row)
# Generate humane response
humane_response = get_gemini_response(
question, PROMPT_HUMANE_RESPONSE_TEMPLATE.format(question=question, result=result)
)
st.subheader("AI Response:")
st.write(humane_response)
else:
st.error("No results returned from the query.")
# Sidebar for Database Configuration
with st.sidebar:
st.subheader("Database Settings")
st.text_input("Host", value="localhost", key="Host")
st.text_input("Port", value="3306", key="Port")
st.text_input("User", value="root", key="User")
st.text_input("Password", type="password", value="", key="Password")
st.text_input("Database", value="sales_database", key="Database")
if st.button("Test Connection"):
with st.spinner("Testing database connection..."):
if connect_to_database():
st.success("Connection successful!")
# Footer
st.markdown(
"""
<style>
.footer {position: fixed;left: 0;bottom: 0;width: 100%;background-color: #000;color: white;text-align: center;}
</style>
<div class='footer'>
<p>ardyadipta@gmail.com</p>
</div>
""",
unsafe_allow_html=True,
)
# Run the app
if __name__ == "__main__":
main()