Skip to content

Commit 56f0103

Browse files
dpgasparmichael-s-molina
authored andcommitted
fix: adds the ability to disallow SQL functions per engine (#28639)
1 parent 6d5e38c commit 56f0103

File tree

10 files changed

+130
-17
lines changed

10 files changed

+130
-17
lines changed

superset-frontend/cypress-base/cypress.config.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export default defineConfig({
2929
videoUploadOnPasses: false,
3030
viewportWidth: 1280,
3131
viewportHeight: 1024,
32-
projectId: 'ukwxzo',
32+
projectId: 'ud5x2f',
3333
retries: {
3434
runMode: 2,
3535
openMode: 0,

superset-frontend/src/dashboard/components/Header/HeaderActionsDropdown/index.jsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
import React from 'react';
19+
import React, { PureComponent } from 'react';
2020
import PropTypes from 'prop-types';
2121
import { isEmpty } from 'lodash';
2222
import { connect } from 'react-redux';

superset/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,15 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
12031203
# as such `create_engine(url, **params)`
12041204
DB_CONNECTION_MUTATOR = None
12051205

1206+
# A set of disallowed SQL functions per engine. This is used to restrict the use of
1207+
# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine
1208+
# names, and the values are sets of disallowed functions.
1209+
DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
1210+
"postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"},
1211+
"clickhouse": {"url"},
1212+
"mysql": {"version"},
1213+
}
1214+
12061215

12071216
# A function that intercepts the SQL to be executed and can alter it.
12081217
# The use case is can be around adding some sort of comment header

superset/db_engine_specs/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from superset.constants import TimeGrain as TimeGrainConstants
6060
from superset.databases.utils import make_url_safe
6161
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
62+
from superset.exceptions import DisallowedSQLFunction
6263
from superset.sql_parse import ParsedQuery, Table
6364
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
6465
from superset.utils import core as utils
@@ -1584,6 +1585,11 @@ def execute( # pylint: disable=unused-argument
15841585
"""
15851586
if not cls.allows_sql_comments:
15861587
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
1588+
disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
1589+
cls.engine, set()
1590+
)
1591+
if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine):
1592+
raise DisallowedSQLFunction(disallowed_functions)
15871593

15881594
if cls.arraysize:
15891595
cursor.arraysize = cls.arraysize

superset/db_engine_specs/trino.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Any, TYPE_CHECKING
2424

2525
import simplejson as json
26-
from flask import current_app
26+
from flask import current_app, Flask
2727
from sqlalchemy.engine.reflection import Inspector
2828
from sqlalchemy.engine.url import URL
2929
from sqlalchemy.exc import NoSuchTableError
@@ -206,19 +206,27 @@ def execute_with_cursor(cls, cursor: Cursor, sql: str, query: Query) -> None:
206206
execute_result: dict[str, Any] = {}
207207
execute_event = threading.Event()
208208

209-
def _execute(results: dict[str, Any], event: threading.Event) -> None:
209+
def _execute(
210+
results: dict[str, Any], event: threading.Event, app: Flask
211+
) -> None:
210212
logger.debug("Query %d: Running query: %s", query_id, sql)
211213

212214
try:
213-
cls.execute(cursor, sql)
215+
with app.app_context():
216+
cls.execute(cursor, sql)
214217
except Exception as ex: # pylint: disable=broad-except
215218
results["error"] = ex
216219
finally:
217220
event.set()
218221

219222
execute_thread = threading.Thread(
220223
target=_execute,
221-
args=(execute_result, execute_event),
224+
args=(
225+
execute_result,
226+
execute_event,
227+
# pylint: disable=protected-access
228+
current_app._get_current_object(),
229+
),
222230
)
223231
execute_thread.start()
224232

superset/exceptions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,18 @@ def __init__(self, exc: ValidationError, payload: dict[str, Any]):
295295
extra={"messages": exc.messages, "payload": payload},
296296
)
297297
super().__init__(error)
298+
299+
300+
class DisallowedSQLFunction(SupersetErrorException):
301+
"""
302+
Disallowed function found on SQL statement
303+
"""
304+
305+
def __init__(self, functions: set[str]):
306+
super().__init__(
307+
SupersetError(
308+
message=f"SQL statement contains disallowed function(s): {functions}",
309+
error_type=SupersetErrorType.SYNTAX_ERROR,
310+
level=ErrorLevel.ERROR,
311+
)
312+
)

superset/sql_parse.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sqlparse import keywords
3737
from sqlparse.lexer import Lexer
3838
from sqlparse.sql import (
39+
Function,
3940
Identifier,
4041
IdentifierList,
4142
Parenthesis,
@@ -219,6 +220,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
219220
return cte, remainder
220221

221222

223+
def check_sql_functions_exist(
224+
sql: str, function_list: set[str], engine: str | None = None
225+
) -> bool:
226+
"""
227+
Check if the SQL statement contains any of the specified functions.
228+
229+
:param sql: The SQL statement
230+
:param function_list: The list of functions to search for
231+
:param engine: The engine to use for parsing the SQL statement
232+
"""
233+
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
234+
235+
222236
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
223237
"""
224238
Strips comments from a SQL statement, does a simple test first
@@ -288,6 +302,34 @@ def tables(self) -> set[Table]:
288302
self._tables = self._extract_tables_from_sql()
289303
return self._tables
290304

305+
def _check_functions_exist_in_token(
306+
self, token: Token, functions: set[str]
307+
) -> bool:
308+
if (
309+
isinstance(token, Function)
310+
and token.get_name() is not None
311+
and token.get_name().lower() in functions
312+
):
313+
return True
314+
if hasattr(token, "tokens"):
315+
for inner_token in token.tokens:
316+
if self._check_functions_exist_in_token(inner_token, functions):
317+
return True
318+
return False
319+
320+
def check_functions_exist(self, functions: set[str]) -> bool:
321+
"""
322+
Check if the SQL statement contains any of the specified functions.
323+
324+
:param functions: A set of functions to search for
325+
:return: True if the statement contains any of the specified functions
326+
"""
327+
for statement in self._parsed:
328+
for token in statement.tokens:
329+
if self._check_functions_exist_in_token(token, functions):
330+
return True
331+
return False
332+
291333
def _extract_tables_from_sql(self) -> set[Table]:
292334
"""
293335
Extract all table references in a query.

tests/unit_tests/db_engine_specs/test_trino.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def test_handle_cursor_early_cancel(
396396
assert cancel_query_mock.call_args is None
397397

398398

399-
def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
399+
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
400400
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
401401
from superset.db_engine_specs.trino import TrinoEngineSpec
402402

@@ -411,16 +411,20 @@ def _mock_execute(*args, **kwargs):
411411
mock_cursor.query_id = query_id
412412

413413
mock_cursor.execute.side_effect = _mock_execute
414+
with patch.dict(
415+
"superset.config.DISALLOWED_SQL_FUNCTIONS",
416+
{},
417+
clear=True,
418+
):
419+
TrinoEngineSpec.execute_with_cursor(
420+
cursor=mock_cursor,
421+
sql="SELECT 1 FROM foo",
422+
query=mock_query,
423+
)
414424

415-
TrinoEngineSpec.execute_with_cursor(
416-
cursor=mock_cursor,
417-
sql="SELECT 1 FROM foo",
418-
query=mock_query,
419-
)
420-
421-
mock_query.set_extra_json_key.assert_called_once_with(
422-
key=QUERY_CANCEL_KEY, value=query_id
423-
)
425+
mock_query.set_extra_json_key.assert_called_once_with(
426+
key=QUERY_CANCEL_KEY, value=query_id
427+
)
424428

425429

426430
def test_get_columns(mocker: MockerFixture):

tests/unit_tests/sql_parse_tests.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from superset.sql_parse import (
3434
add_table_name,
35+
check_sql_functions_exist,
3536
extract_table_references,
3637
extract_tables_from_jinja_sql,
3738
get_rls_for_table,
@@ -1292,6 +1293,31 @@ def test_strip_comments_from_sql() -> None:
12921293
)
12931294

12941295

1296+
def test_check_sql_functions_exist() -> None:
1297+
"""
1298+
Test that comments are stripped out correctly.
1299+
"""
1300+
assert not (
1301+
check_sql_functions_exist("select a, b from version", {"version"}, "postgresql")
1302+
)
1303+
1304+
assert check_sql_functions_exist("select version()", {"version"}, "postgresql")
1305+
1306+
assert check_sql_functions_exist(
1307+
"select version from version()", {"version"}, "postgresql"
1308+
)
1309+
1310+
assert check_sql_functions_exist(
1311+
"select 1, a.version from (select version from version()) as a",
1312+
{"version"},
1313+
"postgresql",
1314+
)
1315+
1316+
assert check_sql_functions_exist(
1317+
"select 1, a.version from (select version()) as a", {"version"}, "postgresql"
1318+
)
1319+
1320+
12951321
def test_sanitize_clause_valid():
12961322
# regular clauses
12971323
assert sanitize_clause("col = 1") == "col = 1"

tests/unit_tests/utils/csv_tests.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import sys
19+
1820
import pandas as pd
1921
import pyarrow as pa
2022
import pytest
@@ -57,6 +59,7 @@ def test_escape_value():
5759
assert result == "' =10+2"
5860

5961

62+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10 or later")
6063
def test_df_to_escaped_csv():
6164
df = pd.DataFrame(
6265
data={
@@ -87,7 +90,7 @@ def test_df_to_escaped_csv():
8790
["col_a"],
8891
["'=func()"],
8992
["-10"],
90-
["\"'=cmd\\|' /C calc'!A0\""],
93+
[r"'=cmd\\|' /C calc'!A0"],
9194
['"\'""""=b"'],
9295
["' =a"],
9396
["\x00"],

0 commit comments

Comments
 (0)