Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions prompt-service/src/unstract/prompt_service/helpers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def validate_bearer_token(token: str | None) -> bool:

platform_key_table = f'"{DB_SCHEMA}".{DBTableV2.PLATFORM_KEY}'
with db_context():
query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'"
cursor = db.execute_sql(query)
query = f"SELECT * FROM {platform_key_table} WHERE key = %s"
cursor = db.execute_sql(query, (token,))
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
Expand Down Expand Up @@ -66,12 +66,10 @@ def get_account_from_bearer_token(token: str | None) -> str:
platform_key_table = DBTableV2.PLATFORM_KEY
organization_table = DBTableV2.ORGANIZATION

query = f"SELECT organization_id FROM {platform_key_table} WHERE key='{token}'"
organization = DBUtils.execute_query(query)
query_org = (
f"SELECT schema_name FROM {organization_table} WHERE id='{organization}'"
)
schema_name: str = DBUtils.execute_query(query_org)
query = f"SELECT organization_id FROM {platform_key_table} WHERE key=%s"
organization = DBUtils.execute_query(query, (token,))
query_org = f"SELECT schema_name FROM {organization_table} WHERE id=%s"
schema_name: str = DBUtils.execute_query(query_org, (organization,))
return schema_name

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
BigQueryNotFoundException,
ColumnMissingException,
)
from unstract.connectors.databases.sql_safety import (
QuoteStyle,
safe_identifier,
validate_identifier,
)
from unstract.connectors.databases.unstract_db import UnstractDB
from unstract.connectors.exceptions import ConnectorError

Expand Down Expand Up @@ -68,6 +73,9 @@ def can_write() -> bool:
def can_read() -> bool:
return True

def get_quote_style(self) -> QuoteStyle:
return QuoteStyle.BACKTICK

@staticmethod
def _sanitize_for_bigquery(data: Any) -> Any:
"""BigQuery-specific float sanitization for PARSE_JSON compatibility.
Expand Down Expand Up @@ -118,7 +126,7 @@ def get_engine(self) -> Any:
info=self.json_credentials
)

def execute(self, query: str) -> Any:
def execute(self, query: str, params: Any = None) -> Any:
try:
query_job = self.get_engine().query(query)
return query_job.result()
Expand Down Expand Up @@ -169,8 +177,9 @@ def get_create_table_base_query(self, table: str) -> str:
"Please ensure the BigQuery table is in the form of "
"{project}.{dataset}.{table}."
)
qt = safe_identifier(table, QuoteStyle.BACKTICK, allow_dots=True)
sql_query = (
f"CREATE TABLE IF NOT EXISTS {table} "
f"CREATE TABLE IF NOT EXISTS {qt} "
f"(id STRING,"
f"created_by STRING, created_at TIMESTAMP, "
f"metadata JSON, "
Expand All @@ -183,9 +192,11 @@ def get_create_table_base_query(self, table: str) -> str:
return sql_query

def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str:
qt = safe_identifier(table_name, QuoteStyle.BACKTICK, allow_dots=True)
qc = safe_identifier(f"{column_name}_v2", QuoteStyle.BACKTICK)
sql_query = (
f"ALTER TABLE {table_name} "
f"ADD COLUMN {column_name}_v2 JSON, "
f"ALTER TABLE {qt} "
f"ADD COLUMN {qc} JSON, "
f"ADD COLUMN metadata JSON, "
f"ADD COLUMN user_field_1 BOOL, "
f"ADD COLUMN user_field_2 INT64, "
Expand All @@ -195,9 +206,8 @@ def prepare_multi_column_migration(self, table_name: str, column_name: str) -> s
)
return sql_query

@staticmethod
def get_sql_insert_query(
table_name: str, sql_keys: list[str], sql_values: list[str] | None = None
self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None
) -> str:
"""Function to generate parameterised insert sql query.

