Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Context manager for Cloud Spanner batched writes."""

from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
from google.cloud.spanner_v1 import TransactionOptions

Expand Down Expand Up @@ -123,6 +124,7 @@ class Batch(_BatchBase):
"""

committed = None
commit_stats = None
"""Timestamp at which the batch was successfully committed."""

def _check_state(self):
Expand All @@ -136,9 +138,13 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self):
def commit(self, return_commit_stats=False):
"""Commit mutations to the database.

:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.

:rtype: datetime
:returns: timestamp of the committed changes.
"""
Expand All @@ -148,14 +154,16 @@ def commit(self):
metadata = _metadata_with_prefix(database.name)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}
request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
metadata=metadata,
)
response = api.commit(request=request, metadata=metadata,)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed

def __enter__(self):
Expand Down
40 changes: 38 additions & 2 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import functools
import grpc
import logging
import re
import threading

Expand Down Expand Up @@ -99,14 +100,18 @@ class Database(object):

_spanner_api = None

def __init__(self, database_id, instance, ddl_statements=(), pool=None):
def __init__(
self, database_id, instance, ddl_statements=(), pool=None, logger=None
):
self.database_id = database_id
self._instance = instance
self._ddl_statements = _check_ddl_statements(ddl_statements)
self._local = threading.local()
self._state = None
self._create_time = None
self._restore_info = None
self.log_commit_stats = False
self._logger = logger

if pool is None:
pool = BurstyPool()
Expand Down Expand Up @@ -216,6 +221,31 @@ def ddl_statements(self):
"""
return self._ddl_statements

@property
def logger(self):
"""Logger used by the database.

The default logger will log commit stats at the log level INFO using
`sys.stderr`.

:rtype: :class:`logging.Logger` or `None`
:returns: the logger
"""
if self._logger is None:
self._logger = logging.getLogger(self.name)
self._logger.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)

self._logger.addHandler(ch)
return self._logger

@property
def spanner_api(self):
"""Helper for session-related API calls."""
Expand Down Expand Up @@ -624,8 +654,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
try:
if exc_type is None:
self._batch.commit()
self._batch.commit(return_commit_stats=self._database.log_commit_stats)
finally:
if self._database.log_commit_stats:
self._database.logger.info(
"Transaction mutation count: {}".format(
self._batch.commit_stats.mutation_count
)
)
self._database._pool.put(self._session)


Expand Down
6 changes: 4 additions & 2 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def delete(self):

api.delete_instance(name=self.name, metadata=metadata)

def database(self, database_id, ddl_statements=(), pool=None):
def database(self, database_id, ddl_statements=(), pool=None, logger=None):
"""Factory to create a database within this instance.

:type database_id: str
Expand All @@ -374,7 +374,9 @@ def database(self, database_id, ddl_statements=(), pool=None):
:rtype: :class:`~google.cloud.spanner_v1.database.Database`
:returns: a database owned by this instance.
"""
return Database(database_id, self, ddl_statements=ddl_statements, pool=pool)
return Database(
database_id, self, ddl_statements=ddl_statements, pool=pool, logger=logger
)

def list_databases(self, page_size=None):
"""List databases for the instance.
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,20 @@ def run_in_transaction(self, func, *args, **kw):
raise

try:
txn.commit()
txn.commit(return_commit_stats=self._database.log_commit_stats)
except Aborted as exc:
del self._transaction
_delay_until_retry(exc, deadline, attempts)
except GoogleAPICallError:
del self._transaction
raise
else:
if self._database.log_commit_stats:
self._database.logger.info(
"Transaction mutation count: {}".format(
txn.commit_stats.mutation_count
)
)
return return_value


Expand Down
23 changes: 16 additions & 7 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_merge_query_options,
_metadata_with_prefix,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import TransactionSelector
Expand All @@ -42,6 +43,7 @@ class Transaction(_SnapshotBase, _BatchBase):
committed = None
"""Timestamp at which the transaction was successfully committed."""
rolled_back = False
commit_stats = None
_multi_use = True
_execute_sql_count = 0

Expand Down Expand Up @@ -119,9 +121,13 @@ def rollback(self):
self.rolled_back = True
del self._session._transaction

