diff --git a/gapic/schema/api.py b/gapic/schema/api.py index f3dbb18971..11cb77f286 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -30,6 +30,7 @@ from google.api import http_pb2 # type: ignore from google.api import resource_pb2 # type: ignore from google.api import service_pb2 # type: ignore +from google.cloud import extended_operations_pb2 as ex_ops_pb2 # type: ignore from google.gapic.metadata import gapic_metadata_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore from google.protobuf import descriptor_pb2 @@ -474,6 +475,20 @@ def requires_package(self, pkg: Tuple[str, ...]) -> bool: for message in proto.all_messages.values() ) + def get_custom_operation_service(self, method: "wrappers.Method") -> "wrappers.Service": + if not method.output.is_extended_operation: + raise ValueError( + f"Method is not an extended operation LRO: {method.name}") + + op_serv_name = self.naming.proto_package + "." + \ + method.options.Extensions[ex_ops_pb2.operation_service] + op_serv = self.services[op_serv_name] + if not op_serv.custom_polling_method: + raise ValueError( + f"Service is not an extended operation operation service: {op_serv.name}") + + return op_serv + class _ProtoBuilder: """A "builder class" for Proto objects. diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 6f059c3963..25fb11ae7b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -360,7 +360,7 @@ def oneof_fields(self, include_optional=False): return oneof_fields @utils.cached_property - def is_diregapic_operation(self) -> bool: + def is_extended_operation(self) -> bool: if not self.name == "Operation": return False @@ -877,7 +877,7 @@ def __getattr__(self, name): @property def is_operation_polling_method(self): - return self.output.is_diregapic_operation and self.options.Extensions[ex_ops_pb2.operation_polling_method] + return self.output.is_extended_operation and self.options.Extensions[ex_ops_pb2.operation_polling_method] @utils.cached_property def client_output(self): diff --git a/tests/unit/schema/test_api.py b/tests/unit/schema/test_api.py index afa82c8cfd..eae44b84b1 100644 --- a/tests/unit/schema/test_api.py +++ b/tests/unit/schema/test_api.py @@ -22,6 +22,7 @@ from google.api import client_pb2 from google.api import resource_pb2 from google.api_core import exceptions +from google.cloud import extended_operations_pb2 as ex_ops_pb2 from google.gapic.metadata import gapic_metadata_pb2 from google.longrunning import operations_pb2 from google.protobuf import descriptor_pb2 @@ -1595,3 +1596,151 @@ def test_http_options(fs): method='get', uri='/v3/{name=projects/*/locations/*/operations/*}', body=None), wrappers.HttpRule(method='get', uri='/v3/{name=/locations/*/operations/*}', body=None)] } + + +def generate_basic_extended_operations_setup(): + T = descriptor_pb2.FieldDescriptorProto.Type + + operation = make_message_pb2( + name="Operation", + fields=( + make_field_pb2(name=name, type=T.Value("TYPE_STRING"), number=i) + for i, name in enumerate(("name", "status", "error_code", "error_message"), start=1) + ), + ) + + for f in operation.field: + options = descriptor_pb2.FieldOptions() + # Note: The field numbers were carefully chosen to be the corresponding enum values. + options.Extensions[ex_ops_pb2.operation_field] = f.number + f.options.MergeFrom(options) + + options = descriptor_pb2.MethodOptions() + options.Extensions[ex_ops_pb2.operation_polling_method] = True + + polling_method = descriptor_pb2.MethodDescriptorProto( + name="Get", + input_type="google.extended_operations.v1.stuff.GetOperation", + output_type="google.extended_operations.v1.stuff.Operation", + options=options, + ) + + delete_input_message = make_message_pb2(name="Input") + delete_output_message = make_message_pb2(name="Output") + ops_service = descriptor_pb2.ServiceDescriptorProto( + name="CustomOperations", + method=[ + polling_method, + descriptor_pb2.MethodDescriptorProto( + name="Delete", + input_type="google.extended_operations.v1.stuff.Input", + output_type="google.extended_operations.v1.stuff.Output", + ), + ], + ) + + request = make_message_pb2( + name="GetOperation", + fields=[ + make_field_pb2(name="name", type=T.Value("TYPE_STRING"), number=1) + ], + ) + + initial_opts = descriptor_pb2.MethodOptions() + initial_opts.Extensions[ex_ops_pb2.operation_service] = ops_service.name + initial_input_message = make_message_pb2(name="Initial") + initial_method = descriptor_pb2.MethodDescriptorProto( + name="CreateTask", + input_type="google.extended_operations.v1.stuff.GetOperation", + output_type="google.extended_operations.v1.stuff.Operation", + options=initial_opts, + ) + + regular_service = descriptor_pb2.ServiceDescriptorProto( + name="RegularService", + method=[ + initial_method, + ], + ) + + file_protos = [ + make_file_pb2( + name="extended_operations.proto", + package="google.extended_operations.v1.stuff", + messages=[ + operation, + request, + delete_output_message, + delete_input_message, + initial_input_message, + ], + services=[ + regular_service, + ops_service, + ], + ), + ] + + return file_protos + + +def test_extended_operations_lro_operation_service(): + file_protos = generate_basic_extended_operations_setup() + api_schema = api.API.build(file_protos) + initial_method = api_schema.services["google.extended_operations.v1.stuff.RegularService"].methods["CreateTask"] + + expected = api_schema.services['google.extended_operations.v1.stuff.CustomOperations'] + actual = api_schema.get_custom_operation_service(initial_method) + + assert expected is actual + + assert actual.custom_polling_method is actual.methods["Get"] + + +def test_extended_operations_lro_operation_service_no_annotation(): + file_protos = generate_basic_extended_operations_setup() + + api_schema = api.API.build(file_protos) + initial_method = api_schema.services["google.extended_operations.v1.stuff.RegularService"].methods["CreateTask"] + # It's easier to manipulate data structures after building the API. + del initial_method.options.Extensions[ex_ops_pb2.operation_service] + + with pytest.raises(KeyError): + api_schema.get_custom_operation_service(initial_method) + + +def test_extended_operations_lro_operation_service_no_such_service(): + file_protos = generate_basic_extended_operations_setup() + + api_schema = api.API.build(file_protos) + initial_method = api_schema.services["google.extended_operations.v1.stuff.RegularService"].methods["CreateTask"] + initial_method.options.Extensions[ex_ops_pb2.operation_service] = "UnrealService" + + with pytest.raises(KeyError): + api_schema.get_custom_operation_service(initial_method) + + +def test_extended_operations_lro_operation_service_not_an_lro(): + file_protos = generate_basic_extended_operations_setup() + + api_schema = api.API.build(file_protos) + initial_method = api_schema.services["google.extended_operations.v1.stuff.RegularService"].methods["CreateTask"] + # Hack to pretend that the initial_method is not an LRO + super(type(initial_method), initial_method).__setattr__( + "output", initial_method.input) + + with pytest.raises(ValueError): + api_schema.get_custom_operation_service(initial_method) + + +def test_extended_operations_lro_operation_service_no_polling_method(): + file_protos = generate_basic_extended_operations_setup() + + api_schema = api.API.build(file_protos) + initial_method = api_schema.services["google.extended_operations.v1.stuff.RegularService"].methods["CreateTask"] + + operation_service = api_schema.services["google.extended_operations.v1.stuff.CustomOperations"] + del operation_service.methods["Get"].options.Extensions[ex_ops_pb2.operation_polling_method] + + with pytest.raises(ValueError): + api_schema.get_custom_operation_service(initial_method) diff --git a/tests/unit/schema/wrappers/test_message.py b/tests/unit/schema/wrappers/test_message.py index 1519fadc67..7cd5910c3f 100644 --- a/tests/unit/schema/wrappers/test_message.py +++ b/tests/unit/schema/wrappers/test_message.py @@ -331,7 +331,7 @@ def test_required_fields(): assert set(request.required_fields) == {mass_kg, length_m, color} -def test_is_diregapic_operation(): +def test_is_extended_operation(): T = descriptor_pb2.FieldDescriptorProto.Type # Canonical Operation @@ -349,7 +349,7 @@ def test_is_diregapic_operation(): options.Extensions[ex_ops_pb2.operation_field] = f.number f.options.MergeFrom(options) - assert operation.is_diregapic_operation + assert operation.is_extended_operation # Missing a required field @@ -367,7 +367,7 @@ def test_is_diregapic_operation(): options.Extensions[ex_ops_pb2.operation_field] = f.number f.options.MergeFrom(options) - assert not missing.is_diregapic_operation + assert not missing.is_extended_operation # Named incorrectly @@ -383,7 +383,7 @@ def test_is_diregapic_operation(): options.Extensions[ex_ops_pb2.operation_field] = f.number f.options.MergeFrom(options) - assert not my_message.is_diregapic_operation + assert not my_message.is_extended_operation # Duplicated annotation for mapping in range(1, 5): @@ -401,4 +401,4 @@ def test_is_diregapic_operation(): f.options.MergeFrom(options) with pytest.raises(TypeError): - duplicate.is_diregapic_operation + duplicate.is_extended_operation diff --git a/tests/unit/schema/wrappers/test_service.py b/tests/unit/schema/wrappers/test_service.py index 7cd41799ed..33e83494f7 100644 --- a/tests/unit/schema/wrappers/test_service.py +++ b/tests/unit/schema/wrappers/test_service.py @@ -589,7 +589,7 @@ def test_operation_polling_method(): assert not user_service.custom_polling_method -def test_diregapic_lro_detection(): +def test_extended_operations_lro_detection(): T = descriptor_pb2.FieldDescriptorProto.Type operation = make_message(