Expand All @@ -209,15 +219,16 @@ def get_sql_insert_query(
Returns:
str: returns a string with parameterised insert sql query
"""
qt = safe_identifier(table_name, QuoteStyle.BACKTICK, allow_dots=True)
# BigQuery uses @ parameterization, ignore sql_values for now
# Escape column names with backticks to handle special characters like underscores
escaped_keys = [f"`{key}`" for key in sql_keys]
escaped_keys = [safe_identifier(key, QuoteStyle.BACKTICK) for key in sql_keys]
keys_str = ",".join(escaped_keys)

# Also escape parameter names with backticks to handle underscores in parameter names
# safe_identifier above already validates each key
escaped_params = [f"@`{key}`" for key in sql_keys]
values_placeholder = ",".join(escaped_params)
return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})"
return f"INSERT INTO {qt} ({keys_str}) VALUES ({values_placeholder})"

def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
Expand Down Expand Up @@ -339,12 +350,27 @@ def get_information_schema(self, table_name: str) -> dict[str, str]:
project = bigquery_table_parts[0].lower()
dataset = bigquery_table_parts[1]
table = bigquery_table_parts[2]
# Validate identifier parts to prevent injection in schema path
validate_identifier(project)
validate_identifier(dataset)
validate_identifier(table)
query = (
"SELECT column_name, data_type FROM "
f"{project}.{dataset}.INFORMATION_SCHEMA.COLUMNS WHERE "
f"table_name = '{table}'"
f"`{project}`.`{dataset}`.INFORMATION_SCHEMA.COLUMNS WHERE "
"table_name = @table_name"
)
from google.cloud import bigquery as bq_module

job_config = bq_module.QueryJobConfig(
query_parameters=[
bq_module.ScalarQueryParameter("table_name", "STRING", table)
]
)
results = self.execute(query=query)
try:
query_job = self.get_engine().query(query, job_config=job_config)
results = query_job.result()
except Exception as e:
raise ConnectorError(str(e))

# If table doesn't exist, execute returns None
if results is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from pymysql.connections import Connection

from unstract.connectors.databases.mysql_handler import MysqlHandler
from unstract.connectors.databases.sql_safety import (
QuoteStyle,
safe_identifier,
)
from unstract.connectors.databases.unstract_db import UnstractDB
from unstract.connectors.exceptions import ConnectorError

Expand Down Expand Up @@ -72,6 +76,9 @@ def get_engine(self) -> Connection: # type: ignore[type-arg]
)
raise ConnectorError(error_msg) from e

def get_quote_style(self) -> QuoteStyle:
return QuoteStyle.BACKTICK

def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str:
return str(MysqlHandler.sql_to_db_mapping(value=value, column_name=column_name))

Expand All @@ -97,8 +104,9 @@ def get_create_table_base_query(self, table: str) -> str:
Returns:
str: generates a create sql base query with the constant columns
"""
qt = safe_identifier(table, QuoteStyle.BACKTICK)
sql_query = (
f"CREATE TABLE IF NOT EXISTS {table} "
f"CREATE TABLE IF NOT EXISTS {qt} "
f"(id LONGTEXT, "
f"created_by LONGTEXT, created_at TIMESTAMP, "
f"metadata JSON, "
Expand All @@ -115,18 +123,20 @@ def get_information_schema(self, table_name: str) -> dict[str, str]:
query = (
"SELECT column_name, data_type FROM "
"information_schema.columns WHERE "
f"UPPER(table_name) = UPPER('{table_name}') AND table_schema = '{self.database}'"
"UPPER(table_name) = UPPER(%s) AND table_schema = %s"
)
results = self.execute(query=query)
results = self.execute(query=query, params=(table_name, self.database))
column_types: dict[str, str] = self.get_db_column_types(
columns_with_types=results
)
return column_types

def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str:
qt = safe_identifier(table_name, QuoteStyle.BACKTICK)
qc = safe_identifier(f"{column_name}_v2", QuoteStyle.BACKTICK)
sql_query = (
f"ALTER TABLE {table_name} "
f"ADD COLUMN {column_name}_v2 JSON, "
f"ALTER TABLE {qt} "
f"ADD COLUMN {qc} JSON, "
f"ADD COLUMN metadata JSON, "
f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, "
f"ADD COLUMN user_field_2 BIGINT DEFAULT 0, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
InvalidSyntaxException,
)
from unstract.connectors.databases.exceptions_helper import ExceptionHelper
from unstract.connectors.databases.sql_safety import (
QuoteStyle,
safe_identifier,
)
from unstract.connectors.databases.unstract_db import UnstractDB

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,6 +67,9 @@ def can_write() -> bool:
def can_read() -> bool:
return True

def get_quote_style(self) -> QuoteStyle:
return QuoteStyle.SQUARE_BRACKET

def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str:
"""Gets the python datatype of value and converts python datatype to
corresponding DB datatype.
Expand Down Expand Up @@ -105,29 +112,40 @@ def get_create_table_base_query(self, table: str) -> str:
Returns:
str: generates a create sql base query with the constant columns
"""
# Parse schema and table name for existence check
# Parse schema and table name for existence check.
# validate_identifier() is called by safe_identifier() and rejects any
# SQL metacharacters, so the values used in WHERE are safe.
if "." in table:
# Handle schema.table format like "[schema].[table]"
parts = table.rsplit(".", 1)
schema_name, table_name = parts[0], parts[1]
safe_schema = safe_identifier(schema_name, QuoteStyle.SQUARE_BRACKET)
safe_table = safe_identifier(table_name, QuoteStyle.SQUARE_BRACKET)
# schema_name/table_name are validated by safe_identifier above
existence_check = (
f"IF NOT EXISTS ("
f"SELECT * FROM INFORMATION_SCHEMA.TABLES "
f"WHERE TABLE_SCHEMA = '{schema_name}' AND TABLE_NAME = '{table_name}'"
f"WHERE TABLE_SCHEMA = '{schema_name}' "
f"AND TABLE_NAME = '{table_name}'"
f")"
)
quoted_full_table = f"{safe_schema}.{safe_table}"
else:
# Handle unqualified table names (default to dbo schema)
safe_table = safe_identifier(table, QuoteStyle.SQUARE_BRACKET)
# table is validated by safe_identifier above
existence_check = (
f"IF NOT EXISTS ("
f"SELECT * FROM INFORMATION_SCHEMA.TABLES "
f"WHERE TABLE_SCHEMA = 'dbo' AND TABLE_NAME = '{table}'"
f"WHERE TABLE_SCHEMA = 'dbo' "
f"AND TABLE_NAME = '{table}'"
f")"
)
quoted_full_table = safe_table

sql_query = (
f"{existence_check} "
f" CREATE TABLE {table} "
f" CREATE TABLE {quoted_full_table} "
f"(id NVARCHAR(MAX), "
f"created_by NVARCHAR(MAX), created_at DATETIMEOFFSET, "
f"metadata NVARCHAR(MAX), "
Expand All @@ -147,14 +165,16 @@ def prepare_multi_column_migration(
MSSQL doesn't support adding multiple columns in a single ALTER TABLE statement,
so we return a list of individual statements like Snowflake.
"""
qt = safe_identifier(table_name, QuoteStyle.SQUARE_BRACKET, allow_dots=True)
qc = safe_identifier(f"{column_name}_v2", QuoteStyle.SQUARE_BRACKET)
sql_statements = [
f"ALTER TABLE {table_name} ADD {column_name}_v2 NVARCHAR(MAX)",
f"ALTER TABLE {table_name} ADD metadata NVARCHAR(MAX)",
f"ALTER TABLE {table_name} ADD user_field_1 BIT DEFAULT 0",
f"ALTER TABLE {table_name} ADD user_field_2 INT DEFAULT 0",
f"ALTER TABLE {table_name} ADD user_field_3 NVARCHAR(MAX) DEFAULT NULL",
f"ALTER TABLE {table_name} ADD status NVARCHAR(10)",
f"ALTER TABLE {table_name} ADD error_message NVARCHAR(MAX)",
f"ALTER TABLE {qt} ADD {qc} NVARCHAR(MAX)",
f"ALTER TABLE {qt} ADD metadata NVARCHAR(MAX)",
f"ALTER TABLE {qt} ADD user_field_1 BIT DEFAULT 0",
f"ALTER TABLE {qt} ADD user_field_2 INT DEFAULT 0",
f"ALTER TABLE {qt} ADD user_field_3 NVARCHAR(MAX) DEFAULT NULL",
f"ALTER TABLE {qt} ADD status NVARCHAR(10)",
f"ALTER TABLE {qt} ADD error_message NVARCHAR(MAX)",
]
return sql_statements

Expand Down Expand Up @@ -214,19 +234,19 @@ def get_information_schema(self, table_name: str) -> dict[str, str]:
parts = table_name.rsplit(".", 1)
schema_name, table_only = parts[0], parts[1]
query = (
f"SELECT column_name, data_type FROM "
f"information_schema.columns WHERE "
f"table_schema = '{schema_name}' AND table_name = '{table_only}'"
"SELECT column_name, data_type FROM "
"information_schema.columns WHERE "
"table_schema = %s AND table_name = %s"
)
results = self.execute(query=query, params=(schema_name, table_only))
else:
# Handle unqualified table names (default to dbo)
query = (
f"SELECT column_name, data_type FROM "
f"information_schema.columns WHERE "
f"table_schema = 'dbo' AND table_name = '{table_name}'"
"SELECT column_name, data_type FROM "
"information_schema.columns WHERE "
"table_schema = 'dbo' AND table_name = %s"
)

results = self.execute(query=query)
results = self.execute(query=query, params=(table_name,))
column_types: dict[str, str] = self.get_db_column_types(
columns_with_types=results
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from pymysql.connections import Connection

from unstract.connectors.databases.mysql_handler import MysqlHandler
from unstract.connectors.databases.sql_safety import (
QuoteStyle,
safe_identifier,
)
from unstract.connectors.databases.unstract_db import UnstractDB


Expand Down Expand Up @@ -52,15 +56,18 @@ def can_read() -> bool:
def get_string_type(self) -> str:
return "longtext"

def get_quote_style(self) -> QuoteStyle:
return QuoteStyle.BACKTICK

def get_information_schema(self, table_name: str) -> dict[str, str]:
"""Get information schema for MySQL database."""
# Try case-insensitive search since MySQL table names can be case-sensitive
query = (
"SELECT column_name, data_type FROM "
"information_schema.columns WHERE "
f"UPPER(table_name) = UPPER('{table_name}') AND table_schema = '{self.database}'"
"UPPER(table_name) = UPPER(%s) AND table_schema = %s"
)
results = self.execute(query=query)
results = self.execute(query=query, params=(table_name, self.database))
column_types: dict[str, str] = self.get_db_column_types(
columns_with_types=results
)
Expand Down Expand Up @@ -89,8 +96,9 @@ def get_create_table_base_query(self, table: str) -> str:
Returns:
str: generates a create sql base query with the constant columns
"""
qt = safe_identifier(table, QuoteStyle.BACKTICK)
sql_query = (
f"CREATE TABLE IF NOT EXISTS {table} "
f"CREATE TABLE IF NOT EXISTS {qt} "
f"(id LONGTEXT, "
f"created_by LONGTEXT, created_at TIMESTAMP, "
f"metadata JSON, "
Expand All @@ -103,9 +111,11 @@ def get_create_table_base_query(self, table: str) -> str:
return sql_query

def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str:
qt = safe_identifier(table_name, QuoteStyle.BACKTICK)
qc = safe_identifier(f"{column_name}_v2", QuoteStyle.BACKTICK)
sql_query = (
f"ALTER TABLE {table_name} "
f"ADD COLUMN {column_name}_v2 JSON, "
f"ALTER TABLE {qt} "
f"ADD COLUMN {qc} JSON, "
f"ADD COLUMN metadata JSON, "
f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, "
f"ADD COLUMN user_field_2 BIGINT DEFAULT 0, "
Expand Down
Loading
Loading