From 973a43301e47a89b290be2b5d6b9cdf0ba6fc253 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 3 Apr 2025 22:47:22 -0700 Subject: [PATCH 01/11] First attempt at genericizing data source --- python-package/examples/app.py | 4 +- python-package/querychat/datasource.py | 207 +++++++++++++++++++++++++ python-package/querychat/querychat.py | 162 +++++-------------- 3 files changed, 251 insertions(+), 122 deletions(-) create mode 100644 python-package/querychat/datasource.py diff --git a/python-package/examples/app.py b/python-package/examples/app.py index 926622cef..5e628f431 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app.py @@ -4,6 +4,7 @@ from shiny import App, render, ui import querychat +from querychat.datasource import DataFrameSource titanic = load_dataset("titanic") @@ -14,8 +15,7 @@ # 1. Configure querychat querychat_config = querychat.init( - titanic, - "titanic", + DataFrameSource(titanic, "titanic"), greeting=greeting, data_description=data_desc, ) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py new file mode 100644 index 000000000..495139ed7 --- /dev/null +++ b/python-package/querychat/datasource.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import Protocol +import pandas as pd +import duckdb +import sqlite3 +import narwhals as nw + + +class DataSource(Protocol): + def get_schema(self) -> str: + """Return schema information about the table as a string. + + Returns: + A string containing the schema information in a format suitable for + prompting an LLM about the data structure + """ + ... + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as a pandas DataFrame + """ + ... + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + ... + + +class DataFrameSource: + """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + + def __init__(self, df: pd.DataFrame, table_name: str): + """Initialize with a pandas DataFrame. + + Args: + df: The DataFrame to wrap + table_name: Name of the table in SQL queries + """ + self._conn = duckdb.connect(database=":memory:") + self._df = df + self._table_name = table_name + self._conn.register(table_name, df) + + def get_schema(self, categorical_threshold: int = 10) -> str: + """Generate schema information from DataFrame. + + Args: + table_name: Name to use for the table in schema description + categorical_threshold: Maximum number of unique values for a text column + to be considered categorical + + Returns: + String describing the schema + """ + ndf = nw.from_native(self._df) + + schema = [f"Table: {self._table_name}", "Columns:"] + + for column in ndf.columns: + # Map pandas dtypes to SQL-like types + dtype = ndf[column].dtype + if dtype.is_integer(): + sql_type = "INTEGER" + elif dtype.is_float(): + sql_type = "FLOAT" + elif dtype == nw.Boolean: + sql_type = "BOOLEAN" + elif dtype == nw.Datetime: + sql_type = "TIME" + elif dtype == nw.Date: + sql_type = "DATE" + else: + sql_type = "TEXT" + + column_info = [f"- {column} ({sql_type})"] + + # For TEXT columns, check if they're categorical + if sql_type == "TEXT": + unique_values = ndf[column].drop_nulls().unique() + if unique_values.len() <= categorical_threshold: + categories = unique_values.to_list() + categories_str = ", ".join([f"'{c}'" for c in categories]) + column_info.append(f" Categorical values: {categories_str}") + + # For numeric columns, include range + elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: + rng = ndf[column].min(), ndf[column].max() + if rng[0] is None and rng[1] is None: + column_info.append(" Range: NULL to NULL") + else: + column_info.append(f" Range: {rng[0]} to {rng[1]}") + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute query using DuckDB. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + """ + return self._conn.execute(query).df() + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + return self._df.copy() + + +class SQLiteSource: + """A DataSource implementation that wraps a SQLite connection.""" + + def __init__(self, conn: sqlite3.Connection, table_name: str): + """Initialize with a SQLite connection. + + Args: + conn: SQLite database connection + """ + self._conn = conn + self._table_name = table_name + + def get_schema(self) -> str: + """Generate schema information from SQLite table. + + Returns: + String describing the schema + """ + # Get column info + cursor = self._conn.execute(f"PRAGMA table_info({self._table_name})") + columns = cursor.fetchall() + + schema = [f"Table: {self._table_name}", "Columns:"] + + for col in columns: + # col format: (cid, name, type, notnull, dflt_value, pk) + column_info = [f"- {col[1]} ({col[2].upper()})"] + + # For numeric columns, try to get range + if col[2].upper() in ["INTEGER", "FLOAT", "REAL", "NUMERIC"]: + try: + cursor = self._conn.execute( + f"SELECT MIN({col[1]}), MAX({col[1]}) FROM {self._table_name}" + ) + min_val, max_val = cursor.fetchone() + if min_val is not None and max_val is not None: + column_info.append(f" Range: {min_val} to {max_val}") + except sqlite3.Error: + pass # Skip range info if query fails + + # For text columns, check if categorical (limited distinct values) + elif col[2].upper() == "TEXT": + try: + cursor = self._conn.execute( + f"SELECT COUNT(DISTINCT {col[1]}) FROM {self._table_name}" + ) + distinct_count = cursor.fetchone()[0] + if distinct_count <= 10: # Use fixed threshold for simplicity + cursor = self._conn.execute( + f"SELECT DISTINCT {col[1]} FROM {self._table_name} " + f"WHERE {col[1]} IS NOT NULL" + ) + values = [str(row[0]) for row in cursor.fetchall()] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except sqlite3.Error: + pass # Skip categorical info if query fails + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute query using SQLite. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + """ + return pd.read_sql_query(query, self._conn) + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + return pd.read_sql_query(f"SELECT * FROM {self._table_name}", self._conn) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 4e492fb19..22b2f5ff9 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -15,29 +15,49 @@ import narwhals as nw from narwhals.typing import IntoFrame +from .datasource import DataSource + + +class CreateChatCallback(Protocol): + def __call__(self, system_prompt: str) -> chatlas.Chat: ... + + +class QueryChatConfig: + """ + Configuration class for querychat. + """ + + def __init__( + self, + data_source: DataSource, + system_prompt: str, + greeting: Optional[str], + create_chat_callback: CreateChatCallback, + ): + self.data_source = data_source + self.system_prompt = system_prompt + self.greeting = greeting + self.create_chat_callback = create_chat_callback + def system_prompt( - df: IntoFrame, - table_name: str, + data_source: DataSource, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, - categorical_threshold: int = 10, ) -> str: """ - Create a system prompt for the chat model based on a data frame's + Create a system prompt for the chat model based on a data source's schema and optional additional context and instructions. Args: - df: A DataFrame to generate schema information from - table_name: A string containing the name of the table in SQL queries + data_source: A data source to generate schema information from data_description: Optional description of the data, in plain text or Markdown format extra_instructions: Optional additional instructions for the chat model, in plain text or Markdown format - categorical_threshold: The maximum number of unique values for a text column to be considered categorical Returns: A string containing the system prompt for the chat model """ - schema = df_to_schema(df, table_name, categorical_threshold) + schema = data_source.get_schema() # Read the prompt file prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md") @@ -65,62 +85,6 @@ def system_prompt( return prompt_text -def df_to_schema(df: IntoFrame, table_name: str, categorical_threshold: int) -> str: - """ - Convert a DataFrame schema to a string representation for the system prompt. - - Args: - df: The DataFrame to extract schema from - table_name: The name of the table in SQL queries - categorical_threshold: The maximum number of unique values for a text column to be considered categorical - - Returns: - A string containing the schema information - """ - - ndf = nw.from_native(df) - - schema = [f"Table: {table_name}", "Columns:"] - - for column in ndf.columns: - # Map pandas dtypes to SQL-like types - dtype = ndf[column].dtype - if dtype.is_integer(): - sql_type = "INTEGER" - elif dtype.is_float(): - sql_type = "FLOAT" - elif dtype == nw.Boolean: - sql_type = "BOOLEAN" - elif dtype == nw.Datetime: - sql_type = "TIME" - elif dtype == nw.Date: - sql_type = "DATE" - else: - sql_type = "TEXT" - - column_info = [f"- {column} ({sql_type})"] - - # For TEXT columns, check if they're categorical - if sql_type == "TEXT": - unique_values = ndf[column].drop_nulls().unique() - if unique_values.len() <= categorical_threshold: - categories = unique_values.to_list() - categories_str = ", ".join([f"'{c}'" for c in categories]) - column_info.append(f" Categorical values: {categories_str}") - - # For numeric columns, include range - elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: - rng = ndf[column].min(), ndf[column].max() - if rng[0] is None and rng[1] is None: - column_info.append(" Range: NULL to NULL") - else: - column_info.append(f" Range: {rng[0]} to {rng[1]}") - - schema.extend(column_info) - - return "\n".join(schema) - - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ Convert a DataFrame to an HTML table for display in chat. @@ -149,45 +113,18 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: return table_html + rows_notice -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - def __init__( - self, - df: pd.DataFrame, - conn: duckdb.DuckDBPyConnection, - system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.df = df - self.conn = conn - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback - - def init( - df: pd.DataFrame, - table_name: str, + data_source: DataSource, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, create_chat_callback: Optional[CreateChatCallback] = None, system_prompt_override: Optional[str] = None, ) -> QueryChatConfig: - """ - Call this once outside of any server function to initialize querychat. + """Initialize querychat with any compliant data source. Args: - df: A data frame - table_name: A string containing a valid table name for the data frame + data_source: A DataSource implementation that provides schema and query execution greeting: A string in Markdown format, containing the initial message data_description: Description of the data in plain text or Markdown extra_instructions: Additional instructions for the chat model @@ -197,12 +134,6 @@ def init( Returns: A QueryChatConfig object that can be passed to server() """ - # Validate table name (must begin with letter, contain only letters, numbers, underscores) - if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): - raise ValueError( - "Table name must begin with a letter and contain only letters, numbers, and underscores" - ) - # Process greeting if greeting is None: print( @@ -211,26 +142,18 @@ def init( file=sys.stderr, ) - # Create the system prompt - if system_prompt_override is None: - _system_prompt = system_prompt( - df, table_name, data_description, extra_instructions - ) - else: - _system_prompt = system_prompt_override - - # Set up DuckDB connection and register the data frame - conn = duckdb.connect(database=":memory:") - conn.register(table_name, df) + # Create the system prompt, or use the override + _system_prompt = system_prompt_override or system_prompt( + data_source, data_description, extra_instructions + ) # Default chat function if none provided create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, model="gpt-4o" + chatlas.ChatOpenAI, model="gpt-4" ) return QueryChatConfig( - df=df, - conn=conn, + data_source=data_source, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, @@ -306,8 +229,7 @@ def _(): pass # Extract config parameters - df = querychat_config.df - conn = querychat_config.conn + data_source = querychat_config.data_source system_prompt = querychat_config.system_prompt greeting = querychat_config.greeting create_chat_callback = querychat_config.create_chat_callback @@ -319,9 +241,9 @@ def _(): @reactive.Calc def filtered_df(): if current_query.get() == "": - return df + return data_source.get_data() else: - return conn.execute(current_query.get()).fetch_df() + return data_source.execute_query(current_query.get()) # This would handle appending messages to the chat UI async def append_output(text): @@ -345,7 +267,7 @@ async def update_dashboard(query: str, title: str): try: # Try the query to see if it errors - conn.execute(query) + data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -370,7 +292,7 @@ async def query(query: str): await append_output(f"\n```sql\n{query}\n```\n\n") try: - result_df = conn.execute(query).fetch_df() + result_df = data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") From 8de0ac71d3e687ec66151b7e977ced697f2a590a Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 4 Apr 2025 08:53:21 -0700 Subject: [PATCH 02/11] Unify prompts by adding chevron Python dependency --- python-package/pyproject.toml | 1 + python-package/querychat/prompt/prompt.md | 6 ++++ python-package/querychat/querychat.py | 39 +++++++---------------- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index c709ee050..dca3b063c 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "htmltools", "chatlas", "narwhals", + "chevron", ] [project.urls] diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md index 62d1ea17f..154ce0cc8 100644 --- a/python-package/querychat/prompt/prompt.md +++ b/python-package/querychat/prompt/prompt.md @@ -10,7 +10,13 @@ You have at your disposal a DuckDB database containing this schema: For security reasons, you may only query this specific table. +{{#data_description}} +Additional helpful info about the data: + + {{data_description}} + +{{/data_description}} There are several tasks you may be asked to do: diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 22b2f5ff9..37af66e1f 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -1,19 +1,13 @@ from __future__ import annotations -import sys import os -import re -import pandas as pd -import duckdb -import json +import sys from functools import partial -from typing import List, Dict, Any, Callable, Optional, Union, Protocol +from typing import Any, Dict, Optional, Protocol import chatlas -from htmltools import TagList, tags, HTML -from shiny import module, reactive, ui, Inputs, Outputs, Session -import narwhals as nw -from narwhals.typing import IntoFrame +import chevron +from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource @@ -64,26 +58,15 @@ def system_prompt( with open(prompt_path, "r") as f: prompt_text = f.read() - # Simple template replacement (a more robust template engine could be used) - if data_description: - data_description_section = ( - "Additional helpful info about the data:\n\n" - "\n" - f"{data_description}\n" - "" - ) - else: - data_description_section = "" - - # Replace variables in the template - prompt_text = prompt_text.replace("{{schema}}", schema) - prompt_text = prompt_text.replace("{{data_description}}", data_description_section) - prompt_text = prompt_text.replace( - "{{extra_instructions}}", extra_instructions or "" + return chevron.render( + prompt_text, + { + "schema": schema, + "data_description": data_description, + "extra_instructions": extra_instructions, + }, ) - return prompt_text - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ From 53c7df3ddeda8b07f534a906165b04205ef83b31 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 18 Apr 2025 14:03:24 -0700 Subject: [PATCH 03/11] Make prompt aware of what engine is being used --- python-package/querychat/datasource.py | 13 ++++++++++--- python-package/querychat/prompt/prompt.md | 10 +++++++--- python-package/querychat/querychat.py | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index 495139ed7..e408e4b03 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import Protocol -import pandas as pd -import duckdb import sqlite3 +from typing import ClassVar, Protocol + +import duckdb import narwhals as nw +import pandas as pd class DataSource(Protocol): + db_engine: ClassVar[str] + def get_schema(self) -> str: """Return schema information about the table as a string. @@ -40,6 +43,8 @@ def get_data(self) -> pd.DataFrame: class DataFrameSource: """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + db_engine: ClassVar[str] = "DuckDB" + def __init__(self, df: pd.DataFrame, table_name: str): """Initialize with a pandas DataFrame. @@ -128,6 +133,8 @@ def get_data(self) -> pd.DataFrame: class SQLiteSource: """A DataSource implementation that wraps a SQLite connection.""" + db_engine: ClassVar[str] = "SQLite" + def __init__(self, conn: sqlite3.Connection, table_name: str): """Initialize with a SQLite connection. diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md index 154ce0cc8..5155ae185 100644 --- a/python-package/querychat/prompt/prompt.md +++ b/python-package/querychat/prompt/prompt.md @@ -4,7 +4,7 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_engine}} database containing this schema: {{schema}} @@ -25,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. @@ -80,7 +80,11 @@ Example of question answering: If the user provides a vague help request, like "Help" or "Show me instructions", describe your own capabilities in a helpful way, including examples of questions they can ask. Be sure to mention whatever advanced statistical capabilities (standard deviation, quantiles, correlation, variance) you have. -## DuckDB SQL tips +## SQL tips + +* The SQL engine is {{db_engine}}. + +* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions. * `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`. diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 37af66e1f..fb0e6997b 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -51,7 +51,6 @@ def system_prompt( Returns: A string containing the system prompt for the chat model """ - schema = data_source.get_schema() # Read the prompt file prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md") @@ -61,7 +60,8 @@ def system_prompt( return chevron.render( prompt_text, { - "schema": schema, + "db_engine": data_source.db_engine, + "schema": data_source.get_schema(), "data_description": data_description, "extra_instructions": extra_instructions, }, From a2122f22da9233ce6edc3ece5ac12440e5a35f63 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 18 Apr 2025 14:37:45 -0700 Subject: [PATCH 04/11] Replace SQLite support with SQLAlchemy support --- python-package/pyproject.toml | 5 + python-package/querychat/datasource.py | 133 ++++++++++++++++++------- 2 files changed, 100 insertions(+), 38 deletions(-) diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index dca3b063c..4ca437a27 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -23,6 +23,11 @@ dependencies = [ "chevron", ] +[project.optional-dependencies] +sqlalchemy = [ + "sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API +] + [project.urls] Homepage = "https://github.com/posit-dev/querychat" Issues = "https://github.com/posit-dev/querychat/issues" diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index e408e4b03..e33711e78 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -1,11 +1,13 @@ from __future__ import annotations -import sqlite3 from typing import ClassVar, Protocol import duckdb import narwhals as nw import pandas as pd +from sqlalchemy import inspect, text +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.sql import sqltypes class DataSource(Protocol): @@ -130,64 +132,93 @@ def get_data(self) -> pd.DataFrame: return self._df.copy() -class SQLiteSource: - """A DataSource implementation that wraps a SQLite connection.""" +class SQLAlchemySource: + """A DataSource implementation that supports multiple SQL databases via SQLAlchemy. - db_engine: ClassVar[str] = "SQLite" + Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks. + """ - def __init__(self, conn: sqlite3.Connection, table_name: str): - """Initialize with a SQLite connection. + db_engine: ClassVar[str] = "SQLAlchemy" + + def __init__(self, engine: Engine, table_name: str): + """Initialize with a SQLAlchemy engine. Args: - conn: SQLite database connection + engine: SQLAlchemy engine + table_name: Name of the table to query """ - self._conn = conn + self._engine = engine self._table_name = table_name + # Validate table exists + inspector = inspect(self._engine) + if table_name not in inspector.get_table_names(): + raise ValueError(f"Table '{table_name}' not found in database") + def get_schema(self) -> str: - """Generate schema information from SQLite table. + """Generate schema information from database table. Returns: String describing the schema """ - # Get column info - cursor = self._conn.execute(f"PRAGMA table_info({self._table_name})") - columns = cursor.fetchall() + inspector = inspect(self._engine) + columns = inspector.get_columns(self._table_name) schema = [f"Table: {self._table_name}", "Columns:"] for col in columns: - # col format: (cid, name, type, notnull, dflt_value, pk) - column_info = [f"- {col[1]} ({col[2].upper()})"] + # Get SQL type name + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col['name']} ({sql_type})"] # For numeric columns, try to get range - if col[2].upper() in ["INTEGER", "FLOAT", "REAL", "NUMERIC"]: + if isinstance( + col["type"], + ( + sqltypes.Integer, + sqltypes.Numeric, + sqltypes.Float, + sqltypes.Date, + sqltypes.Time, + sqltypes.DateTime, + sqltypes.BigInteger, + sqltypes.SmallInteger, + # sqltypes.Interval, + ), + ): try: - cursor = self._conn.execute( - f"SELECT MIN({col[1]}), MAX({col[1]}) FROM {self._table_name}" + query = text( + f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}" ) - min_val, max_val = cursor.fetchone() - if min_val is not None and max_val is not None: - column_info.append(f" Range: {min_val} to {max_val}") - except sqlite3.Error: + with self._get_connection() as conn: + result = conn.execute(query).fetchone() + if result and result[0] is not None and result[1] is not None: + column_info.append(f" Range: {result[0]} to {result[1]}") + except Exception: pass # Skip range info if query fails - # For text columns, check if categorical (limited distinct values) - elif col[2].upper() == "TEXT": + # For string/text columns, check if categorical + elif isinstance( + col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum) + ): try: - cursor = self._conn.execute( - f"SELECT COUNT(DISTINCT {col[1]}) FROM {self._table_name}" + count_query = text( + f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}" ) - distinct_count = cursor.fetchone()[0] - if distinct_count <= 10: # Use fixed threshold for simplicity - cursor = self._conn.execute( - f"SELECT DISTINCT {col[1]} FROM {self._table_name} " - f"WHERE {col[1]} IS NOT NULL" - ) - values = [str(row[0]) for row in cursor.fetchall()] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except sqlite3.Error: + with self._get_connection() as conn: + distinct_count = conn.execute(count_query).scalar() + if distinct_count and distinct_count <= 10: + values_query = text( + f"SELECT DISTINCT {col['name']} FROM {self._table_name} " + f"WHERE {col['name']} IS NOT NULL" + ) + values = [ + str(row[0]) + for row in conn.execute(values_query).fetchall() + ] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except Exception: pass # Skip categorical info if query fails schema.extend(column_info) @@ -195,7 +226,7 @@ def get_schema(self) -> str: return "\n".join(schema) def execute_query(self, query: str) -> pd.DataFrame: - """Execute query using SQLite. + """Execute SQL query and return results as DataFrame. Args: query: SQL query to execute @@ -203,7 +234,8 @@ def execute_query(self, query: str) -> pd.DataFrame: Returns: Query results as pandas DataFrame """ - return pd.read_sql_query(query, self._conn) + with self._get_connection() as conn: + return pd.read_sql_query(text(query), conn) def get_data(self) -> pd.DataFrame: """Return the unfiltered data as a DataFrame. @@ -211,4 +243,29 @@ def get_data(self) -> pd.DataFrame: Returns: The complete dataset as a pandas DataFrame """ - return pd.read_sql_query(f"SELECT * FROM {self._table_name}", self._conn) + return self.execute_query(f"SELECT * FROM {self._table_name}") + + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: + """Convert SQLAlchemy type to SQL type name.""" + if isinstance(type_, sqltypes.Integer): + return "INTEGER" + elif isinstance(type_, sqltypes.Float): + return "FLOAT" + elif isinstance(type_, sqltypes.Numeric): + return "NUMERIC" + elif isinstance(type_, sqltypes.Boolean): + return "BOOLEAN" + elif isinstance(type_, sqltypes.DateTime): + return "TIMESTAMP" + elif isinstance(type_, sqltypes.Date): + return "DATE" + elif isinstance(type_, sqltypes.Time): + return "TIME" + elif isinstance(type_, (sqltypes.String, sqltypes.Text)): + return "TEXT" + else: + return type_.__class__.__name__.upper() + + def _get_connection(self) -> Connection: + """Get a connection to use for queries.""" + return self._engine.connect() From a218fb914963a4477598c8f4d0081bae043de286 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Wed, 23 Apr 2025 16:26:58 -0700 Subject: [PATCH 05/11] Don't fail when given table name's case differs from SQLAlchemy Inspector --- python-package/querychat/datasource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index e33711e78..1fee9b9c6 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -6,7 +6,7 @@ import narwhals as nw import pandas as pd from sqlalchemy import inspect, text -from sqlalchemy.engine import Engine, Connection +from sqlalchemy.engine import Connection, Engine from sqlalchemy.sql import sqltypes @@ -152,7 +152,7 @@ def __init__(self, engine: Engine, table_name: str): # Validate table exists inspector = inspect(self._engine) - if table_name not in inspector.get_table_names(): + if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") def get_schema(self) -> str: From dc0814ef6a68575d0bb9624f43596507d769f4e3 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 1 May 2025 16:58:29 -0400 Subject: [PATCH 06/11] Forgot import --- python-package/querychat/querychat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index fb0e6997b..ed5583627 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -7,6 +7,7 @@ import chatlas import chevron +import narwhals as nw from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource From 9d95d1d0f47db306c3a422d913cfbcf8c6e0d244 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:12:35 -0700 Subject: [PATCH 07/11] Have server() return proper class with typed methods, instead of dict --- .gitignore | 3 +- python-package/examples/app-database.py | 55 +++++++++ .../examples/{app.py => app-dataframe.py} | 7 +- python-package/querychat/querychat.py | 104 ++++++++++++++++-- 4 files changed, 154 insertions(+), 15 deletions(-) create mode 100644 python-package/examples/app-database.py rename python-package/examples/{app.py => app-dataframe.py} (97%) diff --git a/.gitignore b/.gitignore index 98ab22956..32d0462bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ animation.screenflow/ README_files/ -README.html \ No newline at end of file +README.html +.DS_Store \ No newline at end of file diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py new file mode 100644 index 000000000..cfee136e6 --- /dev/null +++ b/python-package/examples/app-database.py @@ -0,0 +1,55 @@ +import sqlite3 +from pathlib import Path + +import querychat +from querychat.datasource import SQLAlchemySource +from seaborn import load_dataset +from shiny import App, render, ui +from sqlalchemy import create_engine + +# Load titanic data and create SQLite database +db_path = Path(__file__).parent / "titanic.db" +engine = create_engine("sqlite:///" + str(db_path)) +# titanic = load_dataset("titanic") +# titanic.to_sql("titanic", conn, if_exists="replace", index=False) + +with open(Path(__file__).parent / "greeting.md", "r") as f: + greeting = f.read() +with open(Path(__file__).parent / "data_description.md", "r") as f: + data_desc = f.read() + +# 1. Configure querychat +querychat_config = querychat.init( + SQLAlchemySource(engine, "titanic"), + greeting=greeting, + data_description=data_desc, +) + +# Create UI +app_ui = ui.page_sidebar( + # 2. Place the chat component in the sidebar + querychat.sidebar("chat"), + # Main panel with data viewer + ui.card( + ui.output_data_frame("data_table"), + fill=True, + ), + title="querychat with Python (SQLite)", + fillable=True, +) + + +# Define server logic +def server(input, output, session): + # 3. Initialize querychat server with the config from step 1 + chat = querychat.server("chat", querychat_config) + + # 4. Display the filtered dataframe + @render.data_frame + def data_table(): + # Access filtered data via chat.df() reactive + return chat["df"]() + + +# Create Shiny app +app = App(app_ui, server) diff --git a/python-package/examples/app.py b/python-package/examples/app-dataframe.py similarity index 97% rename from python-package/examples/app.py rename to python-package/examples/app-dataframe.py index 5e628f431..13d224fbb 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app-dataframe.py @@ -1,10 +1,9 @@ from pathlib import Path -from seaborn import load_dataset -from shiny import App, render, ui - import querychat from querychat.datasource import DataFrameSource +from seaborn import load_dataset +from shiny import App, render, ui titanic = load_dataset("titanic") @@ -43,7 +42,7 @@ def server(input, output, session): @render.data_frame def data_table(): # Access filtered data via chat.df() reactive - return chat["df"]() + return chat.df() # Create Shiny app diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index ed5583627..093dec163 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -3,11 +3,13 @@ import os import sys from functools import partial -from typing import Any, Dict, Optional, Protocol +from typing import Any, Callable, Optional, Protocol import chatlas import chevron import narwhals as nw +import pandas as pd +from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource @@ -35,6 +37,93 @@ def __init__( self.create_chat_callback = create_chat_callback +class QueryChat: + """ + An object representing a query chat session. This is created within a Shiny + server function or Shiny module server function by using + `querychat.server()`. Use this object to bridge the chat interface with the + rest of the Shiny app, for example, by displaying the filtered data. + """ + + def __init__( + self, + chat: chatlas.Chat, + sql: Callable[[], str], + title: Callable[[], str | None], + df: Callable[[], pd.DataFrame], + ): + """ + Initialize a QueryChat object. + + Args: + chat: The chat object for the session + sql: Reactive that returns the current SQL query + title: Reactive that returns the current title + df: Reactive that returns the filtered data frame + """ + self._chat = chat + self._sql = sql + self._title = title + self._df = df + + def chat(self) -> chatlas.Chat: + """ + Get the chat object for this session. + + Returns: + The chat object + """ + return self._chat() + + def sql(self) -> str: + """ + Reactively read the current SQL query that is in effect. + + Returns: + The current SQL query as a string, or `""` if no query has been set. + """ + return self._sql() + + def title(self) -> str | None: + """ + Reactively read the current title that is in effect. The title is a + short description of the current query that the LLM provides to us + whenever it generates a new SQL query. It can be used as a status string + for the data dashboard. + + Returns: + The current title as a string, or `None` if no title has been set + due to no SQL query being set. + """ + return self._title() + + def df(self) -> pd.DataFrame: + """ + Reactively read the current filtered data frame that is in effect. + + Returns: + The current filtered data frame as a pandas DataFrame. If no query + has been set, this will return the unfiltered data frame from the + data source. + """ + return self._df() + + def __getitem__(self, key: str) -> Any: + """ + Allow access to configuration parameters like a dictionary. For + backwards compatibility only; new code should use the attributes + directly instead. + """ + if key == "chat": + return self.chat + elif key == "sql": + return self.sql + elif key == "title": + return self.title + elif key == "df": + return self.df + + def system_prompt( data_source: DataSource, data_description: Optional[str] = None, @@ -190,7 +279,7 @@ def sidebar(id: str, width: int = 400, height: str = "100%", **kwargs) -> ui.Sid @module.server def server( input: Inputs, output: Outputs, session: Session, querychat_config: QueryChatConfig -) -> Dict[str, Any]: +) -> QueryChat: """ Initialize the querychat server. @@ -219,8 +308,8 @@ def _(): create_chat_callback = querychat_config.create_chat_callback # Reactive values to store state - current_title = reactive.Value(None) - current_query = reactive.Value("") + current_title: reactive.Value[str | None] = reactive.Value(None) + current_query: reactive.Value[str] = reactive.Value("") @reactive.Calc def filtered_df(): @@ -326,9 +415,4 @@ async def greet_on_startup(): await chat_ui.append_message_stream(stream) # Return the interface for other components to use - return { - "chat": chat, - "sql": current_query.get, - "title": current_title.get, - "df": filtered_df, - } + return QueryChat(chat, current_query.get, current_title.get, filtered_df) From aeb87dd060fbafb1c973d94c7041ab20ccf71dd8 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:17:43 -0700 Subject: [PATCH 08/11] Auto-create sqlite database for example --- .gitignore | 3 ++- python-package/examples/app-database.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 32d0462bc..1639e0578 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__/ animation.screenflow/ README_files/ README.html -.DS_Store \ No newline at end of file +.DS_Store +python-package/examples/titanic.db diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py index cfee136e6..c196b3e79 100644 --- a/python-package/examples/app-database.py +++ b/python-package/examples/app-database.py @@ -1,4 +1,3 @@ -import sqlite3 from pathlib import Path import querychat @@ -10,8 +9,12 @@ # Load titanic data and create SQLite database db_path = Path(__file__).parent / "titanic.db" engine = create_engine("sqlite:///" + str(db_path)) -# titanic = load_dataset("titanic") -# titanic.to_sql("titanic", conn, if_exists="replace", index=False) + +if not db_path.exists(): + # For example purposes, we'll create the database if it doesn't exist. Don't + # do this in your app! + titanic = load_dataset("titanic") + titanic.to_sql("titanic", engine, if_exists="replace", index=False) with open(Path(__file__).parent / "greeting.md", "r") as f: greeting = f.read() From c38b567189b73dee742715c5983ae32d57adc6c1 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:38:25 -0700 Subject: [PATCH 09/11] Have init() take data frame or sqlalchemy engine directly ...instead of requiring explicit DataSource subclass creation --- python-package/examples/app-database.py | 4 ++-- python-package/examples/app-dataframe.py | 4 ++-- python-package/querychat/querychat.py | 19 ++++++++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py index c196b3e79..9769cc172 100644 --- a/python-package/examples/app-database.py +++ b/python-package/examples/app-database.py @@ -1,7 +1,6 @@ from pathlib import Path import querychat -from querychat.datasource import SQLAlchemySource from seaborn import load_dataset from shiny import App, render, ui from sqlalchemy import create_engine @@ -23,7 +22,8 @@ # 1. Configure querychat querychat_config = querychat.init( - SQLAlchemySource(engine, "titanic"), + engine, + "titanic", greeting=greeting, data_description=data_desc, ) diff --git a/python-package/examples/app-dataframe.py b/python-package/examples/app-dataframe.py index 13d224fbb..1a1fd8588 100644 --- a/python-package/examples/app-dataframe.py +++ b/python-package/examples/app-dataframe.py @@ -1,7 +1,6 @@ from pathlib import Path import querychat -from querychat.datasource import DataFrameSource from seaborn import load_dataset from shiny import App, render, ui @@ -14,7 +13,8 @@ # 1. Configure querychat querychat_config = querychat.init( - DataFrameSource(titanic, "titanic"), + titanic, + "titanic", greeting=greeting, data_description=data_desc, ) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 093dec163..aec6bba76 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -9,10 +9,11 @@ import chevron import narwhals as nw import pandas as pd +import sqlalchemy from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui -from .datasource import DataSource +from .datasource import DataFrameSource, DataSource, SQLAlchemySource class CreateChatCallback(Protocol): @@ -73,7 +74,7 @@ def chat(self) -> chatlas.Chat: Returns: The chat object """ - return self._chat() + return self._chat def sql(self) -> str: """ @@ -187,7 +188,8 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: def init( - data_source: DataSource, + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, @@ -207,6 +209,13 @@ def init( Returns: A QueryChatConfig object that can be passed to server() """ + + data_source_obj: DataSource + if isinstance(data_source, sqlalchemy.Engine): + data_source_obj = SQLAlchemySource(data_source, table_name) + else: + data_source_obj = DataFrameSource(nw.from_native(data_source).to_pandas(), table_name) + # Process greeting if greeting is None: print( @@ -217,7 +226,7 @@ def init( # Create the system prompt, or use the override _system_prompt = system_prompt_override or system_prompt( - data_source, data_description, extra_instructions + data_source_obj, data_description, extra_instructions ) # Default chat function if none provided @@ -226,7 +235,7 @@ def init( ) return QueryChatConfig( - data_source=data_source, + data_source=data_source_obj, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, From 57922b3fe2eeda722f28ca35b60543a9d4223d15 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 17:11:26 -0700 Subject: [PATCH 10/11] Use GPT-4.1 by default, not GPT-4, yuck --- python-package/querychat/querychat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index ed5583627..167560c51 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -133,7 +133,7 @@ def init( # Default chat function if none provided create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, model="gpt-4" + chatlas.ChatOpenAI, model="gpt-4.1" ) return QueryChatConfig( From a08764bf130895a895fdff7c2d535ef40855f156 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 17:23:12 -0700 Subject: [PATCH 11/11] Update README --- python-package/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python-package/README.md b/python-package/README.md index be8057ea0..9b29fb193 100644 --- a/python-package/README.md +++ b/python-package/README.md @@ -56,7 +56,7 @@ def server(input, output, session): # chat["df"]() reactive. @render.data_frame def data_table(): - return chat["df"]() + return chat.df() # Create Shiny app @@ -171,8 +171,8 @@ which you can then pass via: ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", data_description=Path("data_description.md").read_text() ) ``` @@ -185,8 +185,8 @@ You can add additional instructions of your own to the end of the system prompt, ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", extra_instructions=[ "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", @@ -218,8 +218,8 @@ def my_chat_func(system_prompt: str) -> chatlas.Chat: my_chat_func = partial(chatlas.ChatAnthropic, model="claude-3-7-sonnet-latest") querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", create_chat_callback=my_chat_func ) ```