diff --git a/prompt-service/src/unstract/prompt_service/helpers/auth.py b/prompt-service/src/unstract/prompt_service/helpers/auth.py index 3c74e5abe1..cd9fd58128 100644 --- a/prompt-service/src/unstract/prompt_service/helpers/auth.py +++ b/prompt-service/src/unstract/prompt_service/helpers/auth.py @@ -21,8 +21,8 @@ def validate_bearer_token(token: str | None) -> bool: platform_key_table = f'"{DB_SCHEMA}".{DBTableV2.PLATFORM_KEY}' with db_context(): - query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" - cursor = db.execute_sql(query) + query = f"SELECT * FROM {platform_key_table} WHERE key = %s" + cursor = db.execute_sql(query, (token,)) result_row = cursor.fetchone() cursor.close() if not result_row or len(result_row) == 0: @@ -66,12 +66,10 @@ def get_account_from_bearer_token(token: str | None) -> str: platform_key_table = DBTableV2.PLATFORM_KEY organization_table = DBTableV2.ORGANIZATION - query = f"SELECT organization_id FROM {platform_key_table} WHERE key='{token}'" - organization = DBUtils.execute_query(query) - query_org = ( - f"SELECT schema_name FROM {organization_table} WHERE id='{organization}'" - ) - schema_name: str = DBUtils.execute_query(query_org) + query = f"SELECT organization_id FROM {platform_key_table} WHERE key=%s" + organization = DBUtils.execute_query(query, (token,)) + query_org = f"SELECT schema_name FROM {organization_table} WHERE id=%s" + schema_name: str = DBUtils.execute_query(query_org, (organization,)) return schema_name @staticmethod diff --git a/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py b/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py index c9808ee454..d9e05a7055 100644 --- a/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py +++ b/unstract/connectors/src/unstract/connectors/databases/bigquery/bigquery.py @@ -15,6 +15,11 @@ BigQueryNotFoundException, ColumnMissingException, ) +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, + validate_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB from unstract.connectors.exceptions import ConnectorError @@ -68,6 +73,9 @@ def can_write() -> bool: def can_read() -> bool: return True + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.BACKTICK + @staticmethod def _sanitize_for_bigquery(data: Any) -> Any: """BigQuery-specific float sanitization for PARSE_JSON compatibility. @@ -118,7 +126,7 @@ def get_engine(self) -> Any: info=self.json_credentials ) - def execute(self, query: str) -> Any: + def execute(self, query: str, params: Any = None) -> Any: try: query_job = self.get_engine().query(query) return query_job.result() @@ -169,8 +177,9 @@ def get_create_table_base_query(self, table: str) -> str: "Please ensure the BigQuery table is in the form of " "{project}.{dataset}.{table}." ) + qt = safe_identifier(table, QuoteStyle.BACKTICK, allow_dots=True) sql_query = ( - f"CREATE TABLE IF NOT EXISTS {table} " + f"CREATE TABLE IF NOT EXISTS {qt} " f"(id STRING," f"created_by STRING, created_at TIMESTAMP, " f"metadata JSON, " @@ -183,9 +192,11 @@ def get_create_table_base_query(self, table: str) -> str: return sql_query def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str: + qt = safe_identifier(table_name, QuoteStyle.BACKTICK, allow_dots=True) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.BACKTICK) sql_query = ( - f"ALTER TABLE {table_name} " - f"ADD COLUMN {column_name}_v2 JSON, " + f"ALTER TABLE {qt} " + f"ADD COLUMN {qc} JSON, " f"ADD COLUMN metadata JSON, " f"ADD COLUMN user_field_1 BOOL, " f"ADD COLUMN user_field_2 INT64, " @@ -195,9 +206,8 @@ def prepare_multi_column_migration(self, table_name: str, column_name: str) -> s ) return sql_query - @staticmethod def get_sql_insert_query( - table_name: str, sql_keys: list[str], sql_values: list[str] | None = None + self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None ) -> str: """Function to generate parameterised insert sql query. @@ -209,15 +219,16 @@ def get_sql_insert_query( Returns: str: returns a string with parameterised insert sql query """ + qt = safe_identifier(table_name, QuoteStyle.BACKTICK, allow_dots=True) # BigQuery uses @ parameterization, ignore sql_values for now # Escape column names with backticks to handle special characters like underscores - escaped_keys = [f"`{key}`" for key in sql_keys] + escaped_keys = [safe_identifier(key, QuoteStyle.BACKTICK) for key in sql_keys] keys_str = ",".join(escaped_keys) - # Also escape parameter names with backticks to handle underscores in parameter names + # safe_identifier above already validates each key escaped_params = [f"@`{key}`" for key in sql_keys] values_placeholder = ",".join(escaped_params) - return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})" + return f"INSERT INTO {qt} ({keys_str}) VALUES ({values_placeholder})" def execute_query( self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any @@ -339,12 +350,27 @@ def get_information_schema(self, table_name: str) -> dict[str, str]: project = bigquery_table_parts[0].lower() dataset = bigquery_table_parts[1] table = bigquery_table_parts[2] + # Validate identifier parts to prevent injection in schema path + validate_identifier(project) + validate_identifier(dataset) + validate_identifier(table) query = ( "SELECT column_name, data_type FROM " - f"{project}.{dataset}.INFORMATION_SCHEMA.COLUMNS WHERE " - f"table_name = '{table}'" + f"`{project}`.`{dataset}`.INFORMATION_SCHEMA.COLUMNS WHERE " + "table_name = @table_name" + ) + from google.cloud import bigquery as bq_module + + job_config = bq_module.QueryJobConfig( + query_parameters=[ + bq_module.ScalarQueryParameter("table_name", "STRING", table) + ] ) - results = self.execute(query=query) + try: + query_job = self.get_engine().query(query, job_config=job_config) + results = query_job.result() + except Exception as e: + raise ConnectorError(str(e)) # If table doesn't exist, execute returns None if results is None: diff --git a/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py b/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py index fdbede8f2d..99dfb1c0cb 100644 --- a/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py +++ b/unstract/connectors/src/unstract/connectors/databases/mariadb/mariadb.py @@ -6,6 +6,10 @@ from pymysql.connections import Connection from unstract.connectors.databases.mysql_handler import MysqlHandler +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB from unstract.connectors.exceptions import ConnectorError @@ -72,6 +76,9 @@ def get_engine(self) -> Connection: # type: ignore[type-arg] ) raise ConnectorError(error_msg) from e + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.BACKTICK + def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str: return str(MysqlHandler.sql_to_db_mapping(value=value, column_name=column_name)) @@ -97,8 +104,9 @@ def get_create_table_base_query(self, table: str) -> str: Returns: str: generates a create sql base query with the constant columns """ + qt = safe_identifier(table, QuoteStyle.BACKTICK) sql_query = ( - f"CREATE TABLE IF NOT EXISTS {table} " + f"CREATE TABLE IF NOT EXISTS {qt} " f"(id LONGTEXT, " f"created_by LONGTEXT, created_at TIMESTAMP, " f"metadata JSON, " @@ -115,18 +123,20 @@ def get_information_schema(self, table_name: str) -> dict[str, str]: query = ( "SELECT column_name, data_type FROM " "information_schema.columns WHERE " - f"UPPER(table_name) = UPPER('{table_name}') AND table_schema = '{self.database}'" + "UPPER(table_name) = UPPER(%s) AND table_schema = %s" ) - results = self.execute(query=query) + results = self.execute(query=query, params=(table_name, self.database)) column_types: dict[str, str] = self.get_db_column_types( columns_with_types=results ) return column_types def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str: + qt = safe_identifier(table_name, QuoteStyle.BACKTICK) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.BACKTICK) sql_query = ( - f"ALTER TABLE {table_name} " - f"ADD COLUMN {column_name}_v2 JSON, " + f"ALTER TABLE {qt} " + f"ADD COLUMN {qc} JSON, " f"ADD COLUMN metadata JSON, " f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, " f"ADD COLUMN user_field_2 BIGINT DEFAULT 0, " diff --git a/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py b/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py index b4355dac61..8239bc4254 100644 --- a/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py +++ b/unstract/connectors/src/unstract/connectors/databases/mssql/mssql.py @@ -13,6 +13,10 @@ InvalidSyntaxException, ) from unstract.connectors.databases.exceptions_helper import ExceptionHelper +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB logger = logging.getLogger(__name__) @@ -63,6 +67,9 @@ def can_write() -> bool: def can_read() -> bool: return True + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.SQUARE_BRACKET + def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str: """Gets the python datatype of value and converts python datatype to corresponding DB datatype. @@ -105,29 +112,40 @@ def get_create_table_base_query(self, table: str) -> str: Returns: str: generates a create sql base query with the constant columns """ - # Parse schema and table name for existence check + # Parse schema and table name for existence check. + # validate_identifier() is called by safe_identifier() and rejects any + # SQL metacharacters, so the values used in WHERE are safe. if "." in table: # Handle schema.table format like "[schema].[table]" parts = table.rsplit(".", 1) schema_name, table_name = parts[0], parts[1] + safe_schema = safe_identifier(schema_name, QuoteStyle.SQUARE_BRACKET) + safe_table = safe_identifier(table_name, QuoteStyle.SQUARE_BRACKET) + # schema_name/table_name are validated by safe_identifier above existence_check = ( f"IF NOT EXISTS (" f"SELECT * FROM INFORMATION_SCHEMA.TABLES " - f"WHERE TABLE_SCHEMA = '{schema_name}' AND TABLE_NAME = '{table_name}'" + f"WHERE TABLE_SCHEMA = '{schema_name}' " + f"AND TABLE_NAME = '{table_name}'" f")" ) + quoted_full_table = f"{safe_schema}.{safe_table}" else: # Handle unqualified table names (default to dbo schema) + safe_table = safe_identifier(table, QuoteStyle.SQUARE_BRACKET) + # table is validated by safe_identifier above existence_check = ( f"IF NOT EXISTS (" f"SELECT * FROM INFORMATION_SCHEMA.TABLES " - f"WHERE TABLE_SCHEMA = 'dbo' AND TABLE_NAME = '{table}'" + f"WHERE TABLE_SCHEMA = 'dbo' " + f"AND TABLE_NAME = '{table}'" f")" ) + quoted_full_table = safe_table sql_query = ( f"{existence_check} " - f" CREATE TABLE {table} " + f" CREATE TABLE {quoted_full_table} " f"(id NVARCHAR(MAX), " f"created_by NVARCHAR(MAX), created_at DATETIMEOFFSET, " f"metadata NVARCHAR(MAX), " @@ -147,14 +165,16 @@ def prepare_multi_column_migration( MSSQL doesn't support adding multiple columns in a single ALTER TABLE statement, so we return a list of individual statements like Snowflake. """ + qt = safe_identifier(table_name, QuoteStyle.SQUARE_BRACKET, allow_dots=True) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.SQUARE_BRACKET) sql_statements = [ - f"ALTER TABLE {table_name} ADD {column_name}_v2 NVARCHAR(MAX)", - f"ALTER TABLE {table_name} ADD metadata NVARCHAR(MAX)", - f"ALTER TABLE {table_name} ADD user_field_1 BIT DEFAULT 0", - f"ALTER TABLE {table_name} ADD user_field_2 INT DEFAULT 0", - f"ALTER TABLE {table_name} ADD user_field_3 NVARCHAR(MAX) DEFAULT NULL", - f"ALTER TABLE {table_name} ADD status NVARCHAR(10)", - f"ALTER TABLE {table_name} ADD error_message NVARCHAR(MAX)", + f"ALTER TABLE {qt} ADD {qc} NVARCHAR(MAX)", + f"ALTER TABLE {qt} ADD metadata NVARCHAR(MAX)", + f"ALTER TABLE {qt} ADD user_field_1 BIT DEFAULT 0", + f"ALTER TABLE {qt} ADD user_field_2 INT DEFAULT 0", + f"ALTER TABLE {qt} ADD user_field_3 NVARCHAR(MAX) DEFAULT NULL", + f"ALTER TABLE {qt} ADD status NVARCHAR(10)", + f"ALTER TABLE {qt} ADD error_message NVARCHAR(MAX)", ] return sql_statements @@ -214,19 +234,19 @@ def get_information_schema(self, table_name: str) -> dict[str, str]: parts = table_name.rsplit(".", 1) schema_name, table_only = parts[0], parts[1] query = ( - f"SELECT column_name, data_type FROM " - f"information_schema.columns WHERE " - f"table_schema = '{schema_name}' AND table_name = '{table_only}'" + "SELECT column_name, data_type FROM " + "information_schema.columns WHERE " + "table_schema = %s AND table_name = %s" ) + results = self.execute(query=query, params=(schema_name, table_only)) else: # Handle unqualified table names (default to dbo) query = ( - f"SELECT column_name, data_type FROM " - f"information_schema.columns WHERE " - f"table_schema = 'dbo' AND table_name = '{table_name}'" + "SELECT column_name, data_type FROM " + "information_schema.columns WHERE " + "table_schema = 'dbo' AND table_name = %s" ) - - results = self.execute(query=query) + results = self.execute(query=query, params=(table_name,)) column_types: dict[str, str] = self.get_db_column_types( columns_with_types=results ) diff --git a/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py b/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py index ecb5b60916..27a756fbf8 100644 --- a/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py +++ b/unstract/connectors/src/unstract/connectors/databases/mysql/mysql.py @@ -5,6 +5,10 @@ from pymysql.connections import Connection from unstract.connectors.databases.mysql_handler import MysqlHandler +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB @@ -52,15 +56,18 @@ def can_read() -> bool: def get_string_type(self) -> str: return "longtext" + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.BACKTICK + def get_information_schema(self, table_name: str) -> dict[str, str]: """Get information schema for MySQL database.""" # Try case-insensitive search since MySQL table names can be case-sensitive query = ( "SELECT column_name, data_type FROM " "information_schema.columns WHERE " - f"UPPER(table_name) = UPPER('{table_name}') AND table_schema = '{self.database}'" + "UPPER(table_name) = UPPER(%s) AND table_schema = %s" ) - results = self.execute(query=query) + results = self.execute(query=query, params=(table_name, self.database)) column_types: dict[str, str] = self.get_db_column_types( columns_with_types=results ) @@ -89,8 +96,9 @@ def get_create_table_base_query(self, table: str) -> str: Returns: str: generates a create sql base query with the constant columns """ + qt = safe_identifier(table, QuoteStyle.BACKTICK) sql_query = ( - f"CREATE TABLE IF NOT EXISTS {table} " + f"CREATE TABLE IF NOT EXISTS {qt} " f"(id LONGTEXT, " f"created_by LONGTEXT, created_at TIMESTAMP, " f"metadata JSON, " @@ -103,9 +111,11 @@ def get_create_table_base_query(self, table: str) -> str: return sql_query def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str: + qt = safe_identifier(table_name, QuoteStyle.BACKTICK) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.BACKTICK) sql_query = ( - f"ALTER TABLE {table_name} " - f"ADD COLUMN {column_name}_v2 JSON, " + f"ALTER TABLE {qt} " + f"ADD COLUMN {qc} JSON, " f"ADD COLUMN metadata JSON, " f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, " f"ADD COLUMN user_field_2 BIGINT DEFAULT 0, " diff --git a/unstract/connectors/src/unstract/connectors/databases/oracle_db/oracle_db.py b/unstract/connectors/src/unstract/connectors/databases/oracle_db/oracle_db.py index e57150f8c9..bf1d42e9c1 100644 --- a/unstract/connectors/src/unstract/connectors/databases/oracle_db/oracle_db.py +++ b/unstract/connectors/src/unstract/connectors/databases/oracle_db/oracle_db.py @@ -7,6 +7,11 @@ from oracledb.connection import Connection from unstract.connectors.constants import DatabaseTypeConstants +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, + validate_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB logger = logging.getLogger(__name__) @@ -63,6 +68,9 @@ def can_write() -> bool: def can_read() -> bool: return True + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.DOUBLE_QUOTE + def get_engine(self) -> Connection: con = oracledb.connect( config_dir=self.config_dir, @@ -108,8 +116,9 @@ def get_create_table_base_query(self, table: str) -> str: Returns: str: generates a create sql base query with the constant columns """ + qt = safe_identifier(table, QuoteStyle.DOUBLE_QUOTE) sql_query = ( - f"CREATE TABLE {table} " + f"CREATE TABLE {qt} " f"(id VARCHAR2(32767) , " f"created_by VARCHAR2(32767), created_at TIMESTAMP, " f"metadata CLOB, " @@ -136,19 +145,20 @@ def prepare_multi_column_migration(self, table_name: str, column_name: str) -> l Each column addition requires a separate ALTER TABLE statement. """ # Return one ALTER statement per column for Oracle compatibility + qt = safe_identifier(table_name, QuoteStyle.DOUBLE_QUOTE) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.DOUBLE_QUOTE) return [ - f"ALTER TABLE {table_name} ADD {column_name}_v2 CLOB", - f"ALTER TABLE {table_name} ADD metadata CLOB", - f"ALTER TABLE {table_name} ADD user_field_1 NUMBER(1) DEFAULT 0", - f"ALTER TABLE {table_name} ADD user_field_2 NUMBER DEFAULT 0", - f"ALTER TABLE {table_name} ADD user_field_3 VARCHAR2(32767) DEFAULT NULL", - f"ALTER TABLE {table_name} ADD status VARCHAR2(10)", - f"ALTER TABLE {table_name} ADD error_message VARCHAR2(32767)", + f"ALTER TABLE {qt} ADD {qc} CLOB", + f"ALTER TABLE {qt} ADD metadata CLOB", + f"ALTER TABLE {qt} ADD user_field_1 NUMBER(1) DEFAULT 0", + f"ALTER TABLE {qt} ADD user_field_2 NUMBER DEFAULT 0", + f"ALTER TABLE {qt} ADD user_field_3 VARCHAR2(32767) DEFAULT NULL", + f"ALTER TABLE {qt} ADD status VARCHAR2(10)", + f"ALTER TABLE {qt} ADD error_message VARCHAR2(32767)", ] - @staticmethod def get_sql_insert_query( - table_name: str, sql_keys: list[str], sql_values: list[str] | None = None + self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None ) -> str: """Function to generate parameterised insert sql query. @@ -160,6 +170,11 @@ def get_sql_insert_query( Returns: str: returns a string with parameterised insert sql query """ + qt = safe_identifier(table_name, QuoteStyle.DOUBLE_QUOTE) + # Validate column names but don't quote — Oracle normalizes + # unquoted to UPPERCASE, matching existing table schemas. + for k in sql_keys: + validate_identifier(k) columns = ", ".join(sql_keys) values = [] for key in sql_keys: @@ -167,7 +182,7 @@ def get_sql_insert_query( values.append("TO_TIMESTAMP(:created_at, 'YYYY-MM-DD HH24:MI:SS.FF')") else: values.append(f":{key}") - return f"INSERT INTO {table_name} ({columns}) VALUES ({', '.join(values)})" + return f"INSERT INTO {qt} ({columns}) VALUES ({', '.join(values)})" def execute_query( self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any @@ -208,9 +223,9 @@ def get_information_schema(self, table_name: str) -> dict[str, str]: query = ( "SELECT column_name, data_type FROM " "user_tab_columns WHERE " - f"table_name = UPPER('{table_name}')" + "table_name = UPPER(:table_name)" ) - results = self.execute(query=query) + results = self.execute(query=query, params={"table_name": table_name}) column_types: dict[str, str] = self.get_db_column_types( columns_with_types=results ) diff --git a/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py b/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py index b3400ea034..b2343dc10f 100644 --- a/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py +++ b/unstract/connectors/src/unstract/connectors/databases/postgresql/postgresql.py @@ -3,10 +3,16 @@ from typing import Any import psycopg2 +from psycopg2 import sql as psycopg2_sql from psycopg2.extensions import connection from unstract.connectors.constants import DatabaseTypeConstants from unstract.connectors.databases.psycopg_handler import PsycoPgHandler +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, + validate_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB @@ -72,6 +78,9 @@ def can_write() -> bool: def can_read() -> bool: return True + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.DOUBLE_QUOTE + def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str: """Gets the python datatype of value and converts python datatype to corresponding DB datatype. @@ -134,8 +143,13 @@ def get_engine(self) -> connection: # Set schema explicitly only if schema is specified (avoids PgBouncer issues) if self.schema: + validate_identifier(self.schema) with con.cursor() as cur: - cur.execute(f"SET search_path TO {self.schema};") + cur.execute( + psycopg2_sql.SQL("SET search_path TO {}").format( + psycopg2_sql.Identifier(self.schema) + ) + ) return con @@ -164,9 +178,10 @@ def get_create_table_base_query(self, table: str) -> str: def prepare_multi_column_migration(self, table_name: str, column_name: str) -> str: quoted_table = self._quote_identifier(table_name) + quoted_col_v2 = safe_identifier(f"{column_name}_v2", QuoteStyle.DOUBLE_QUOTE) sql_query = ( f"ALTER TABLE {quoted_table} " - f"ADD COLUMN {column_name}_v2 JSONB, " + f"ADD COLUMN {quoted_col_v2} JSONB, " f"ADD COLUMN metadata JSONB, " f"ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE, " f"ADD COLUMN user_field_2 INTEGER DEFAULT 0, " @@ -195,7 +210,7 @@ def _quote_identifier(identifier: str) -> str: PostgreSQL identifiers with special characters must be enclosed in double quotes. This method adds proper quoting for table names containing hyphens, spaces, - or other special characters. + or other special characters. Embedded double quotes are escaped. Args: identifier (str): Table name or column name to quote @@ -203,9 +218,7 @@ def _quote_identifier(identifier: str) -> str: Returns: str: Properly quoted identifier safe for PostgreSQL """ - # Always quote the identifier to handle special characters like hyphens - # This is safe even for valid identifiers and prevents SQL injection - return f'"{identifier}"' + return safe_identifier(identifier, QuoteStyle.DOUBLE_QUOTE) def get_sql_insert_query( self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None @@ -223,6 +236,7 @@ def get_sql_insert_query( str: INSERT query with properly quoted table name """ quoted_table = self._quote_identifier(table_name) - keys_str = ", ".join(sql_keys) + quoted_keys = [safe_identifier(k, QuoteStyle.DOUBLE_QUOTE) for k in sql_keys] + keys_str = ", ".join(quoted_keys) values_placeholder = ", ".join(["%s"] * len(sql_keys)) return f"INSERT INTO {quoted_table} ({keys_str}) VALUES ({values_placeholder})" diff --git a/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py b/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py index 0acb6786d9..7c82bf2e72 100644 --- a/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py +++ b/unstract/connectors/src/unstract/connectors/databases/redshift/redshift.py @@ -7,6 +7,11 @@ from unstract.connectors.constants import DatabaseTypeConstants from unstract.connectors.databases.psycopg_handler import PsycoPgHandler +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, + validate_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB @@ -58,7 +63,11 @@ def can_write() -> bool: def can_read() -> bool: return True + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.DOUBLE_QUOTE + def get_engine(self) -> connection: + validate_identifier(self.schema) return psycopg2.connect( host=self.host, port=self.port, @@ -96,8 +105,9 @@ def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str: return str(mapping.get(data_type, DatabaseTypeConstants.REDSHIFT_VARCHAR)) def get_create_table_base_query(self, table: str) -> str: + quoted_table = safe_identifier(table, QuoteStyle.DOUBLE_QUOTE) sql_query = ( - f"CREATE TABLE IF NOT EXISTS {table} " + f"CREATE TABLE IF NOT EXISTS {quoted_table} " f"(id VARCHAR(65535) ," f"created_by VARCHAR(65535), created_at TIMESTAMP, " f"metadata SUPER, " @@ -125,14 +135,16 @@ def prepare_multi_column_migration(self, table_name: str, column_name: str) -> l or use dynamic SQL to make these operations idempotent. """ # Return one ALTER statement per column for Redshift compatibility + qt = safe_identifier(table_name, QuoteStyle.DOUBLE_QUOTE) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.DOUBLE_QUOTE) return [ - f"ALTER TABLE {table_name} ADD COLUMN {column_name}_v2 SUPER;", - f"ALTER TABLE {table_name} ADD COLUMN metadata SUPER;", - f"ALTER TABLE {table_name} ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE;", - f"ALTER TABLE {table_name} ADD COLUMN user_field_2 INTEGER DEFAULT 0;", - f"ALTER TABLE {table_name} ADD COLUMN user_field_3 VARCHAR(65535) DEFAULT NULL;", - f"ALTER TABLE {table_name} ADD COLUMN status VARCHAR(256);", - f"ALTER TABLE {table_name} ADD COLUMN error_message VARCHAR(65535);", + f"ALTER TABLE {qt} ADD COLUMN {qc} SUPER;", + f"ALTER TABLE {qt} ADD COLUMN metadata SUPER;", + f"ALTER TABLE {qt} ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE;", + f"ALTER TABLE {qt} ADD COLUMN user_field_2 INTEGER DEFAULT 0;", + f"ALTER TABLE {qt} ADD COLUMN user_field_3 VARCHAR(65535) DEFAULT NULL;", + f"ALTER TABLE {qt} ADD COLUMN status VARCHAR(256);", + f"ALTER TABLE {qt} ADD COLUMN error_message VARCHAR(65535);", ] def execute_query( diff --git a/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py b/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py index d7b192081f..4070a438b5 100644 --- a/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py +++ b/unstract/connectors/src/unstract/connectors/databases/snowflake/snowflake.py @@ -8,6 +8,11 @@ from unstract.connectors.constants import DatabaseTypeConstants from unstract.connectors.databases.exceptions import SnowflakeProgrammingException +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, + validate_identifier, +) from unstract.connectors.databases.unstract_db import UnstractDB from unstract.connectors.exceptions import ConnectorError @@ -61,6 +66,9 @@ def can_write() -> bool: def can_read() -> bool: return True + def get_quote_style(self) -> QuoteStyle: + return QuoteStyle.DOUBLE_QUOTE + def sql_to_db_mapping(self, value: Any, column_name: str | None = None) -> str: """Gets the python datatype of value and converts python datatype to corresponding DB datatype. @@ -103,8 +111,9 @@ def get_engine(self) -> Any: return con def get_create_table_base_query(self, table: str) -> str: + qt = safe_identifier(table, QuoteStyle.DOUBLE_QUOTE) sql_query = ( - f"CREATE TABLE IF NOT EXISTS {table} " + f"CREATE TABLE IF NOT EXISTS {qt} " f"(id TEXT ," f"created_by TEXT, created_at TIMESTAMP, " f"metadata VARIANT, " @@ -122,14 +131,16 @@ def prepare_multi_column_migration(self, table_name: str, column_name: str) -> l Snowflake doesn't support multiple ADD COLUMN clauses in a single statement, so we return a list of individual ALTER TABLE statements. """ + qt = safe_identifier(table_name, QuoteStyle.DOUBLE_QUOTE) + qc = safe_identifier(f"{column_name}_v2", QuoteStyle.DOUBLE_QUOTE) sql_statements = [ - f"ALTER TABLE {table_name} ADD COLUMN {column_name}_v2 VARIANT", - f"ALTER TABLE {table_name} ADD COLUMN metadata VARIANT", - f"ALTER TABLE {table_name} ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE", - f"ALTER TABLE {table_name} ADD COLUMN user_field_2 INT DEFAULT 0", - f"ALTER TABLE {table_name} ADD COLUMN user_field_3 TEXT DEFAULT NULL", - f"ALTER TABLE {table_name} ADD COLUMN status VARCHAR", - f"ALTER TABLE {table_name} ADD COLUMN error_message VARCHAR", + f"ALTER TABLE {qt} ADD COLUMN {qc} VARIANT", + f"ALTER TABLE {qt} ADD COLUMN metadata VARIANT", + f"ALTER TABLE {qt} ADD COLUMN user_field_1 BOOLEAN DEFAULT FALSE", + f"ALTER TABLE {qt} ADD COLUMN user_field_2 INT DEFAULT 0", + f"ALTER TABLE {qt} ADD COLUMN user_field_3 TEXT DEFAULT NULL", + f"ALTER TABLE {qt} ADD COLUMN status VARCHAR", + f"ALTER TABLE {qt} ADD COLUMN error_message VARCHAR", ] return sql_statements @@ -166,7 +177,7 @@ def execute_query( logger.error(f"SQL Query: {sql_query}") logger.error(f"SQL Values: {sql_values}") raise SnowflakeProgrammingException( - detail=f"{e.msg} | SQL: {sql_query} | Values: {sql_values}", + detail=e.msg, database=self.database, schema=self.schema, table_name=table_name, @@ -175,7 +186,8 @@ def execute_query( def get_information_schema(self, table_name: str) -> dict[str, str]: import snowflake.connector.errors as SnowflakeError - query = f"describe table {table_name}" + qt = safe_identifier(table_name, QuoteStyle.DOUBLE_QUOTE, allow_dots=True) + query = f"describe table {qt}" column_types: dict[str, str] = {} try: results = self.execute(query=query) @@ -320,7 +332,7 @@ def get_sql_values_for_query( return sql_values def get_sql_insert_query( - self, table_name: str, sql_keys: list[str], sql_values: list[str] = None + self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None ) -> str: """Generate SQL insert query for Snowflake with special handling for VARIANT columns. @@ -332,6 +344,11 @@ def get_sql_insert_query( Returns: str: Complete SQL insert query with VARIANT columns handled appropriately """ + qt = safe_identifier(table_name, QuoteStyle.DOUBLE_QUOTE) + # Validate column names but don't quote — Snowflake normalizes + # unquoted to UPPERCASE, matching existing table schemas. + for k in sql_keys: + validate_identifier(k) keys_str = ",".join(sql_keys) if sql_values: @@ -345,8 +362,8 @@ def get_sql_insert_query( if has_sql_fragments: # Build complete SQL with SELECT format for VARIANT columns values_str = ",".join(str(v) for v in sql_values) - return f"INSERT INTO {table_name} ({keys_str}) SELECT {values_str}" + return f"INSERT INTO {qt} ({keys_str}) SELECT {values_str}" # Fall back to parameterized format for standard queries values_placeholder = ",".join(["%s" for _ in sql_keys]) - return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})" + return f"INSERT INTO {qt} ({keys_str}) VALUES ({values_placeholder})" diff --git a/unstract/connectors/src/unstract/connectors/databases/sql_safety.py b/unstract/connectors/src/unstract/connectors/databases/sql_safety.py new file mode 100644 index 0000000000..35d38be1ee --- /dev/null +++ b/unstract/connectors/src/unstract/connectors/databases/sql_safety.py @@ -0,0 +1,110 @@ +"""SQL identifier safety utilities for preventing SQL injection. + +Provides validation and quoting for SQL identifiers (table names, column names, +schema names) across different database engines. Used by all database connectors +to ensure user-supplied identifiers cannot inject arbitrary SQL. + +Defense-in-depth approach: +1. validate_identifier() - allowlist regex rejects SQL metacharacters +2. quote_identifier() - DB-specific quoting with proper escaping +3. safe_identifier() - validate + quote combined +""" + +import re +from enum import Enum + + +class QuoteStyle(Enum): + """Database-specific identifier quoting styles.""" + + DOUBLE_QUOTE = "double" # PostgreSQL, Redshift, Snowflake, Oracle + BACKTICK = "backtick" # MySQL, MariaDB, BigQuery + SQUARE_BRACKET = "bracket" # MSSQL + + +# Allowlist patterns for SQL identifiers. +# Permits: letters, digits, underscores, hyphens (common in table names). +# Intentionally excludes $ (Oracle/PG), # (MSSQL temp tables), and spaces +# to keep the strictest safe default. Extend if a deployment needs them. +_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_-]*$") + + +def validate_identifier(name: str, allow_dots: bool = False) -> str: + """Validate a SQL identifier against an allowlist pattern. + + Args: + name: The identifier to validate (table name, column name, schema name). + allow_dots: If True, allows dot-separated qualified names + (e.g., BigQuery's ``project.dataset.table``). + + Returns: + The validated identifier string (unchanged). + + Raises: + ValueError: If the identifier contains disallowed characters. + """ + if not name or not name.strip(): + raise ValueError("SQL identifier cannot be empty") + + if allow_dots and "." in name: + parts = name.split(".") + for part in parts: + validate_identifier(part, allow_dots=False) + return name + + if not _IDENTIFIER_PATTERN.match(name): + raise ValueError( + f"Invalid SQL identifier: '{name}'. " + "Only letters, digits, underscores, and hyphens are allowed. " + "Must start with a letter or underscore." + ) + return name + + +def quote_identifier(name: str, style: QuoteStyle) -> str: + """Quote a single identifier using DB-specific quoting with escaping. + + Escapes any embedded quote characters to prevent breakout. + + Args: + name: The identifier to quote. + style: The quoting style for the target database. + + Returns: + The quoted identifier string. + """ + if style == QuoteStyle.DOUBLE_QUOTE: + escaped = name.replace('"', '""') + return f'"{escaped}"' + elif style == QuoteStyle.BACKTICK: + escaped = name.replace("`", "``") + return f"`{escaped}`" + elif style == QuoteStyle.SQUARE_BRACKET: + escaped = name.replace("]", "]]") + return f"[{escaped}]" + else: + raise ValueError(f"Unknown quote style: {style}") + + +def safe_identifier(name: str, style: QuoteStyle, allow_dots: bool = False) -> str: + """Validate AND quote a SQL identifier. + + For dot-qualified names (e.g., ``schema.table``), splits on dots + and validates+quotes each component separately. + + Args: + name: The identifier to make safe. + style: The quoting style for the target database. + allow_dots: If True, handles dot-separated qualified names. + + Returns: + The validated and quoted identifier string. + + Raises: + ValueError: If any component fails validation. + """ + if allow_dots and "." in name: + parts = name.split(".") + return ".".join(safe_identifier(part, style, allow_dots=False) for part in parts) + validate_identifier(name) + return quote_identifier(name, style) diff --git a/unstract/connectors/src/unstract/connectors/databases/unstract_db.py b/unstract/connectors/src/unstract/connectors/databases/unstract_db.py index dfe6048236..ce176724ac 100644 --- a/unstract/connectors/src/unstract/connectors/databases/unstract_db.py +++ b/unstract/connectors/src/unstract/connectors/databases/unstract_db.py @@ -7,6 +7,11 @@ from typing import Any from unstract.connectors.base import UnstractConnector +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + safe_identifier, + validate_identifier, +) from unstract.connectors.enums import ConnectorMode from unstract.connectors.exceptions import ConnectorError @@ -68,6 +73,15 @@ def python_social_auth_backend() -> str: def get_engine(self) -> Any: pass + @abstractmethod + def get_quote_style(self) -> QuoteStyle: + """Return the identifier quoting style for this database. + + Returns: + QuoteStyle: The quoting convention used by this database engine. + """ + pass + def test_credentials(self) -> bool: """To test credentials for a DB connector.""" try: @@ -76,10 +90,13 @@ def test_credentials(self) -> bool: raise ConnectorError(f"Error while connecting to DB: {str(e)}") from e return True - def execute(self, query: str) -> Any: + def execute(self, query: str, params: Any = None) -> Any: try: with self.get_engine().cursor() as cursor: - cursor.execute(query) + if params is not None: + cursor.execute(query, params) + else: + cursor.execute(query) return cursor.fetchall() except Exception as e: raise ConnectorError(str(e)) from e @@ -154,13 +171,17 @@ def create_table_query( for key, val in database_entry.items(): if key not in PERMANENT_COLUMNS: sql_type = self.sql_to_db_mapping(val, column_name=key) + # Validate column name to block injection but don't quote — + # databases like Snowflake/Oracle normalize unquoted to + # UPPERCASE, and quoting would preserve lowercase, causing + # a mismatch with existing table schemas. + validate_identifier(key) sql_query += f"{key} {sql_type}, " return sql_query.rstrip(", ") + ")" - @staticmethod def get_sql_insert_query( - table_name: str, sql_keys: list[str], sql_values: list[str] = None + self, table_name: str, sql_keys: list[str], sql_values: list[str] | None = None ) -> str: """Function to generate parameterised insert sql query. @@ -172,10 +193,17 @@ def get_sql_insert_query( Returns: str: returns a string with parameterised insert sql query """ - # Base implementation ignores sql_values and returns parameterized query + style = self.get_quote_style() + quoted_table = safe_identifier(table_name, style, allow_dots=True) + # Validate column names to block injection but don't quote — + # databases like Snowflake/Oracle normalize unquoted to UPPERCASE, + # and quoting would preserve lowercase, causing a mismatch with + # existing table schemas where columns were created unquoted. + for k in sql_keys: + validate_identifier(k) keys_str = ",".join(sql_keys) values_placeholder = ",".join(["%s" for _ in sql_keys]) - return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})" + return f"INSERT INTO {quoted_table} ({keys_str}) VALUES ({values_placeholder})" @abstractmethod def execute_query( @@ -204,9 +232,9 @@ def get_information_schema(self, table_name: str) -> dict[str, str]: query = ( "SELECT column_name, data_type FROM " "information_schema.columns WHERE " - f"table_name = '{table_name}'" + "table_name = %s" ) - results = self.execute(query=query) + results = self.execute(query=query, params=(table_name,)) column_types: dict[str, str] = self.get_db_column_types( columns_with_types=results ) diff --git a/unstract/connectors/tests/databases/test_sql_safety.py b/unstract/connectors/tests/databases/test_sql_safety.py new file mode 100644 index 0000000000..d38385bad6 --- /dev/null +++ b/unstract/connectors/tests/databases/test_sql_safety.py @@ -0,0 +1,169 @@ +import unittest + +from unstract.connectors.databases.sql_safety import ( + QuoteStyle, + quote_identifier, + safe_identifier, + validate_identifier, +) + + +class TestValidateIdentifier(unittest.TestCase): + """Tests for validate_identifier() allowlist regex.""" + + def test_valid_simple_identifiers(self): + valid = ["my_table", "users", "Column1", "_private", "table_name_v2", "a"] + for name in valid: + self.assertEqual(validate_identifier(name), name) + + def test_valid_hyphenated_identifiers(self): + valid = ["my-table", "some-long-name", "data-2024"] + for name in valid: + self.assertEqual(validate_identifier(name), name) + + def test_valid_dot_qualified_names(self): + self.assertEqual( + validate_identifier("schema.table", allow_dots=True), "schema.table" + ) + self.assertEqual( + validate_identifier("project.dataset.table", allow_dots=True), + "project.dataset.table", + ) + + def test_reject_semicolon(self): + with self.assertRaises(ValueError): + validate_identifier("public; DROP TABLE users; --") + + def test_reject_single_quote(self): + with self.assertRaises(ValueError): + validate_identifier("test' OR '1'='1") + + def test_reject_double_quote(self): + with self.assertRaises(ValueError): + validate_identifier('" OR 1=1 --') + + def test_reject_spaces(self): + with self.assertRaises(ValueError): + validate_identifier("table name") + + def test_reject_parentheses(self): + with self.assertRaises(ValueError): + validate_identifier("x(); DROP TABLE y") + + def test_reject_empty_string(self): + with self.assertRaises(ValueError): + validate_identifier("") + + def test_reject_whitespace_only(self): + with self.assertRaises(ValueError): + validate_identifier(" ") + + def test_reject_starts_with_digit(self): + with self.assertRaises(ValueError): + validate_identifier("1table") + + def test_reject_dots_without_flag(self): + with self.assertRaises(ValueError): + validate_identifier("schema.table") + + def test_reject_dot_with_invalid_part(self): + with self.assertRaises(ValueError): + validate_identifier("valid.'; DROP TABLE x", allow_dots=True) + + def test_reject_real_world_payloads(self): + payloads = [ + "public; CREATE TABLE sqli_proof(pwned text); --", + "public; SELECT pg_ls_dir('/etc'); --", + "dbo.results' UNION SELECT name, 'a' FROM sysobjects--", + "results') OR '1'='1' --", + "x TEXT); DROP TABLE users; CREATE TABLE dummy(y", + ] + for payload in payloads: + with self.assertRaises(ValueError, msg=f"Should reject: {payload}"): + validate_identifier(payload) + + +class TestQuoteIdentifier(unittest.TestCase): + """Tests for quote_identifier() DB-specific quoting.""" + + def test_double_quote_style(self): + self.assertEqual( + quote_identifier("my_table", QuoteStyle.DOUBLE_QUOTE), '"my_table"' + ) + + def test_double_quote_escapes_embedded(self): + self.assertEqual( + quote_identifier('my"table', QuoteStyle.DOUBLE_QUOTE), '"my""table"' + ) + + def test_backtick_style(self): + self.assertEqual( + quote_identifier("my_table", QuoteStyle.BACKTICK), "`my_table`" + ) + + def test_backtick_escapes_embedded(self): + self.assertEqual( + quote_identifier("my`table", QuoteStyle.BACKTICK), "`my``table`" + ) + + def test_square_bracket_style(self): + self.assertEqual( + quote_identifier("my_table", QuoteStyle.SQUARE_BRACKET), "[my_table]" + ) + + def test_square_bracket_escapes_embedded(self): + self.assertEqual( + quote_identifier("my]table", QuoteStyle.SQUARE_BRACKET), "[my]]table]" + ) + + def test_hyphenated_name(self): + self.assertEqual( + quote_identifier("my-table", QuoteStyle.DOUBLE_QUOTE), '"my-table"' + ) + self.assertEqual( + quote_identifier("my-table", QuoteStyle.BACKTICK), "`my-table`" + ) + self.assertEqual( + quote_identifier("my-table", QuoteStyle.SQUARE_BRACKET), "[my-table]" + ) + + +class TestSafeIdentifier(unittest.TestCase): + """Tests for safe_identifier() — validate + quote combined.""" + + def test_simple_identifier(self): + self.assertEqual(safe_identifier("users", QuoteStyle.DOUBLE_QUOTE), '"users"') + self.assertEqual(safe_identifier("users", QuoteStyle.BACKTICK), "`users`") + self.assertEqual( + safe_identifier("users", QuoteStyle.SQUARE_BRACKET), "[users]" + ) + + def test_dot_qualified_identifier(self): + result = safe_identifier( + "project.dataset.table", QuoteStyle.BACKTICK, allow_dots=True + ) + self.assertEqual(result, "`project`.`dataset`.`table`") + + result = safe_identifier( + "dbo.my_table", QuoteStyle.SQUARE_BRACKET, allow_dots=True + ) + self.assertEqual(result, "[dbo].[my_table]") + + def test_injection_rejected(self): + with self.assertRaises(ValueError): + safe_identifier("public; DROP TABLE x; --", QuoteStyle.DOUBLE_QUOTE) + + def test_injection_in_qualified_part(self): + with self.assertRaises(ValueError): + safe_identifier( + "valid.'; DROP TABLE x", QuoteStyle.BACKTICK, allow_dots=True + ) + + def test_hyphenated_table(self): + self.assertEqual( + safe_identifier("my-table", QuoteStyle.DOUBLE_QUOTE), '"my-table"' + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/x2text-service/app/authentication_middleware.py b/x2text-service/app/authentication_middleware.py index f43c0542bb..6c0f1c2fc0 100644 --- a/x2text-service/app/authentication_middleware.py +++ b/x2text-service/app/authentication_middleware.py @@ -28,8 +28,8 @@ def validate_bearer_token(cls, token: str | None) -> bool: current_app.logger.error("Authentication failed. Empty bearer token") return False platform_key_table = f'"{Env.DB_SCHEMA}".{DBTable.PLATFORM_KEY}' - query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" - cursor = be_db.execute_sql(query) + query = f"SELECT * FROM {platform_key_table} WHERE key = %s" + cursor = be_db.execute_sql(query, (token,)) result_row = cursor.fetchone() cursor.close() if not result_row or len(result_row) == 0: