-
Notifications
You must be signed in to change notification settings - Fork 99
feat: Implementation for batch dml in dbapi #1055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, List | ||
| from google.cloud.spanner_dbapi.checksum import ResultsChecksum | ||
| from google.cloud.spanner_dbapi.parsed_statement import ( | ||
| ParsedStatement, | ||
| StatementType, | ||
| Statement, | ||
| ) | ||
| from google.rpc.code_pb2 import ABORTED, OK | ||
| from google.api_core.exceptions import Aborted | ||
|
|
||
| from google.cloud.spanner_dbapi.utils import StreamedManyResultSets | ||
|
|
||
| if TYPE_CHECKING: | ||
| from google.cloud.spanner_dbapi.cursor import Cursor | ||
|
|
||
|
|
||
| class BatchDmlExecutor: | ||
| """Executor that is used when a DML batch is started. These batches only | ||
| accept DML statements. All DML statements are buffered locally and sent to | ||
| Spanner when runBatch() is called. | ||
| :type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor` | ||
| :param cursor: | ||
| """ | ||
|
|
||
| def __init__(self, cursor: "Cursor"): | ||
| self._cursor = cursor | ||
| self._connection = cursor.connection | ||
| self._statements: List[Statement] = [] | ||
|
|
||
| def execute_statement(self, parsed_statement: ParsedStatement): | ||
| """Executes the statement when dml batch is active by buffering the | ||
| statement in-memory. | ||
| :type parsed_statement: ParsedStatement | ||
| :param parsed_statement: parsed statement containing sql query and query | ||
| params | ||
| """ | ||
| from google.cloud.spanner_dbapi import ProgrammingError | ||
|
|
||
| if ( | ||
| parsed_statement.statement_type != StatementType.UPDATE | ||
| and parsed_statement.statement_type != StatementType.INSERT | ||
| ): | ||
| raise ProgrammingError( | ||
| "Only DML statements are allowed in batch " "DML mode." | ||
|
||
| ) | ||
| self._statements.append(parsed_statement.statement) | ||
|
|
||
| def run_batch_dml(self): | ||
| """Executes all the buffered statements on the active dml batch by | ||
| making a call to Spanner. | ||
| """ | ||
| return run_batch_dml(self._cursor, self._statements) | ||
|
|
||
|
|
||
| def run_batch_dml(cursor: "Cursor", statements: List[Statement]): | ||
| """Executes all the dml statements by making a batch call to Spanner. | ||
| :type cursor: Cursor | ||
| :param cursor: Database Cursor object | ||
| :type statements: List[Statement] | ||
| :param statements: list of statements to execute in batch | ||
| """ | ||
| from google.cloud.spanner_dbapi import OperationalError | ||
|
|
||
| connection = cursor.connection | ||
olavloite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| many_result_set = StreamedManyResultSets() | ||
| statements_tuple = [] | ||
| for statement in statements: | ||
| statements_tuple.append(statement.get_tuple()) | ||
| if not connection._client_transaction_started: | ||
| res = connection.database.run_in_transaction(_do_batch_update, statements_tuple) | ||
| many_result_set.add_iter(res) | ||
| cursor._row_count = sum([max(val, 0) for val in res]) | ||
| else: | ||
| retried = False | ||
| while True: | ||
| try: | ||
| transaction = connection.transaction_checkout() | ||
| status, res = transaction.batch_update(statements_tuple) | ||
| many_result_set.add_iter(res) | ||
| res_checksum = ResultsChecksum() | ||
| res_checksum.consume_result(res) | ||
| res_checksum.consume_result(status.code) | ||
| if not retried: | ||
| connection._statements.append((statements, res_checksum)) | ||
| cursor._row_count = sum([max(val, 0) for val in res]) | ||
|
|
||
| if status.code == ABORTED: | ||
| connection._transaction = None | ||
| raise Aborted(status.message) | ||
| elif status.code != OK: | ||
| raise OperationalError(status.message) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should (could) this also include the status code?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will take it in a follow up PR |
||
| return many_result_set | ||
| except Aborted: | ||
| connection.retry_transaction() | ||
| retried = True | ||
|
|
||
|
|
||
| def _do_batch_update(transaction, statements): | ||
| from google.cloud.spanner_dbapi import OperationalError | ||
|
|
||
| status, res = transaction.batch_update(statements) | ||
| if status.code == ABORTED: | ||
| raise Aborted(status.message) | ||
| elif status.code != OK: | ||
| raise OperationalError(status.message) | ||
| return res | ||
|
|
||
|
|
||
| class BatchMode(Enum): | ||
| DML = 1 | ||
| DDL = 2 | ||
| NONE = 3 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,13 +13,14 @@ | |
| # limitations under the License. | ||
|
|
||
| """DB-API Connection for the Google Cloud Spanner.""" | ||
|
|
||
| import time | ||
| import warnings | ||
|
|
||
| from google.api_core.exceptions import Aborted | ||
| from google.api_core.gapic_v1.client_info import ClientInfo | ||
| from google.cloud import spanner_v1 as spanner | ||
| from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor | ||
| from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement | ||
| from google.cloud.spanner_v1 import RequestOptions | ||
| from google.cloud.spanner_v1.session import _get_retry_delay | ||
| from google.cloud.spanner_v1.snapshot import Snapshot | ||
|
|
@@ -28,7 +29,11 @@ | |
| from google.cloud.spanner_dbapi.checksum import _compare_checksums | ||
| from google.cloud.spanner_dbapi.checksum import ResultsChecksum | ||
| from google.cloud.spanner_dbapi.cursor import Cursor | ||
| from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError | ||
| from google.cloud.spanner_dbapi.exceptions import ( | ||
| InterfaceError, | ||
| OperationalError, | ||
| ProgrammingError, | ||
| ) | ||
| from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT | ||
| from google.cloud.spanner_dbapi.version import PY_VERSION | ||
|
|
||
|
|
@@ -111,6 +116,8 @@ def __init__(self, instance, database=None, read_only=False): | |
| # whether transaction started at Spanner. This means that we had | ||
| # made atleast one call to Spanner. | ||
| self._spanner_transaction_started = False | ||
| self._batch_mode = BatchMode.NONE | ||
| self._batch_dml_executor: BatchDmlExecutor = None | ||
|
|
||
| @property | ||
| def autocommit(self): | ||
|
|
@@ -196,6 +203,24 @@ def read_only(self, value): | |
| ) | ||
| self._read_only = value | ||
|
|
||
| @property | ||
| def batch_mode(self): | ||
|
||
| """_batch_mode flag for this connection. | ||
|
|
||
| :rtype: bool | ||
| :returns: _batch_mode flag value. | ||
| """ | ||
| return self._batch_mode | ||
|
|
||
| @batch_mode.setter | ||
| def batch_mode(self, value): | ||
| """`batch_mode` flag setter. | ||
|
|
||
| Args: | ||
| value (BatchMode) | ||
| """ | ||
| self._batch_mode = value | ||
|
||
|
|
||
| @property | ||
| def request_options(self): | ||
| """Options for the next SQL operations. | ||
|
|
@@ -310,7 +335,10 @@ def _rerun_previous_statements(self): | |
| statements, checksum = statement | ||
|
|
||
| transaction = self.transaction_checkout() | ||
| status, res = transaction.batch_update(statements) | ||
| statements_tuple = [] | ||
| for single_statement in statements: | ||
| statements_tuple.append(single_statement.get_tuple()) | ||
| status, res = transaction.batch_update(statements_tuple) | ||
|
|
||
| if status.code == ABORTED: | ||
| raise Aborted(status.details) | ||
|
|
@@ -476,14 +504,14 @@ def run_prior_DDL_statements(self): | |
|
|
||
| return self.database.update_ddl(ddl_statements).result() | ||
|
|
||
| def run_statement(self, statement, retried=False): | ||
| def run_statement(self, statement: Statement, retried=False): | ||
| """Run single SQL statement in begun transaction. | ||
|
|
||
| This method is never used in autocommit mode. In | ||
| !autocommit mode however it remembers every executed | ||
| SQL statement with its parameters. | ||
|
|
||
| :type statement: :class:`dict` | ||
| :type statement: :class:`Statement` | ||
| :param statement: SQL statement to execute. | ||
|
|
||
| :type retried: bool | ||
|
|
@@ -534,6 +562,39 @@ def validate(self): | |
| "Expected: [[1]]" % result | ||
| ) | ||
|
|
||
| @check_not_closed | ||
| def start_batch_dml(self, cursor): | ||
| if self.batch_mode is not BatchMode.NONE: | ||
| raise ProgrammingError( | ||
| "Cannot start a DML batch when a batch is already active" | ||
| ) | ||
| if self.read_only: | ||
| raise ProgrammingError( | ||
| "Cannot start a DML batch when the connection is in read-only mode" | ||
| ) | ||
| self.batch_mode = BatchMode.DML | ||
| self._batch_dml_executor = BatchDmlExecutor(cursor) | ||
|
|
||
| @check_not_closed | ||
| def execute_batch_dml_statement(self, parsed_statement: ParsedStatement): | ||
| if self.batch_mode is not BatchMode.DML: | ||
| raise ProgrammingError( | ||
| "Cannot execute statement when the BatchMode is not DML" | ||
| ) | ||
| self._batch_dml_executor.execute_statement(parsed_statement) | ||
|
|
||
| @check_not_closed | ||
| def run_batch(self): | ||
| if self.batch_mode is BatchMode.NONE: | ||
| raise ProgrammingError("Cannot run a batch when the BatchMode is not set") | ||
| try: | ||
| if self.batch_mode is BatchMode.DML: | ||
| many_result_set = self._batch_dml_executor.run_batch_dml() | ||
| finally: | ||
| self.batch_mode = BatchMode.NONE | ||
| self._batch_dml_executor = None | ||
| return many_result_set | ||
|
|
||
| def __enter__(self): | ||
| return self | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.