Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
# add wren-ai-service to sys.path
sys.path.append(f"{Path().parent.parent.resolve()}")
from eval.utils import (
add_quotes,
get_contexts_from_sql,
get_data_from_wren_engine,
get_ddl_commands,
get_documents_given_contexts,
)
from src.core.engine import add_quotes
from src.pipelines.indexing.db_schema import DDLChunker

load_dotenv()
Expand Down
18 changes: 4 additions & 14 deletions wren-ai-service/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import uuid
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple, get_args
from typing import Any, Dict, List, Literal, Optional, get_args

import aiohttp
import orjson
import psycopg2
import requests
import sqlglot
import tomlkit
import yaml
from dotenv import load_dotenv
Expand All @@ -19,21 +18,12 @@

import docker
from eval import WREN_ENGINE_API_URL, EvalSettings
from src.core.engine import add_quotes
from src.providers.engine.wren import WrenEngine

load_dotenv(".env", override=True)


def add_quotes(sql: str) -> Tuple[str, bool]:
try:
quoted_sql = sqlglot.transpile(sql, read=None, identify=True)[0]
return quoted_sql, True
except Exception as e:
print(f"Error in adding quotes to SQL: {sql}")
print(f"Error: {e}")
return sql, False


async def get_data_from_wren_engine(
sql: str,
mdl_json: dict,
Expand All @@ -43,8 +33,8 @@ async def get_data_from_wren_engine(
timeout: float = 300,
limit: Optional[int] = None,
):
quoted_sql, no_error = add_quotes(sql)
assert no_error, f"Error in quoting SQL: {sql}"
quoted_sql, error = add_quotes(sql)
assert not error, f"Error in quoting SQL: {sql}"

if data_source == "duckdb":
async with aiohttp.request(
Expand Down
112 changes: 103 additions & 9 deletions wren-ai-service/src/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import Any, Dict, Optional, Tuple

import aiohttp
import sqlglot
import sqlparse
from pydantic import BaseModel
from sqlglot.tokens import Token, Tokenizer, TokenType

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -49,17 +50,110 @@ def remove_limit_statement(sql: str) -> str:
return modified_sql


def squish_sql(sql: str) -> str:
return (
sqlparse.format(
sql,
strip_comments=False,
reindent=False, # don't add newlines/indent
keyword_case=None, # don't change case
)
.replace("\n", " ")
.replace("\r", " ")
.strip()
)


def add_quotes(sql: str) -> Tuple[str, str]:
def _quote_sql_identifiers_by_tokens(sql: str, quote_char: str = '"') -> str:
"""
Add quotes around identifiers using SQLGlot's tokenizer positions.
"""

def is_ident(tok: Token):
# SQLGlot uses VAR for identifiers, but also treats SQL keywords as identifiers in some contexts
return tok.token_type in (
TokenType.VAR,
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.COLUMN,
TokenType.DATABASE,
TokenType.INDEX,
TokenType.VIEW,
)

def is_already_quoted_text(text: str) -> bool:
text = text.strip()
return (
(len(text) >= 2 and text[0] == '"' and text[-1] == '"')
or (len(text) >= 2 and text[0] == "`" and text[-1] == "`")
or (len(text) >= 2 and text[0] == "[" and text[-1] == "]")
)

toks = Tokenizer().tokenize(sql)
n = len(toks)
edits = [] # (start, end_exclusive, replacement)

i = 0
while i < n:
t = toks[i]

if not is_ident(t):
i += 1
continue

# Check for wildcard pattern: IDENT DOT STAR (e.g., t.*)
if (
i + 2 < n
and toks[i + 1].token_type == TokenType.DOT
and toks[i + 2].token_type == TokenType.STAR
):
i += 3 # Skip the entire wildcard pattern
continue

# Check if this is part of a dotted chain
j = i
chain_tokens = [t] # Start with current identifier

# Collect all tokens in the dotted chain: IDENT (DOT IDENT)*
while (
j + 2 < n
and toks[j + 1].token_type == TokenType.DOT
and is_ident(toks[j + 2])
):
chain_tokens.append(toks[j + 1]) # DOT
chain_tokens.append(toks[j + 2]) # IDENT
j += 2

# If the next token after the chain is '(', it's a function call -> skip
if j + 1 < n and toks[j + 1].token_type == TokenType.L_PAREN:
i = j + 1
continue

# Process each identifier in the chain separately to ensure all are quoted
for k in range(
0, len(chain_tokens), 2
): # Process only identifiers (skip dots)
ident_token = chain_tokens[k]
token_text = sql[ident_token.start : ident_token.end + 1]

if not is_already_quoted_text(token_text):
replacement = f"{quote_char}{token_text}{quote_char}"
edits.append((ident_token.start, ident_token.end + 1, replacement))

i = j + 1

# Apply edits right-to-left to keep offsets valid
out = sql
for start, end, repl in sorted(edits, key=lambda x: x[0], reverse=True):
out = out[:start] + repl + out[end:]
return out

try:
quoted_sql = sqlglot.transpile(
sql,
read=None,
identify=True,
error_level=sqlglot.ErrorLevel.RAISE,
unsupported_level=sqlglot.ErrorLevel.RAISE,
)[0]
sql = squish_sql(sql)
quoted_sql = _quote_sql_identifiers_by_tokens(sql)
except Exception as e:
logger.exception(f"Error in sqlglot.transpile to {sql}: {e}")
logger.exception(f"Error in adding quotes to {sql}: {e}")

return "", str(e)

Expand Down
Loading
Loading