diff --git a/core/google/api/core/helpers/grpc_helpers.py b/core/google/api/core/helpers/grpc_helpers.py new file mode 100644 index 000000000000..0f065b7086e1 --- /dev/null +++ b/core/google/api/core/helpers/grpc_helpers.py @@ -0,0 +1,104 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for :mod:`grpc`.""" + +import grpc +import six + +from google.api.core import exceptions + + +# The list of gRPC Callable interfaces that return iterators. +_STREAM_WRAP_CLASSES = ( + grpc.UnaryStreamMultiCallable, + grpc.StreamStreamMultiCallable, +) + + +def _patch_callable_name(callable_): + """Fix-up gRPC callable attributes. + + gRPC callable lack the ``__name__`` attribute which causes + :func:`functools.wraps` to error. This adds the attribute if needed. + """ + if not hasattr(callable_, '__name__'): + callable_.__name__ = callable_.__class__.__name__ + + +def _wrap_unary_errors(callable_): + """Map errors for Unary-Unary and Stream-Unary gRPC callables.""" + _patch_callable_name(callable_) + + @six.wraps(callable_) + def error_remapped_callable(*args, **kwargs): + try: + return callable_(*args, **kwargs) + except grpc.RpcError as exc: + six.raise_from(exceptions.from_grpc_error(exc), exc) + + return error_remapped_callable + + +def _wrap_stream_errors(callable_): + """Wrap errors for Unary-Stream and Stream-Stream gRPC callables. + + The callables that return iterators require a bit more logic to re-map + errors when iterating. This wraps both the initial invocation and the + iterator of the return value to re-map errors. + """ + _patch_callable_name(callable_) + + @six.wraps(callable_) + def error_remapped_callable(*args, **kwargs): + try: + result = callable_(*args, **kwargs) + # Note: we are patching the private grpc._channel._Rendezvous._next + # method as magic methods (__next__ in this case) can not be + # patched on a per-instance basis (see + # https://docs.python.org/3/reference/datamodel.html + # #special-lookup). + # In an ideal world, gRPC would return a *specific* interface + # from *StreamMultiCallables, but they return a God class that's + # a combination of basically every interface in gRPC making it + # untenable for us to implement a wrapper object using the same + # interface. + result._next = _wrap_unary_errors(result._next) + return result + except grpc.RpcError as exc: + six.raise_from(exceptions.from_grpc_error(exc), exc) + + return error_remapped_callable + + +def wrap_errors(callable_): + """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error + classes. + + Errors raised by the gRPC callable are mapped to the appropriate + :class:`google.api.core.exceptions.GoogleAPICallError` subclasses. + The original `grpc.RpcError` (which is usually also a `grpc.Call`) is + available from the ``response`` property on the mapped exception. This + is useful for extracting metadata from the original error. + + Args: + callable_ (Callable): A gRPC callable. + + Returns: + Callable: The wrapped gRPC callable. + """ + if isinstance(callable_, _STREAM_WRAP_CLASSES): + return _wrap_stream_errors(callable_) + else: + return _wrap_unary_errors(callable_) diff --git a/core/tests/unit/api_core/helpers/test_grpc_helpers.py b/core/tests/unit/api_core/helpers/test_grpc_helpers.py new file mode 100644 index 000000000000..1b4f3a3025b6 --- /dev/null +++ b/core/tests/unit/api_core/helpers/test_grpc_helpers.py @@ -0,0 +1,130 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +import mock +import pytest + +from google.api.core import exceptions +from google.api.core.helpers import grpc_helpers + + +def test__patch_callable_name(): + callable = mock.Mock(spec=['__class__']) + callable.__class__ = mock.Mock(spec=['__name__']) + callable.__class__.__name__ = 'TestCallable' + + grpc_helpers._patch_callable_name(callable) + + assert callable.__name__ == 'TestCallable' + + +def test__patch_callable_name_no_op(): + callable = mock.Mock(spec=['__name__']) + callable.__name__ = 'test_callable' + + grpc_helpers._patch_callable_name(callable) + + assert callable.__name__ == 'test_callable' + + +class RpcErrorImpl(grpc.RpcError, grpc.Call): + def __init__(self, code): + super(RpcErrorImpl, self).__init__() + self._code = code + + def code(self): + return self._code + + def details(self): + return None + + +def test_wrap_unary_errors(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + callable_ = mock.Mock(spec=['__call__'], side_effect=grpc_error) + + wrapped_callable = grpc_helpers._wrap_unary_errors(callable_) + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + wrapped_callable(1, 2, three='four') + + callable_.assert_called_once_with(1, 2, three='four') + assert exc_info.value.response == grpc_error + + +def test_wrap_stream_errors_invocation(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + callable_ = mock.Mock(spec=['__call__'], side_effect=grpc_error) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + wrapped_callable(1, 2, three='four') + + callable_.assert_called_once_with(1, 2, three='four') + assert exc_info.value.response == grpc_error + + +class RpcResponseIteratorImpl(object): + def __init__(self, exception): + self._exception = exception + + # Note: This matches grpc._channel._Rendezvous._next which is what is + # patched by _wrap_stream_errors. + def _next(self): + raise self._exception + + def __next__(self): # pragma: NO COVER + return self._next() + + def next(self): # pragma: NO COVER + return self._next() + + +def test_wrap_stream_errors_iterator(): + grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) + response_iter = RpcResponseIteratorImpl(grpc_error) + callable_ = mock.Mock(spec=['__call__'], return_value=response_iter) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + got_iterator = wrapped_callable(1, 2, three='four') + + with pytest.raises(exceptions.ServiceUnavailable) as exc_info: + next(got_iterator) + + assert got_iterator == response_iter + callable_.assert_called_once_with(1, 2, three='four') + assert exc_info.value.response == grpc_error + + +@mock.patch('google.api.core.helpers.grpc_helpers._wrap_unary_errors') +def test_wrap_errors_non_streaming(wrap_unary_errors): + callable_ = mock.create_autospec(grpc.UnaryUnaryMultiCallable) + + result = grpc_helpers.wrap_errors(callable_) + + assert result == wrap_unary_errors.return_value + wrap_unary_errors.assert_called_once_with(callable_) + + +@mock.patch('google.api.core.helpers.grpc_helpers._wrap_stream_errors') +def test_wrap_errors_streaming(wrap_stream_errors): + callable_ = mock.create_autospec(grpc.UnaryStreamMultiCallable) + + result = grpc_helpers.wrap_errors(callable_) + + assert result == wrap_stream_errors.return_value + wrap_stream_errors.assert_called_once_with(callable_)