Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ def __init__(
):
super().__init__(meta)

_indexing = indexing.DBSchema(
_db_schema_indexing = indexing.DBSchema(
**pipe_components["db_schema_indexing"],
column_batch_size=settings.column_indexing_batch_size,
)
deploy_model(mdl, _indexing)
_table_description_indexing = indexing.TableDescription(
**pipe_components["table_description_indexing"],
)
deploy_model(mdl, [_db_schema_indexing, _table_description_indexing])

self._retrieval = retrieval.Retrieval(
**pipe_components["db_schema_retrieval"],
Expand Down Expand Up @@ -247,6 +250,7 @@ async def _process(self, prediction: dict, document: list, **_) -> dict:
samples=prediction["samples"],
has_calculated_field=prediction.get("has_calculated_field", False),
has_metric=prediction.get("has_metric", False),
sql_generation_reasoning=prediction.get("reasoning", ""),
)

prediction["actual_output"] = actual_output
Expand Down Expand Up @@ -337,6 +341,7 @@ async def _process(self, prediction: dict, **_) -> dict:
sql_samples=[],
has_calculated_field=has_calculated_field,
has_metric=has_metric,
sql_generation_reasoning=prediction.get("reasoning", ""),
)

prediction["actual_output"] = actual_output
Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,8 @@ def prompt(
"db_schemas token count is greater than 100,000, so we will prune columns"
)
db_schemas = [
ddl
build_table_ddl(construct_db_schema)[0]
for construct_db_schema in construct_db_schemas
for ddl, _ in build_table_ddl(construct_db_schema)
]

if history:
Expand Down Expand Up @@ -417,8 +416,9 @@ def __init__(

# for the first time, we need to load the encodings
_model = llm_provider.get_model()
if "gpt-4o" in _model or "gpt-4o-mini" in _model:
allow_using_db_schemas_without_pruning = True
if allow_using_db_schemas_without_pruning and (
"gpt-4o" in _model or "gpt-4o-mini" in _model
):
_encoding = tiktoken.get_encoding("o200k_base")
else:
_encoding = tiktoken.get_encoding("cl100k_base")
Expand Down
Loading