Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,5 @@ test.py


!/.venv/

sqlbot-xpack
4 changes: 2 additions & 2 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# from apps.system.models.user import SQLModel # noqa
# from apps.settings.models.setting_models import SQLModel
#from apps.chat.models.chat_model import SQLModel
#from apps.terminology.models.terminology_model import SQLModel
#from apps.custom_prompt.models.custom_prompt_model import SQLModel
from apps.terminology.models.terminology_model import SQLModel
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import SQLModel
#from apps.data_training.models.data_training_model import SQLModel
# from apps.dashboard.models.dashboard_model import SQLModel
from common.core.config import settings # noqa
Expand Down
31 changes: 31 additions & 0 deletions backend/alembic/versions/069_term_custom_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""069_term_custom_prompt

Revision ID: 1f82cad3546e
Revises: a1b2c3d4e5f6
Create Date: 2026-06-15 14:51:12.280391

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '1f82cad3546e'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('custom_prompt', sa.Column('advanced_application', sa.BigInteger(), nullable=True))
op.add_column('terminology', sa.Column('advanced_application', sa.BigInteger(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('terminology', 'advanced_application')
op.drop_column('custom_prompt', 'advanced_application')
# ### end Alembic commands ###
44 changes: 29 additions & 15 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
import orjson
import pandas as pd
import requests
import sqlparse
import sqlglot
from sqlglot import exp
import sqlparse
from langchain.chat_models.base import BaseChatModel
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, BaseMessageChunk
Expand All @@ -24,6 +23,7 @@
from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum
from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil
from sqlglot import exp
from sqlmodel import Session

from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config
Expand Down Expand Up @@ -68,7 +68,6 @@
i18n = I18n()



def extract_tables_from_sql(sql: str, ds_type: str = None) -> set:
"""从 SQL 中提取表名(使用 sqlglot 解析,可信)"""
tables = set()
Expand Down Expand Up @@ -340,38 +339,53 @@ def get_fields_from_chart(self, _session: Session):
return format_chart_fields(chart_info)

def filter_terminology_template(self, _session: Session, oid: int = None, ds_id: int = None):
self.current_logs[OperationEnum.FILTER_TERMS] = start_log(session=_session,
operate=OperationEnum.FILTER_TERMS,
record_id=self.record.id, local_operation=True)
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
self.current_logs[OperationEnum.FILTER_TERMS] = start_log(session=_session,
operate=OperationEnum.FILTER_TERMS,
record_id=self.record.id, local_operation=True)
if self.current_assistant and self.current_assistant.type == 1:
self.chat_question.terminologies, term_list = get_terminology_template(_session,
self.chat_question.question,
calculate_oid,
None, self.current_assistant.id)
else:
self.chat_question.terminologies, term_list = get_terminology_template(_session,
self.chat_question.question,
calculate_oid,
calculate_ds_id)

self.chat_question.terminologies, term_list = get_terminology_template(_session, self.chat_question.question,
calculate_oid, calculate_ds_id)
self.current_logs[OperationEnum.FILTER_TERMS] = end_log(session=_session,
log=self.current_logs[OperationEnum.FILTER_TERMS],
full_message=term_list)

def filter_custom_prompts(self, _session: Session, custom_prompt_type: CustomPromptTypeEnum, oid: int = None,
ds_id: int = None):
if SQLBotLicenseUtil.valid():
self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = start_log(session=_session,
operate=OperationEnum.FILTER_CUSTOM_PROMPT,
record_id=self.record.id,
local_operation=True)
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = start_log(session=_session,
operate=OperationEnum.FILTER_CUSTOM_PROMPT,
record_id=self.record.id,
local_operation=True)
self.chat_question.custom_prompt, prompt_list = find_custom_prompts(_session, custom_prompt_type,
calculate_oid,
calculate_ds_id)
if self.current_assistant and self.current_assistant.type == 1:
self.chat_question.custom_prompt, prompt_list = find_custom_prompts(_session,
custom_prompt_type,
calculate_oid,
None, self.current_assistant.id)
else:
self.chat_question.custom_prompt, prompt_list = find_custom_prompts(_session,
custom_prompt_type,
calculate_oid,
calculate_ds_id)
self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = end_log(session=_session,
log=self.current_logs[
OperationEnum.FILTER_CUSTOM_PROMPT],
Expand Down
13 changes: 11 additions & 2 deletions backend/apps/terminology/api/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def inner():
"description": obj.description,
"all_data_sources": 'N' if obj.specific_ds else 'Y',
"datasource": ', '.join(obj.datasource_names) if obj.datasource_names and obj.specific_ds else '',
"advanced_application_name": obj.advanced_application_name or '',
}
data_list.append(_data)

Expand All @@ -91,6 +92,7 @@ def inner():
fields.append(AxisObj(name=trans('i18n_terminology.term_description'), value='description'))
fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources'), value='datasource'))
fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources'), value='all_data_sources'))
fields.append(AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))

md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)

Expand Down Expand Up @@ -119,6 +121,7 @@ def inner():
"description": trans('i18n_terminology.term_description_template_example_1'),
"all_data_sources": 'N',
"datasource": trans('i18n_terminology.effective_data_sources_template_example_1'),
"advanced_application_name": '',
}
data_list.append(_data1)
_data2 = {
Expand All @@ -127,6 +130,7 @@ def inner():
"description": trans('i18n_terminology.term_description_template_example_2'),
"all_data_sources": 'Y',
"datasource": '',
"advanced_application_name": '',
}
data_list.append(_data2)

Expand All @@ -136,6 +140,7 @@ def inner():
fields.append(AxisObj(name=trans('i18n_terminology.term_description_template'), value='description'))
fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources_template'), value='datasource'))
fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources_template'), value='all_data_sources'))
fields.append(AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))

md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)

Expand Down Expand Up @@ -180,7 +185,7 @@ async def upload_excel(trans: Trans, current_user: CurrentUser, file: UploadFile

oid = current_user.oid

use_cols = [0, 1, 2, 3, 4]
use_cols = [0, 1, 2, 3, 4, 5]

def inner():

Expand Down Expand Up @@ -217,9 +222,11 @@ def inner():
3].strip() else []
all_datasource = True if pd.notna(row[4]) and row[4].lower().strip() in ['y', 'yes', 'true'] else False
specific_ds = False if all_datasource else True
advanced_application_name = row[5].strip() if pd.notna(row[5]) and row[5].strip() else None

import_data.append(TerminologyInfo(word=word, description=description, other_words=other_words,
datasource_names=datasource_names, specific_ds=specific_ds))
datasource_names=datasource_names, specific_ds=specific_ds,
advanced_application_name=advanced_application_name))

res = batch_create_terminology(session, import_data, oid, trans)

Expand All @@ -237,6 +244,7 @@ def inner():
"all_data_sources": 'N' if obj['data'].specific_ds else 'Y',
"datasource": ', '.join(obj['data'].datasource_names) if obj['data'].datasource_names and obj[
'data'].specific_ds else '',
"advanced_application_name": obj['data'].advanced_application_name or '',
"errors": obj['errors']
}
data_list.append(_data)
Expand All @@ -247,6 +255,7 @@ def inner():
fields.append(AxisObj(name=trans('i18n_terminology.term_description'), value='description'))
fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources'), value='datasource'))
fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources'), value='all_data_sources'))
fields.append(AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))
fields.append(AxisObj(name=trans('i18n_data_training.error_info'), value='errors'))

md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)
Expand Down
Loading
Loading