Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Context manager for Cloud Spanner batched writes."""

from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
from google.cloud.spanner_v1 import TransactionOptions

Expand Down Expand Up @@ -123,6 +124,7 @@ class Batch(_BatchBase):
"""

committed = None
commit_stats = None
"""Timestamp at which the batch was successfully committed."""

def _check_state(self):
Expand All @@ -136,9 +138,13 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self):
def commit(self, return_commit_stats=False):
"""Commit mutations to the database.

:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.

:rtype: datetime
:returns: timestamp of the committed changes.
"""
Expand All @@ -148,14 +154,16 @@ def commit(self):
metadata = _metadata_with_prefix(database.name)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}
request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
metadata=metadata,
)
response = api.commit(request=request, metadata=metadata,)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed

def __enter__(self):
Expand Down
39 changes: 37 additions & 2 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import functools
import grpc
import logging
import re
import threading

Expand Down Expand Up @@ -95,11 +96,19 @@ class Database(object):
:param pool: (Optional) session pool to be used by database. If not
passed, the database will construct an instance of
:class:`~google.cloud.spanner_v1.pool.BurstyPool`.

:type logger: `logging.Logger`
:param logger: (Optional) a custom logger that is used if `log_commit_stats`
is `True` to log commit statistics. If not passed, a logger
will be created when needed that will log the commit statistics
to stdout.
"""

_spanner_api = None

def __init__(self, database_id, instance, ddl_statements=(), pool=None):
def __init__(
self, database_id, instance, ddl_statements=(), pool=None, logger=None
):
self.database_id = database_id
self._instance = instance
self._ddl_statements = _check_ddl_statements(ddl_statements)
Expand All @@ -109,6 +118,8 @@ def __init__(self, database_id, instance, ddl_statements=(), pool=None):
self._restore_info = None
self._version_retention_period = None
self._earliest_version_time = None
self.log_commit_stats = False
self._logger = logger

if pool is None:
pool = BurstyPool()
Expand Down Expand Up @@ -237,6 +248,25 @@ def ddl_statements(self):
"""
return self._ddl_statements

@property
def logger(self):
"""Logger used by the database.

The default logger will log commit stats at the log level INFO using
`sys.stderr`.

:rtype: :class:`logging.Logger` or `None`
:returns: the logger
"""
if self._logger is None:
self._logger = logging.getLogger(self.name)
self._logger.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
self._logger.addHandler(ch)
return self._logger

@property
def spanner_api(self):
"""Helper for session-related API calls."""
Expand Down Expand Up @@ -647,8 +677,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
try:
if exc_type is None:
self._batch.commit()
self._batch.commit(return_commit_stats=self._database.log_commit_stats)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
self._database.logger.info(
"CommitStats: {}".format(self._batch.commit_stats),
extra={"commit_stats": self._batch.commit_stats},
)
self._database._pool.put(self._session)


Expand Down
12 changes: 10 additions & 2 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def delete(self):

api.delete_instance(name=self.name, metadata=metadata)

def database(self, database_id, ddl_statements=(), pool=None):
def database(self, database_id, ddl_statements=(), pool=None, logger=None):
"""Factory to create a database within this instance.

:type database_id: str
Expand All @@ -371,10 +371,18 @@ def database(self, database_id, ddl_statements=(), pool=None):
:class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`.
:param pool: (Optional) session pool to be used by database.

:type logger: `logging.Logger`
:param logger: (Optional) a custom logger that is used if `log_commit_stats`
is `True` to log commit statistics. If not passed, a logger
will be created when needed that will log the commit statistics
to stdout.

:rtype: :class:`~google.cloud.spanner_v1.database.Database`
:returns: a database owned by this instance.
"""
return Database(database_id, self, ddl_statements=ddl_statements, pool=pool)
return Database(
database_id, self, ddl_statements=ddl_statements, pool=pool, logger=logger
)

def list_databases(self, page_size=None):
"""List databases for the instance.
Expand Down
7 changes: 6 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,19 @@ def run_in_transaction(self, func, *args, **kw):
raise

try:
txn.commit()
txn.commit(return_commit_stats=self._database.log_commit_stats)
except Aborted as exc:
del self._transaction
_delay_until_retry(exc, deadline, attempts)
except GoogleAPICallError:
del self._transaction
raise
else:
if self._database.log_commit_stats and txn.commit_stats:
self._database.logger.info(
"CommitStats: {}".format(txn.commit_stats),
extra={"commit_stats": txn.commit_stats},
)
return return_value


Expand Down
23 changes: 16 additions & 7 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_merge_query_options,
_metadata_with_prefix,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import TransactionSelector
Expand All @@ -42,6 +43,7 @@ class Transaction(_SnapshotBase, _BatchBase):
committed = None
"""Timestamp at which the transaction was successfully committed."""
rolled_back = False
commit_stats = None
_multi_use = True
_execute_sql_count = 0

Expand Down Expand Up @@ -119,9 +121,13 @@ def rollback(self):
self.rolled_back = True
del self._session._transaction

def commit(self):
def commit(self, return_commit_stats=False):
"""Commit mutations to the database.

:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.

:rtype: datetime
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
Expand All @@ -132,14 +138,17 @@ def commit(self):
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
trace_attributes = {"num_mutations": len(self._mutations)}
request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
metadata=metadata,
)
response = api.commit(request=request, metadata=metadata,)
self.committed = response.commit_timestamp
if return_commit_stats:
self.commit_stats = response.commit_stats
del self._session._transaction
return self.committed

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,17 @@ def __init__(self, **kwargs):
self.__dict__.update(**kwargs)

def commit(
self,
session,
mutations,
transaction_id="",
single_use_transaction=None,
metadata=None,
self, request=None, metadata=None,
):
from google.api_core.exceptions import Unknown

assert transaction_id == ""
self._committed = (session, mutations, single_use_transaction, metadata)
assert request.transaction_id == b""
self._committed = (
request.session,
request.mutations,
request.single_use_transaction,
metadata,
)
if self._rpc_error:
raise Unknown("error")
return self._commit_response
Loading