diff --git a/python-package/examples/app.py b/python-package/examples/app.py
index 926622cef..5e628f431 100644
--- a/python-package/examples/app.py
+++ b/python-package/examples/app.py
@@ -4,6 +4,7 @@
from shiny import App, render, ui
import querychat
+from querychat.datasource import DataFrameSource
titanic = load_dataset("titanic")
@@ -14,8 +15,7 @@
# 1. Configure querychat
querychat_config = querychat.init(
- titanic,
- "titanic",
+ DataFrameSource(titanic, "titanic"),
greeting=greeting,
data_description=data_desc,
)
diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml
index c709ee050..4ca437a27 100644
--- a/python-package/pyproject.toml
+++ b/python-package/pyproject.toml
@@ -20,6 +20,12 @@ dependencies = [
"htmltools",
"chatlas",
"narwhals",
+ "chevron",
+]
+
+[project.optional-dependencies]
+sqlalchemy = [
+ "sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API
]
[project.urls]
diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py
new file mode 100644
index 000000000..1fee9b9c6
--- /dev/null
+++ b/python-package/querychat/datasource.py
@@ -0,0 +1,271 @@
+from __future__ import annotations
+
+from typing import ClassVar, Protocol
+
+import duckdb
+import narwhals as nw
+import pandas as pd
+from sqlalchemy import inspect, text
+from sqlalchemy.engine import Connection, Engine
+from sqlalchemy.sql import sqltypes
+
+
+class DataSource(Protocol):
+ db_engine: ClassVar[str]
+
+ def get_schema(self) -> str:
+ """Return schema information about the table as a string.
+
+ Returns:
+ A string containing the schema information in a format suitable for
+ prompting an LLM about the data structure
+ """
+ ...
+
+ def execute_query(self, query: str) -> pd.DataFrame:
+ """Execute SQL query and return results as DataFrame.
+
+ Args:
+ query: SQL query to execute
+
+ Returns:
+ Query results as a pandas DataFrame
+ """
+ ...
+
+ def get_data(self) -> pd.DataFrame:
+ """Return the unfiltered data as a DataFrame.
+
+ Returns:
+ The complete dataset as a pandas DataFrame
+ """
+ ...
+
+
+class DataFrameSource:
+ """A DataSource implementation that wraps a pandas DataFrame using DuckDB."""
+
+ db_engine: ClassVar[str] = "DuckDB"
+
+ def __init__(self, df: pd.DataFrame, table_name: str):
+ """Initialize with a pandas DataFrame.
+
+ Args:
+ df: The DataFrame to wrap
+ table_name: Name of the table in SQL queries
+ """
+ self._conn = duckdb.connect(database=":memory:")
+ self._df = df
+ self._table_name = table_name
+ self._conn.register(table_name, df)
+
+ def get_schema(self, categorical_threshold: int = 10) -> str:
+ """Generate schema information from DataFrame.
+
+ Args:
+ table_name: Name to use for the table in schema description
+ categorical_threshold: Maximum number of unique values for a text column
+ to be considered categorical
+
+ Returns:
+ String describing the schema
+ """
+ ndf = nw.from_native(self._df)
+
+ schema = [f"Table: {self._table_name}", "Columns:"]
+
+ for column in ndf.columns:
+ # Map pandas dtypes to SQL-like types
+ dtype = ndf[column].dtype
+ if dtype.is_integer():
+ sql_type = "INTEGER"
+ elif dtype.is_float():
+ sql_type = "FLOAT"
+ elif dtype == nw.Boolean:
+ sql_type = "BOOLEAN"
+ elif dtype == nw.Datetime:
+ sql_type = "TIME"
+ elif dtype == nw.Date:
+ sql_type = "DATE"
+ else:
+ sql_type = "TEXT"
+
+ column_info = [f"- {column} ({sql_type})"]
+
+ # For TEXT columns, check if they're categorical
+ if sql_type == "TEXT":
+ unique_values = ndf[column].drop_nulls().unique()
+ if unique_values.len() <= categorical_threshold:
+ categories = unique_values.to_list()
+ categories_str = ", ".join([f"'{c}'" for c in categories])
+ column_info.append(f" Categorical values: {categories_str}")
+
+ # For numeric columns, include range
+ elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]:
+ rng = ndf[column].min(), ndf[column].max()
+ if rng[0] is None and rng[1] is None:
+ column_info.append(" Range: NULL to NULL")
+ else:
+ column_info.append(f" Range: {rng[0]} to {rng[1]}")
+
+ schema.extend(column_info)
+
+ return "\n".join(schema)
+
+ def execute_query(self, query: str) -> pd.DataFrame:
+ """Execute query using DuckDB.
+
+ Args:
+ query: SQL query to execute
+
+ Returns:
+ Query results as pandas DataFrame
+ """
+ return self._conn.execute(query).df()
+
+ def get_data(self) -> pd.DataFrame:
+ """Return the unfiltered data as a DataFrame.
+
+ Returns:
+ The complete dataset as a pandas DataFrame
+ """
+ return self._df.copy()
+
+
+class SQLAlchemySource:
+ """A DataSource implementation that supports multiple SQL databases via SQLAlchemy.
+
+ Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks.
+ """
+
+ db_engine: ClassVar[str] = "SQLAlchemy"
+
+ def __init__(self, engine: Engine, table_name: str):
+ """Initialize with a SQLAlchemy engine.
+
+ Args:
+ engine: SQLAlchemy engine
+ table_name: Name of the table to query
+ """
+ self._engine = engine
+ self._table_name = table_name
+
+ # Validate table exists
+ inspector = inspect(self._engine)
+ if not inspector.has_table(table_name):
+ raise ValueError(f"Table '{table_name}' not found in database")
+
+ def get_schema(self) -> str:
+ """Generate schema information from database table.
+
+ Returns:
+ String describing the schema
+ """
+ inspector = inspect(self._engine)
+ columns = inspector.get_columns(self._table_name)
+
+ schema = [f"Table: {self._table_name}", "Columns:"]
+
+ for col in columns:
+ # Get SQL type name
+ sql_type = self._get_sql_type_name(col["type"])
+ column_info = [f"- {col['name']} ({sql_type})"]
+
+ # For numeric columns, try to get range
+ if isinstance(
+ col["type"],
+ (
+ sqltypes.Integer,
+ sqltypes.Numeric,
+ sqltypes.Float,
+ sqltypes.Date,
+ sqltypes.Time,
+ sqltypes.DateTime,
+ sqltypes.BigInteger,
+ sqltypes.SmallInteger,
+ # sqltypes.Interval,
+ ),
+ ):
+ try:
+ query = text(
+ f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}"
+ )
+ with self._get_connection() as conn:
+ result = conn.execute(query).fetchone()
+ if result and result[0] is not None and result[1] is not None:
+ column_info.append(f" Range: {result[0]} to {result[1]}")
+ except Exception:
+ pass # Skip range info if query fails
+
+ # For string/text columns, check if categorical
+ elif isinstance(
+ col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum)
+ ):
+ try:
+ count_query = text(
+ f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}"
+ )
+ with self._get_connection() as conn:
+ distinct_count = conn.execute(count_query).scalar()
+ if distinct_count and distinct_count <= 10:
+ values_query = text(
+ f"SELECT DISTINCT {col['name']} FROM {self._table_name} "
+ f"WHERE {col['name']} IS NOT NULL"
+ )
+ values = [
+ str(row[0])
+ for row in conn.execute(values_query).fetchall()
+ ]
+ values_str = ", ".join([f"'{v}'" for v in values])
+ column_info.append(f" Categorical values: {values_str}")
+ except Exception:
+ pass # Skip categorical info if query fails
+
+ schema.extend(column_info)
+
+ return "\n".join(schema)
+
+ def execute_query(self, query: str) -> pd.DataFrame:
+ """Execute SQL query and return results as DataFrame.
+
+ Args:
+ query: SQL query to execute
+
+ Returns:
+ Query results as pandas DataFrame
+ """
+ with self._get_connection() as conn:
+ return pd.read_sql_query(text(query), conn)
+
+ def get_data(self) -> pd.DataFrame:
+ """Return the unfiltered data as a DataFrame.
+
+ Returns:
+ The complete dataset as a pandas DataFrame
+ """
+ return self.execute_query(f"SELECT * FROM {self._table_name}")
+
+ def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str:
+ """Convert SQLAlchemy type to SQL type name."""
+ if isinstance(type_, sqltypes.Integer):
+ return "INTEGER"
+ elif isinstance(type_, sqltypes.Float):
+ return "FLOAT"
+ elif isinstance(type_, sqltypes.Numeric):
+ return "NUMERIC"
+ elif isinstance(type_, sqltypes.Boolean):
+ return "BOOLEAN"
+ elif isinstance(type_, sqltypes.DateTime):
+ return "TIMESTAMP"
+ elif isinstance(type_, sqltypes.Date):
+ return "DATE"
+ elif isinstance(type_, sqltypes.Time):
+ return "TIME"
+ elif isinstance(type_, (sqltypes.String, sqltypes.Text)):
+ return "TEXT"
+ else:
+ return type_.__class__.__name__.upper()
+
+ def _get_connection(self) -> Connection:
+ """Get a connection to use for queries."""
+ return self._engine.connect()
diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md
index 62d1ea17f..5155ae185 100644
--- a/python-package/querychat/prompt/prompt.md
+++ b/python-package/querychat/prompt/prompt.md
@@ -4,13 +4,19 @@ It's important that you get clear, unambiguous instructions from the user, so if
The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance.
-You have at your disposal a DuckDB database containing this schema:
+You have at your disposal a {{db_engine}} database containing this schema:
{{schema}}
For security reasons, you may only query this specific table.
+{{#data_description}}
+Additional helpful info about the data:
+
+
{{data_description}}
+
+{{/data_description}}
There are several tasks you may be asked to do:
@@ -19,7 +25,7 @@ There are several tasks you may be asked to do:
The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success.
* **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error.
-* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions.
+* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database.
* The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`.
* Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed.
* When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't.
@@ -74,7 +80,11 @@ Example of question answering:
If the user provides a vague help request, like "Help" or "Show me instructions", describe your own capabilities in a helpful way, including examples of questions they can ask. Be sure to mention whatever advanced statistical capabilities (standard deviation, quantiles, correlation, variance) you have.
-## DuckDB SQL tips
+## SQL tips
+
+* The SQL engine is {{db_engine}}.
+
+* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions.
* `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`.
diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py
index 4e492fb19..167560c51 100644
--- a/python-package/querychat/querychat.py
+++ b/python-package/querychat/querychat.py
@@ -1,125 +1,73 @@
from __future__ import annotations
-import sys
import os
-import re
-import pandas as pd
-import duckdb
-import json
+import sys
from functools import partial
-from typing import List, Dict, Any, Callable, Optional, Union, Protocol
+from typing import Any, Dict, Optional, Protocol
import chatlas
-from htmltools import TagList, tags, HTML
-from shiny import module, reactive, ui, Inputs, Outputs, Session
+import chevron
import narwhals as nw
-from narwhals.typing import IntoFrame
+from shiny import Inputs, Outputs, Session, module, reactive, ui
+
+from .datasource import DataSource
+
+
+class CreateChatCallback(Protocol):
+ def __call__(self, system_prompt: str) -> chatlas.Chat: ...
+
+
+class QueryChatConfig:
+ """
+ Configuration class for querychat.
+ """
+
+ def __init__(
+ self,
+ data_source: DataSource,
+ system_prompt: str,
+ greeting: Optional[str],
+ create_chat_callback: CreateChatCallback,
+ ):
+ self.data_source = data_source
+ self.system_prompt = system_prompt
+ self.greeting = greeting
+ self.create_chat_callback = create_chat_callback
def system_prompt(
- df: IntoFrame,
- table_name: str,
+ data_source: DataSource,
data_description: Optional[str] = None,
extra_instructions: Optional[str] = None,
- categorical_threshold: int = 10,
) -> str:
"""
- Create a system prompt for the chat model based on a data frame's
+ Create a system prompt for the chat model based on a data source's
schema and optional additional context and instructions.
Args:
- df: A DataFrame to generate schema information from
- table_name: A string containing the name of the table in SQL queries
+ data_source: A data source to generate schema information from
data_description: Optional description of the data, in plain text or Markdown format
extra_instructions: Optional additional instructions for the chat model, in plain text or Markdown format
- categorical_threshold: The maximum number of unique values for a text column to be considered categorical
Returns:
A string containing the system prompt for the chat model
"""
- schema = df_to_schema(df, table_name, categorical_threshold)
# Read the prompt file
prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md")
with open(prompt_path, "r") as f:
prompt_text = f.read()
- # Simple template replacement (a more robust template engine could be used)
- if data_description:
- data_description_section = (
- "Additional helpful info about the data:\n\n"
- "\n"
- f"{data_description}\n"
- ""
- )
- else:
- data_description_section = ""
-
- # Replace variables in the template
- prompt_text = prompt_text.replace("{{schema}}", schema)
- prompt_text = prompt_text.replace("{{data_description}}", data_description_section)
- prompt_text = prompt_text.replace(
- "{{extra_instructions}}", extra_instructions or ""
+ return chevron.render(
+ prompt_text,
+ {
+ "db_engine": data_source.db_engine,
+ "schema": data_source.get_schema(),
+ "data_description": data_description,
+ "extra_instructions": extra_instructions,
+ },
)
- return prompt_text
-
-
-def df_to_schema(df: IntoFrame, table_name: str, categorical_threshold: int) -> str:
- """
- Convert a DataFrame schema to a string representation for the system prompt.
-
- Args:
- df: The DataFrame to extract schema from
- table_name: The name of the table in SQL queries
- categorical_threshold: The maximum number of unique values for a text column to be considered categorical
-
- Returns:
- A string containing the schema information
- """
-
- ndf = nw.from_native(df)
-
- schema = [f"Table: {table_name}", "Columns:"]
-
- for column in ndf.columns:
- # Map pandas dtypes to SQL-like types
- dtype = ndf[column].dtype
- if dtype.is_integer():
- sql_type = "INTEGER"
- elif dtype.is_float():
- sql_type = "FLOAT"
- elif dtype == nw.Boolean:
- sql_type = "BOOLEAN"
- elif dtype == nw.Datetime:
- sql_type = "TIME"
- elif dtype == nw.Date:
- sql_type = "DATE"
- else:
- sql_type = "TEXT"
-
- column_info = [f"- {column} ({sql_type})"]
-
- # For TEXT columns, check if they're categorical
- if sql_type == "TEXT":
- unique_values = ndf[column].drop_nulls().unique()
- if unique_values.len() <= categorical_threshold:
- categories = unique_values.to_list()
- categories_str = ", ".join([f"'{c}'" for c in categories])
- column_info.append(f" Categorical values: {categories_str}")
-
- # For numeric columns, include range
- elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]:
- rng = ndf[column].min(), ndf[column].max()
- if rng[0] is None and rng[1] is None:
- column_info.append(" Range: NULL to NULL")
- else:
- column_info.append(f" Range: {rng[0]} to {rng[1]}")
-
- schema.extend(column_info)
-
- return "\n".join(schema)
-
def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
"""
@@ -149,45 +97,18 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
return table_html + rows_notice
-class CreateChatCallback(Protocol):
- def __call__(self, system_prompt: str) -> chatlas.Chat: ...
-
-
-class QueryChatConfig:
- """
- Configuration class for querychat.
- """
-
- def __init__(
- self,
- df: pd.DataFrame,
- conn: duckdb.DuckDBPyConnection,
- system_prompt: str,
- greeting: Optional[str],
- create_chat_callback: CreateChatCallback,
- ):
- self.df = df
- self.conn = conn
- self.system_prompt = system_prompt
- self.greeting = greeting
- self.create_chat_callback = create_chat_callback
-
-
def init(
- df: pd.DataFrame,
- table_name: str,
+ data_source: DataSource,
greeting: Optional[str] = None,
data_description: Optional[str] = None,
extra_instructions: Optional[str] = None,
create_chat_callback: Optional[CreateChatCallback] = None,
system_prompt_override: Optional[str] = None,
) -> QueryChatConfig:
- """
- Call this once outside of any server function to initialize querychat.
+ """Initialize querychat with any compliant data source.
Args:
- df: A data frame
- table_name: A string containing a valid table name for the data frame
+ data_source: A DataSource implementation that provides schema and query execution
greeting: A string in Markdown format, containing the initial message
data_description: Description of the data in plain text or Markdown
extra_instructions: Additional instructions for the chat model
@@ -197,12 +118,6 @@ def init(
Returns:
A QueryChatConfig object that can be passed to server()
"""
- # Validate table name (must begin with letter, contain only letters, numbers, underscores)
- if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name):
- raise ValueError(
- "Table name must begin with a letter and contain only letters, numbers, and underscores"
- )
-
# Process greeting
if greeting is None:
print(
@@ -211,26 +126,18 @@ def init(
file=sys.stderr,
)
- # Create the system prompt
- if system_prompt_override is None:
- _system_prompt = system_prompt(
- df, table_name, data_description, extra_instructions
- )
- else:
- _system_prompt = system_prompt_override
-
- # Set up DuckDB connection and register the data frame
- conn = duckdb.connect(database=":memory:")
- conn.register(table_name, df)
+ # Create the system prompt, or use the override
+ _system_prompt = system_prompt_override or system_prompt(
+ data_source, data_description, extra_instructions
+ )
# Default chat function if none provided
create_chat_callback = create_chat_callback or partial(
- chatlas.ChatOpenAI, model="gpt-4o"
+ chatlas.ChatOpenAI, model="gpt-4.1"
)
return QueryChatConfig(
- df=df,
- conn=conn,
+ data_source=data_source,
system_prompt=_system_prompt,
greeting=greeting,
create_chat_callback=create_chat_callback,
@@ -306,8 +213,7 @@ def _():
pass
# Extract config parameters
- df = querychat_config.df
- conn = querychat_config.conn
+ data_source = querychat_config.data_source
system_prompt = querychat_config.system_prompt
greeting = querychat_config.greeting
create_chat_callback = querychat_config.create_chat_callback
@@ -319,9 +225,9 @@ def _():
@reactive.Calc
def filtered_df():
if current_query.get() == "":
- return df
+ return data_source.get_data()
else:
- return conn.execute(current_query.get()).fetch_df()
+ return data_source.execute_query(current_query.get())
# This would handle appending messages to the chat UI
async def append_output(text):
@@ -345,7 +251,7 @@ async def update_dashboard(query: str, title: str):
try:
# Try the query to see if it errors
- conn.execute(query)
+ data_source.execute_query(query)
except Exception as e:
error_msg = str(e)
await append_output(f"> Error: {error_msg}\n\n")
@@ -370,7 +276,7 @@ async def query(query: str):
await append_output(f"\n```sql\n{query}\n```\n\n")
try:
- result_df = conn.execute(query).fetch_df()
+ result_df = data_source.execute_query(query)
except Exception as e:
error_msg = str(e)
await append_output(f"> Error: {error_msg}\n\n")