Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ def __init__(
):
super().__init__(meta)

_indexing = indexing.DBSchema(
_db_schema_indexing = indexing.DBSchema(
**pipe_components["db_schema_indexing"],
column_batch_size=settings.column_indexing_batch_size,
)
deploy_model(mdl, _indexing)
_table_description_indexing = indexing.TableDescription(
**pipe_components["table_description_indexing"],
)
deploy_model(mdl, [_db_schema_indexing, _table_description_indexing])

self._retrieval = retrieval.Retrieval(
**pipe_components["db_schema_retrieval"],
Expand Down Expand Up @@ -247,6 +250,7 @@ async def _process(self, prediction: dict, document: list, **_) -> dict:
samples=prediction["samples"],
has_calculated_field=prediction.get("has_calculated_field", False),
has_metric=prediction.get("has_metric", False),
sql_generation_reasoning=prediction.get("reasoning", ""),
)

prediction["actual_output"] = actual_output
Expand Down Expand Up @@ -337,6 +341,7 @@ async def _process(self, prediction: dict, **_) -> dict:
sql_samples=[],
has_calculated_field=has_calculated_field,
has_metric=has_metric,
sql_generation_reasoning=prediction.get("reasoning", ""),
)

prediction["actual_output"] = actual_output
Expand Down
4 changes: 1 addition & 3 deletions wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,8 @@ def prompt(
"db_schemas token count is greater than 100,000, so we will prune columns"
)
db_schemas = [
ddl
build_table_ddl(construct_db_schema)[0]
for construct_db_schema in construct_db_schemas
for ddl, _ in build_table_ddl(construct_db_schema)
]

if history:
Expand Down Expand Up @@ -418,7 +417,6 @@ def __init__(
# for the first time, we need to load the encodings
_model = llm_provider.get_model()
if "gpt-4o" in _model or "gpt-4o-mini" in _model:
allow_using_db_schemas_without_pruning = True
_encoding = tiktoken.get_encoding("o200k_base")
else:
_encoding = tiktoken.get_encoding("cl100k_base")
Expand Down
4 changes: 4 additions & 0 deletions wren-launcher/utils/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ func PrepareConfigFileForOpenAI(projectDir string, generationModel string) error
config := string(content)
config = strings.ReplaceAll(config, "litellm_llm.gpt-4o-mini-2024-07-18", "litellm_llm."+generationModelToModelName[generationModel])

// replace allow_using_db_schemas_without_pruning setting
// enable this feature since OpenAI models have sufficient context window size to handle full schema
config = strings.ReplaceAll(config, "allow_using_db_schemas_without_pruning: false", "allow_using_db_schemas_without_pruning: true")

// write back to config.yaml
err = os.WriteFile(configPath, []byte(config), 0644)
if err != nil {
Expand Down
Loading