Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ class ChatInfo(BaseModel):
ds_type: str = ''
datasource_name: str = ''
datasource_exists: bool = True
recommended_question: Optional[str] = None
recommended_generate: Optional[bool] = False
recommended_question: Optional[str] = None
recommended_generate: Optional[bool] = False
records: List[ChatRecord | dict] = []


Expand Down Expand Up @@ -237,9 +237,9 @@ def sql_user_question(self, current_time: str, change_title: bool):
def chart_sys_question(self):
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)

def chart_user_question(self, chart_type: Optional[str] = None):
def chart_user_question(self, chart_type: Optional[str] = '', schema: Optional[str] = ''):
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
chart_type=chart_type)
chart_type=chart_type, schema=schema)

def analysis_sys_question(self):
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies,
Expand Down
25 changes: 18 additions & 7 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
_ds = session.get(CoreDatasource, chat_question.datasource_id)
if _ds:
if _ds.oid != current_user.oid:
raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
raise SingleMessageError(
f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
chat.datasource = _ds.id
chat.engine_type = _ds.type_name
# save chat
Expand Down Expand Up @@ -410,7 +411,8 @@ def generate_recommend_questions_task(self, _session: Session):
reasoning_content=full_thinking_text,
token_usage=token_usage)
self.record = save_recommend_question_answer(session=_session, record_id=self.record.id,
answer={'content': full_guess_text}, articles_number=self.articles_number)
answer={'content': full_guess_text},
articles_number=self.articles_number)

yield {'recommended_question': self.record.recommended_question}

Expand Down Expand Up @@ -716,9 +718,9 @@ def generate_assistant_filter(self, _session: Session, sql, tables: List):
return None
return self.build_table_filter(session=_session, sql=sql, filters=filters)

def generate_chart(self, _session: Session, chart_type: Optional[str] = ''):
def generate_chart(self, _session: Session, chart_type: Optional[str] = '', schema: Optional[str] = ''):
# append current question
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type)))
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type, schema)))

self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=_session,
ai_modal_id=self.chat_question.ai_modal_id,
Expand Down Expand Up @@ -1079,9 +1081,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
sqlbot_temp_sql_text = None
assistant_dynamic_sql = None
# row permission
sql, tables = self.check_sql(res=full_sql_text)
if ((not self.current_assistant or is_page_embedded) and is_normal_user(
self.current_user)) or use_dynamic_ds:
sql, tables = self.check_sql(res=full_sql_text)
sql_result = None

if use_dynamic_ds:
Expand Down Expand Up @@ -1167,7 +1169,16 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
return

# generate chart
chart_res = self.generate_chart(_session, chart_type)
used_tables_schema = self.out_ds_instance.get_db_schema(
self.ds.id, self.chat_question.question, embedding=False,
table_list=tables) if self.out_ds_instance else get_table_schema(
session=_session,
current_user=self.current_user,
ds=self.ds,
question=self.chat_question.question,
embedding=False, table_list=tables)
SQLBotLogUtil.info('used_tables_schema: \n' + used_tables_schema)
chart_res = self.generate_chart(_session, chart_type, used_tables_schema)
full_chart_text = ''
for chunk in chart_res:
full_chart_text += chunk.get('content')
Expand Down Expand Up @@ -1482,7 +1493,7 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
y = None
series = None
multi_quota_fields = []
multi_quota_name =None
multi_quota_name = None

if chart.get('axis'):
axis_data = chart.get('axis')
Expand Down
10 changes: 9 additions & 1 deletion backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core


def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
embedding: bool = True) -> str:
embedding: bool = True, table_list: list[str] = None) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
Expand All @@ -435,6 +435,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables = []
all_tables = [] # temp save all tables
for obj in table_objs:
# 如果传入了table_list,则只处理在列表中的表
if table_list is not None and obj.table.table_name not in table_list:
continue

schema_table = ''
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
table_comment = ''
Expand Down Expand Up @@ -462,6 +466,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables.append(t_obj)
all_tables.append(t_obj)

# 如果没有符合过滤条件的表,直接返回
if not tables:
return schema_str

# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
tables = calc_table_embedding(tables, question)
Expand Down
7 changes: 6 additions & 1 deletion backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,19 @@ def get_simple_ds_list(self):
else:
raise Exception("Datasource list is not found.")

def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str:
def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
table_list: list[str] = None) -> str:
ds = self.get_ds(ds_id)
schema_str = ""
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
i = 0
for table in ds.tables:
# 如果传入了 table_list,则只处理在列表中的表
if table_list is not None and table.name not in table_list:
continue

i += 1
schema_table = ''
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {table.name}"
Expand Down
9 changes: 8 additions & 1 deletion backend/templates/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,11 @@ template:
<Instruction>
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL,查询数据并进行图表展示。
你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。
用户的提问在<user-question>内,<sql>内是给定需要参考的SQL,<chart-type>内是推荐你生成的图表类型
用户会提供给你如下信息,帮助你生成配置项:
<user-question>:用户的提问
<sql>:需要参考的SQL
<m-schema>:以 M-Schema 格式提供 SQL 内用到表的数据库表结构信息,你可以参考字段名与字段备注来生成图表使用到的字段名
<chart-type>:推荐你生成的图表类型
</Instruction>

你必须遵守以下规则:
Expand Down Expand Up @@ -455,6 +459,9 @@ template:
<sql>
{sql}
</sql>
<m-schema>
{schema}
</m-schema>
<chart-type>
{chart_type}
</chart-type>
Expand Down