def commit(self):
def commit(self, return_commit_stats=False):
"""Commit mutations to the database.

:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.

:rtype: datetime
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
Expand All @@ -132,14 +138,17 @@ def commit(self):
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
trace_attributes = {"num_mutations": len(self._mutations)}
request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
metadata=metadata,
)
response = api.commit(request=request, metadata=metadata,)
self.committed = response.commit_timestamp
if return_commit_stats:
self.commit_stats = response.commit_stats
del self._session._transaction
return self.committed

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,17 @@ def __init__(self, **kwargs):
self.__dict__.update(**kwargs)

def commit(
self,
session,
mutations,
transaction_id="",
single_use_transaction=None,
metadata=None,
self, request=None, metadata=None,
):
from google.api_core.exceptions import Unknown

assert transaction_id == ""
self._committed = (session, mutations, single_use_transaction, metadata)
assert request.transaction_id == b""
self._committed = (
request.session,
request.mutations,
request.single_use_transaction,
metadata,
)
if self._rpc_error:
raise Unknown("error")
return self._commit_response
86 changes: 85 additions & 1 deletion tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def test_ctor_defaults(self):
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertIsInstance(database._pool, BurstyPool)
self.assertFalse(database.log_commit_stats)
self.assertIsNone(database._logger)
# BurstyPool does not create sessions during 'bind()'.
self.assertTrue(database._pool._sessions.empty())

Expand Down Expand Up @@ -145,6 +147,18 @@ def test_ctor_w_ddl_statements_ok(self):
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS)

def test_ctor_w_explicit_logger(self):
from logging import Logger

instance = _Instance(self.INSTANCE_NAME)
logger = mock.create_autospec(Logger, instance=True)
database = self._make_one(self.DATABASE_ID, instance, logger=logger)
self.assertEqual(database.database_id, self.DATABASE_ID)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertFalse(database.log_commit_stats)
self.assertEqual(database._logger, logger)

def test_from_pb_bad_database_name(self):
from google.cloud.spanner_admin_database_v1 import Database

Expand Down Expand Up @@ -249,6 +263,24 @@ def test_restore_info(self):
)
self.assertEqual(database.restore_info, restore_info)

def test_logger_property_default(self):
import logging

instance = _Instance(self.INSTANCE_NAME)
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
logger = logging.getLogger(database.name)
self.assertEqual(database.logger, logger)

def test_logger_property_custom(self):
import logging

instance = _Instance(self.INSTANCE_NAME)
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
logger = database._logger = mock.create_autospec(logging.Logger, instance=True)
self.assertEqual(database.logger, logger)

def test_spanner_api_property_w_scopeless_creds(self):

client = _Client()
Expand Down Expand Up @@ -1263,6 +1295,7 @@ def test_ctor(self):

def test_context_mgr_success(self):
import datetime
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -1290,13 +1323,59 @@ def test_context_mgr_success(self):

expected_txn_options = TransactionOptions(read_write={})

request = CommitRequest(
session=self.SESSION_NAME,
mutations=[],
single_use_transaction=expected_txn_options,
)
api.commit.assert_called_once_with(
request=request, metadata=[("google-cloud-resource-prefix", database.name)],
)

def test_context_mgr_w_commit_stats(self):
import datetime
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud.spanner_v1.batch import Batch

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
commit_stats = CommitResponse.CommitStats(mutation_count=4)
response = CommitResponse(commit_timestamp=now_pb, commit_stats=commit_stats)
database = _Database(self.DATABASE_NAME)
database.log_commit_stats = True
api = database.spanner_api = self._make_spanner_client()
api.commit.return_value = response
pool = database._pool = _Pool()
session = _Session(database)
pool.put(session)
checkout = self._make_one(database)

with checkout as batch:
self.assertIsNone(pool._session)
self.assertIsInstance(batch, Batch)
self.assertIs(batch._session, session)

self.assertIs(pool._session, session)
self.assertEqual(batch.committed, now)

expected_txn_options = TransactionOptions(read_write={})

request = CommitRequest(
session=self.SESSION_NAME,
mutations=[],
single_use_transaction=expected_txn_options,
metadata=[("google-cloud-resource-prefix", database.name)],
return_commit_stats=True,
)
api.commit.assert_called_once_with(
request=request, metadata=[("google-cloud-resource-prefix", database.name)],
)

database.logger.info.assert_called_once_with("Transaction mutation count: 4")

def test_context_mgr_failure(self):
from google.cloud.spanner_v1.batch import Batch

Expand Down Expand Up @@ -1883,10 +1962,15 @@ def __init__(self, name):


class _Database(object):
log_commit_stats = False

def __init__(self, name, instance=None):
self.name = name
self.database_id = name.rsplit("/", 1)[1]
self._instance = instance
from logging import Logger

self.logger = mock.create_autospec(Logger, instance=True)


class _Pool(object):
Expand Down
Loading