|
36 | 36 | from sqlparse import keywords |
37 | 37 | from sqlparse.lexer import Lexer |
38 | 38 | from sqlparse.sql import ( |
| 39 | + Function, |
39 | 40 | Identifier, |
40 | 41 | IdentifierList, |
41 | 42 | Parenthesis, |
@@ -219,6 +220,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: |
219 | 220 | return cte, remainder |
220 | 221 |
|
221 | 222 |
|
| 223 | +def check_sql_functions_exist( |
| 224 | + sql: str, function_list: set[str], engine: str | None = None |
| 225 | +) -> bool: |
| 226 | + """ |
| 227 | + Check if the SQL statement contains any of the specified functions. |
| 228 | +
|
| 229 | + :param sql: The SQL statement |
| 230 | + :param function_list: The list of functions to search for |
| 231 | + :param engine: The engine to use for parsing the SQL statement |
| 232 | + """ |
| 233 | + return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) |
| 234 | + |
| 235 | + |
222 | 236 | def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: |
223 | 237 | """ |
224 | 238 | Strips comments from a SQL statement, does a simple test first |
@@ -288,6 +302,34 @@ def tables(self) -> set[Table]: |
288 | 302 | self._tables = self._extract_tables_from_sql() |
289 | 303 | return self._tables |
290 | 304 |
|
| 305 | + def _check_functions_exist_in_token( |
| 306 | + self, token: Token, functions: set[str] |
| 307 | + ) -> bool: |
| 308 | + if ( |
| 309 | + isinstance(token, Function) |
| 310 | + and token.get_name() is not None |
| 311 | + and token.get_name().lower() in functions |
| 312 | + ): |
| 313 | + return True |
| 314 | + if hasattr(token, "tokens"): |
| 315 | + for inner_token in token.tokens: |
| 316 | + if self._check_functions_exist_in_token(inner_token, functions): |
| 317 | + return True |
| 318 | + return False |
| 319 | + |
| 320 | + def check_functions_exist(self, functions: set[str]) -> bool: |
| 321 | + """ |
| 322 | + Check if the SQL statement contains any of the specified functions. |
| 323 | +
|
| 324 | + :param functions: A set of functions to search for |
| 325 | + :return: True if the statement contains any of the specified functions |
| 326 | + """ |
| 327 | + for statement in self._parsed: |
| 328 | + for token in statement.tokens: |
| 329 | + if self._check_functions_exist_in_token(token, functions): |
| 330 | + return True |
| 331 | + return False |
| 332 | + |
291 | 333 | def _extract_tables_from_sql(self) -> set[Table]: |
292 | 334 | """ |
293 | 335 | Extract all table references in a query. |
|
0 commit comments