diff --git a/mlserver/grpc/converters.py b/mlserver/grpc/converters.py index 29fbab9be..4827f6623 100644 --- a/mlserver/grpc/converters.py +++ b/mlserver/grpc/converters.py @@ -357,7 +357,7 @@ def from_types( class RepositoryIndexRequestConverter: @classmethod def to_types( - cls, pb_object: mr_pb.RepositoryIndexRequest + cls, pb_object: Union[pb.RepositoryIndexRequest, mr_pb.RepositoryIndexRequest] ) -> types.RepositoryIndexRequest: return types.RepositoryIndexRequest( ready=pb_object.ready, @@ -366,46 +366,67 @@ def to_types( @classmethod def from_types( cls, type_object: types.RepositoryIndexRequest - ) -> mr_pb.RepositoryIndexRequest: + ) -> Union[pb.RepositoryIndexRequest, mr_pb.RepositoryIndexRequest]: raise NotImplementedError("Implement me") class RepositoryIndexResponseConverter: @classmethod def to_types( - cls, pb_object: mr_pb.RepositoryIndexResponse + cls, pb_object: Union[pb.RepositoryIndexResponse, mr_pb.RepositoryIndexResponse] ) -> types.RepositoryIndexResponse: raise NotImplementedError("Implement me") @classmethod def from_types( - cls, type_object: types.RepositoryIndexResponse - ) -> mr_pb.RepositoryIndexResponse: - return mr_pb.RepositoryIndexResponse( - models=[ - RepositoryIndexResponseItemConverter.from_types(model) - for model in type_object - ] - ) + cls, + type_object: types.RepositoryIndexResponse, + use_model_repository: bool = False, + ) -> Union[pb.RepositoryIndexResponse, mr_pb.RepositoryIndexResponse]: + models = [ + RepositoryIndexResponseItemConverter.from_types( + model, use_model_repository=use_model_repository + ) + for model in type_object + ] + if use_model_repository: + return mr_pb.RepositoryIndexResponse(models=models) # type: ignore + + return pb.RepositoryIndexResponse(models=models) # type: ignore class RepositoryIndexResponseItemConverter: @classmethod def to_types( - cls, pb_object: mr_pb.RepositoryIndexResponse.ModelIndex + cls, + pb_object: Union[ + pb.RepositoryIndexResponse.ModelIndex, + mr_pb.RepositoryIndexResponse.ModelIndex, + ], ) -> types.RepositoryIndexResponseItem: raise NotImplementedError("Implement me") @classmethod def from_types( - cls, type_object: types.RepositoryIndexResponseItem - ) -> mr_pb.RepositoryIndexResponse.ModelIndex: - model_index = mr_pb.RepositoryIndexResponse.ModelIndex( + cls, + type_object: types.RepositoryIndexResponseItem, + use_model_repository: bool = False, + ) -> Union[ + pb.RepositoryIndexResponse.ModelIndex, mr_pb.RepositoryIndexResponse.ModelIndex + ]: + model_index = pb.RepositoryIndexResponse.ModelIndex( name=type_object.name, state=type_object.state.value, reason=type_object.reason, ) + if use_model_repository: + model_index = mr_pb.RepositoryIndexResponse.ModelIndex( # type: ignore + name=type_object.name, + state=type_object.state.value, + reason=type_object.reason, + ) + if type_object.version is not None: model_index.version = type_object.version diff --git a/mlserver/grpc/dataplane_pb2.py b/mlserver/grpc/dataplane_pb2.py index 54b94d294..cc4290268 100644 --- a/mlserver/grpc/dataplane_pb2.py +++ b/mlserver/grpc/dataplane_pb2.py @@ -14,7 +14,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0f\x64\x61taplane.proto\x12\tinference"\x13\n\x11ServerLiveRequest""\n\x12ServerLiveResponse\x12\x0c\n\x04live\x18\x01 \x01(\x08"\x14\n\x12ServerReadyRequest"$\n\x13ServerReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"2\n\x11ModelReadyRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t"#\n\x12ModelReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"\x17\n\x15ServerMetadataRequest"K\n\x16ServerMetadataResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\nextensions\x18\x03 \x03(\t"5\n\x14ModelMetadataRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t"\xc5\x04\n\x15ModelMetadataResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08versions\x18\x02 \x03(\t\x12\x10\n\x08platform\x18\x03 \x01(\t\x12?\n\x06inputs\x18\x04 \x03(\x0b\x32/.inference.ModelMetadataResponse.TensorMetadata\x12@\n\x07outputs\x18\x05 \x03(\x0b\x32/.inference.ModelMetadataResponse.TensorMetadata\x12\x44\n\nparameters\x18\x06 \x03(\x0b\x32\x30.inference.ModelMetadataResponse.ParametersEntry\x1a\xe2\x01\n\x0eTensorMetadata\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12S\n\nparameters\x18\x04 \x03(\x0b\x32?.inference.ModelMetadataResponse.TensorMetadata.ParametersEntry\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"\xd2\x06\n\x11ModelInferRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12@\n\nparameters\x18\x04 \x03(\x0b\x32,.inference.ModelInferRequest.ParametersEntry\x12=\n\x06inputs\x18\x05 \x03(\x0b\x32-.inference.ModelInferRequest.InferInputTensor\x12H\n\x07outputs\x18\x06 \x03(\x0b\x32\x37.inference.ModelInferRequest.InferRequestedOutputTensor\x1a\x94\x02\n\x10InferInputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12Q\n\nparameters\x18\x04 \x03(\x0b\x32=.inference.ModelInferRequest.InferInputTensor.ParametersEntry\x12\x30\n\x08\x63ontents\x18\x05 \x01(\x0b\x32\x1e.inference.InferTensorContents\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1a\xd5\x01\n\x1aInferRequestedOutputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12[\n\nparameters\x18\x02 \x03(\x0b\x32G.inference.ModelInferRequest.InferRequestedOutputTensor.ParametersEntry\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"\xb8\x04\n\x12ModelInferResponse\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\x41\n\nparameters\x18\x04 \x03(\x0b\x32-.inference.ModelInferResponse.ParametersEntry\x12@\n\x07outputs\x18\x05 \x03(\x0b\x32/.inference.ModelInferResponse.InferOutputTensor\x1a\x97\x02\n\x11InferOutputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12S\n\nparameters\x18\x04 \x03(\x0b\x32?.inference.ModelInferResponse.InferOutputTensor.ParametersEntry\x12\x30\n\x08\x63ontents\x18\x05 \x01(\x0b\x32\x1e.inference.InferTensorContents\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"i\n\x0eInferParameter\x12\x14\n\nbool_param\x18\x01 \x01(\x08H\x00\x12\x15\n\x0bint64_param\x18\x02 \x01(\x03H\x00\x12\x16\n\x0cstring_param\x18\x03 \x01(\tH\x00\x42\x12\n\x10parameter_choice"\xd0\x01\n\x13InferTensorContents\x12\x15\n\rbool_contents\x18\x01 \x03(\x08\x12\x14\n\x0cint_contents\x18\x02 \x03(\x05\x12\x16\n\x0eint64_contents\x18\x03 \x03(\x03\x12\x15\n\ruint_contents\x18\x04 \x03(\r\x12\x17\n\x0fuint64_contents\x18\x05 \x03(\x04\x12\x15\n\rfp32_contents\x18\x06 \x03(\x02\x12\x15\n\rfp64_contents\x18\x07 \x03(\x01\x12\x16\n\x0e\x62ytes_contents\x18\x08 \x03(\x0c\x32\xfc\x03\n\x14GRPCInferenceService\x12K\n\nServerLive\x12\x1c.inference.ServerLiveRequest\x1a\x1d.inference.ServerLiveResponse"\x00\x12N\n\x0bServerReady\x12\x1d.inference.ServerReadyRequest\x1a\x1e.inference.ServerReadyResponse"\x00\x12K\n\nModelReady\x12\x1c.inference.ModelReadyRequest\x1a\x1d.inference.ModelReadyResponse"\x00\x12W\n\x0eServerMetadata\x12 .inference.ServerMetadataRequest\x1a!.inference.ServerMetadataResponse"\x00\x12T\n\rModelMetadata\x12\x1f.inference.ModelMetadataRequest\x1a .inference.ModelMetadataResponse"\x00\x12K\n\nModelInfer\x12\x1c.inference.ModelInferRequest\x1a\x1d.inference.ModelInferResponse"\x00\x62\x06proto3' + b'\n\x0f\x64\x61taplane.proto\x12\tinference"\x13\n\x11ServerLiveRequest""\n\x12ServerLiveResponse\x12\x0c\n\x04live\x18\x01 \x01(\x08"\x14\n\x12ServerReadyRequest"$\n\x13ServerReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"2\n\x11ModelReadyRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t"#\n\x12ModelReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"\x17\n\x15ServerMetadataRequest"K\n\x16ServerMetadataResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\nextensions\x18\x03 \x03(\t"5\n\x14ModelMetadataRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t"\xc5\x04\n\x15ModelMetadataResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08versions\x18\x02 \x03(\t\x12\x10\n\x08platform\x18\x03 \x01(\t\x12?\n\x06inputs\x18\x04 \x03(\x0b\x32/.inference.ModelMetadataResponse.TensorMetadata\x12@\n\x07outputs\x18\x05 \x03(\x0b\x32/.inference.ModelMetadataResponse.TensorMetadata\x12\x44\n\nparameters\x18\x06 \x03(\x0b\x32\x30.inference.ModelMetadataResponse.ParametersEntry\x1a\xe2\x01\n\x0eTensorMetadata\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12S\n\nparameters\x18\x04 \x03(\x0b\x32?.inference.ModelMetadataResponse.TensorMetadata.ParametersEntry\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"\xd2\x06\n\x11ModelInferRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12@\n\nparameters\x18\x04 \x03(\x0b\x32,.inference.ModelInferRequest.ParametersEntry\x12=\n\x06inputs\x18\x05 \x03(\x0b\x32-.inference.ModelInferRequest.InferInputTensor\x12H\n\x07outputs\x18\x06 \x03(\x0b\x32\x37.inference.ModelInferRequest.InferRequestedOutputTensor\x1a\x94\x02\n\x10InferInputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12Q\n\nparameters\x18\x04 \x03(\x0b\x32=.inference.ModelInferRequest.InferInputTensor.ParametersEntry\x12\x30\n\x08\x63ontents\x18\x05 \x01(\x0b\x32\x1e.inference.InferTensorContents\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1a\xd5\x01\n\x1aInferRequestedOutputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12[\n\nparameters\x18\x02 \x03(\x0b\x32G.inference.ModelInferRequest.InferRequestedOutputTensor.ParametersEntry\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"\xb8\x04\n\x12ModelInferResponse\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\x41\n\nparameters\x18\x04 \x03(\x0b\x32-.inference.ModelInferResponse.ParametersEntry\x12@\n\x07outputs\x18\x05 \x03(\x0b\x32/.inference.ModelInferResponse.InferOutputTensor\x1a\x97\x02\n\x11InferOutputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12S\n\nparameters\x18\x04 \x03(\x0b\x32?.inference.ModelInferResponse.InferOutputTensor.ParametersEntry\x12\x30\n\x08\x63ontents\x18\x05 \x01(\x0b\x32\x1e.inference.InferTensorContents\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"i\n\x0eInferParameter\x12\x14\n\nbool_param\x18\x01 \x01(\x08H\x00\x12\x15\n\x0bint64_param\x18\x02 \x01(\x03H\x00\x12\x16\n\x0cstring_param\x18\x03 \x01(\tH\x00\x42\x12\n\x10parameter_choice"\xd0\x01\n\x13InferTensorContents\x12\x15\n\rbool_contents\x18\x01 \x03(\x08\x12\x14\n\x0cint_contents\x18\x02 \x03(\x05\x12\x16\n\x0eint64_contents\x18\x03 \x03(\x03\x12\x15\n\ruint_contents\x18\x04 \x03(\r\x12\x17\n\x0fuint64_contents\x18\x05 \x03(\x04\x12\x15\n\rfp32_contents\x18\x06 \x03(\x02\x12\x15\n\rfp64_contents\x18\x07 \x03(\x01\x12\x16\n\x0e\x62ytes_contents\x18\x08 \x03(\x0c"\x8a\x01\n\x18ModelRepositoryParameter\x12\x14\n\nbool_param\x18\x01 \x01(\x08H\x00\x12\x15\n\x0bint64_param\x18\x02 \x01(\x03H\x00\x12\x16\n\x0cstring_param\x18\x03 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_param\x18\x04 \x01(\x0cH\x00\x42\x12\n\x10parameter_choice"@\n\x16RepositoryIndexRequest\x12\x17\n\x0frepository_name\x18\x01 \x01(\t\x12\r\n\x05ready\x18\x02 \x01(\x08"\xa4\x01\n\x17RepositoryIndexResponse\x12=\n\x06models\x18\x01 \x03(\x0b\x32-.inference.RepositoryIndexResponse.ModelIndex\x1aJ\n\nModelIndex\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\r\n\x05state\x18\x03 \x01(\t\x12\x0e\n\x06reason\x18\x04 \x01(\t"\xec\x01\n\x1aRepositoryModelLoadRequest\x12\x17\n\x0frepository_name\x18\x01 \x01(\t\x12\x12\n\nmodel_name\x18\x02 \x01(\t\x12I\n\nparameters\x18\x03 \x03(\x0b\x32\x35.inference.RepositoryModelLoadRequest.ParametersEntry\x1aV\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.inference.ModelRepositoryParameter:\x02\x38\x01"\x1d\n\x1bRepositoryModelLoadResponse"\xf0\x01\n\x1cRepositoryModelUnloadRequest\x12\x17\n\x0frepository_name\x18\x01 \x01(\t\x12\x12\n\nmodel_name\x18\x02 \x01(\t\x12K\n\nparameters\x18\x03 \x03(\x0b\x32\x37.inference.RepositoryModelUnloadRequest.ParametersEntry\x1aV\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.inference.ModelRepositoryParameter:\x02\x38\x01"\x1f\n\x1dRepositoryModelUnloadResponse2\xae\x06\n\x14GRPCInferenceService\x12K\n\nServerLive\x12\x1c.inference.ServerLiveRequest\x1a\x1d.inference.ServerLiveResponse"\x00\x12N\n\x0bServerReady\x12\x1d.inference.ServerReadyRequest\x1a\x1e.inference.ServerReadyResponse"\x00\x12K\n\nModelReady\x12\x1c.inference.ModelReadyRequest\x1a\x1d.inference.ModelReadyResponse"\x00\x12W\n\x0eServerMetadata\x12 .inference.ServerMetadataRequest\x1a!.inference.ServerMetadataResponse"\x00\x12T\n\rModelMetadata\x12\x1f.inference.ModelMetadataRequest\x1a .inference.ModelMetadataResponse"\x00\x12K\n\nModelInfer\x12\x1c.inference.ModelInferRequest\x1a\x1d.inference.ModelInferResponse"\x00\x12Z\n\x0fRepositoryIndex\x12!.inference.RepositoryIndexRequest\x1a".inference.RepositoryIndexResponse"\x00\x12\x66\n\x13RepositoryModelLoad\x12%.inference.RepositoryModelLoadRequest\x1a&.inference.RepositoryModelLoadResponse"\x00\x12l\n\x15RepositoryModelUnload\x12\'.inference.RepositoryModelUnloadRequest\x1a(.inference.RepositoryModelUnloadResponse"\x00\x62\x06proto3' ) @@ -67,6 +67,30 @@ ] _INFERPARAMETER = DESCRIPTOR.message_types_by_name["InferParameter"] _INFERTENSORCONTENTS = DESCRIPTOR.message_types_by_name["InferTensorContents"] +_MODELREPOSITORYPARAMETER = DESCRIPTOR.message_types_by_name["ModelRepositoryParameter"] +_REPOSITORYINDEXREQUEST = DESCRIPTOR.message_types_by_name["RepositoryIndexRequest"] +_REPOSITORYINDEXRESPONSE = DESCRIPTOR.message_types_by_name["RepositoryIndexResponse"] +_REPOSITORYINDEXRESPONSE_MODELINDEX = _REPOSITORYINDEXRESPONSE.nested_types_by_name[ + "ModelIndex" +] +_REPOSITORYMODELLOADREQUEST = DESCRIPTOR.message_types_by_name[ + "RepositoryModelLoadRequest" +] +_REPOSITORYMODELLOADREQUEST_PARAMETERSENTRY = ( + _REPOSITORYMODELLOADREQUEST.nested_types_by_name["ParametersEntry"] +) +_REPOSITORYMODELLOADRESPONSE = DESCRIPTOR.message_types_by_name[ + "RepositoryModelLoadResponse" +] +_REPOSITORYMODELUNLOADREQUEST = DESCRIPTOR.message_types_by_name[ + "RepositoryModelUnloadRequest" +] +_REPOSITORYMODELUNLOADREQUEST_PARAMETERSENTRY = ( + _REPOSITORYMODELUNLOADREQUEST.nested_types_by_name["ParametersEntry"] +) +_REPOSITORYMODELUNLOADRESPONSE = DESCRIPTOR.message_types_by_name[ + "RepositoryModelUnloadResponse" +] ServerLiveRequest = _reflection.GeneratedProtocolMessageType( "ServerLiveRequest", (_message.Message,), @@ -331,6 +355,113 @@ ) _sym_db.RegisterMessage(InferTensorContents) +ModelRepositoryParameter = _reflection.GeneratedProtocolMessageType( + "ModelRepositoryParameter", + (_message.Message,), + { + "DESCRIPTOR": _MODELREPOSITORYPARAMETER, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.ModelRepositoryParameter) + }, +) +_sym_db.RegisterMessage(ModelRepositoryParameter) + +RepositoryIndexRequest = _reflection.GeneratedProtocolMessageType( + "RepositoryIndexRequest", + (_message.Message,), + { + "DESCRIPTOR": _REPOSITORYINDEXREQUEST, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryIndexRequest) + }, +) +_sym_db.RegisterMessage(RepositoryIndexRequest) + +RepositoryIndexResponse = _reflection.GeneratedProtocolMessageType( + "RepositoryIndexResponse", + (_message.Message,), + { + "ModelIndex": _reflection.GeneratedProtocolMessageType( + "ModelIndex", + (_message.Message,), + { + "DESCRIPTOR": _REPOSITORYINDEXRESPONSE_MODELINDEX, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryIndexResponse.ModelIndex) + }, + ), + "DESCRIPTOR": _REPOSITORYINDEXRESPONSE, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryIndexResponse) + }, +) +_sym_db.RegisterMessage(RepositoryIndexResponse) +_sym_db.RegisterMessage(RepositoryIndexResponse.ModelIndex) + +RepositoryModelLoadRequest = _reflection.GeneratedProtocolMessageType( + "RepositoryModelLoadRequest", + (_message.Message,), + { + "ParametersEntry": _reflection.GeneratedProtocolMessageType( + "ParametersEntry", + (_message.Message,), + { + "DESCRIPTOR": _REPOSITORYMODELLOADREQUEST_PARAMETERSENTRY, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryModelLoadRequest.ParametersEntry) + }, + ), + "DESCRIPTOR": _REPOSITORYMODELLOADREQUEST, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryModelLoadRequest) + }, +) +_sym_db.RegisterMessage(RepositoryModelLoadRequest) +_sym_db.RegisterMessage(RepositoryModelLoadRequest.ParametersEntry) + +RepositoryModelLoadResponse = _reflection.GeneratedProtocolMessageType( + "RepositoryModelLoadResponse", + (_message.Message,), + { + "DESCRIPTOR": _REPOSITORYMODELLOADRESPONSE, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryModelLoadResponse) + }, +) +_sym_db.RegisterMessage(RepositoryModelLoadResponse) + +RepositoryModelUnloadRequest = _reflection.GeneratedProtocolMessageType( + "RepositoryModelUnloadRequest", + (_message.Message,), + { + "ParametersEntry": _reflection.GeneratedProtocolMessageType( + "ParametersEntry", + (_message.Message,), + { + "DESCRIPTOR": _REPOSITORYMODELUNLOADREQUEST_PARAMETERSENTRY, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryModelUnloadRequest.ParametersEntry) + }, + ), + "DESCRIPTOR": _REPOSITORYMODELUNLOADREQUEST, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryModelUnloadRequest) + }, +) +_sym_db.RegisterMessage(RepositoryModelUnloadRequest) +_sym_db.RegisterMessage(RepositoryModelUnloadRequest.ParametersEntry) + +RepositoryModelUnloadResponse = _reflection.GeneratedProtocolMessageType( + "RepositoryModelUnloadResponse", + (_message.Message,), + { + "DESCRIPTOR": _REPOSITORYMODELUNLOADRESPONSE, + "__module__": "dataplane_pb2" + # @@protoc_insertion_point(class_scope:inference.RepositoryModelUnloadResponse) + }, +) +_sym_db.RegisterMessage(RepositoryModelUnloadResponse) + _GRPCINFERENCESERVICE = DESCRIPTOR.services_by_name["GRPCInferenceService"] if _descriptor._USE_C_DESCRIPTORS == False: @@ -351,6 +482,10 @@ _MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY._serialized_options = b"8\001" _MODELINFERRESPONSE_PARAMETERSENTRY._options = None _MODELINFERRESPONSE_PARAMETERSENTRY._serialized_options = b"8\001" + _REPOSITORYMODELLOADREQUEST_PARAMETERSENTRY._options = None + _REPOSITORYMODELLOADREQUEST_PARAMETERSENTRY._serialized_options = b"8\001" + _REPOSITORYMODELUNLOADREQUEST_PARAMETERSENTRY._options = None + _REPOSITORYMODELUNLOADREQUEST_PARAMETERSENTRY._serialized_options = b"8\001" _SERVERLIVEREQUEST._serialized_start = 30 _SERVERLIVEREQUEST._serialized_end = 49 _SERVERLIVERESPONSE._serialized_start = 51 @@ -403,6 +538,26 @@ _INFERPARAMETER._serialized_end = 2506 _INFERTENSORCONTENTS._serialized_start = 2509 _INFERTENSORCONTENTS._serialized_end = 2717 - _GRPCINFERENCESERVICE._serialized_start = 2720 - _GRPCINFERENCESERVICE._serialized_end = 3228 + _MODELREPOSITORYPARAMETER._serialized_start = 2720 + _MODELREPOSITORYPARAMETER._serialized_end = 2858 + _REPOSITORYINDEXREQUEST._serialized_start = 2860 + _REPOSITORYINDEXREQUEST._serialized_end = 2924 + _REPOSITORYINDEXRESPONSE._serialized_start = 2927 + _REPOSITORYINDEXRESPONSE._serialized_end = 3091 + _REPOSITORYINDEXRESPONSE_MODELINDEX._serialized_start = 3017 + _REPOSITORYINDEXRESPONSE_MODELINDEX._serialized_end = 3091 + _REPOSITORYMODELLOADREQUEST._serialized_start = 3094 + _REPOSITORYMODELLOADREQUEST._serialized_end = 3330 + _REPOSITORYMODELLOADREQUEST_PARAMETERSENTRY._serialized_start = 3244 + _REPOSITORYMODELLOADREQUEST_PARAMETERSENTRY._serialized_end = 3330 + _REPOSITORYMODELLOADRESPONSE._serialized_start = 3332 + _REPOSITORYMODELLOADRESPONSE._serialized_end = 3361 + _REPOSITORYMODELUNLOADREQUEST._serialized_start = 3364 + _REPOSITORYMODELUNLOADREQUEST._serialized_end = 3604 + _REPOSITORYMODELUNLOADREQUEST_PARAMETERSENTRY._serialized_start = 3244 + _REPOSITORYMODELUNLOADREQUEST_PARAMETERSENTRY._serialized_end = 3330 + _REPOSITORYMODELUNLOADRESPONSE._serialized_start = 3606 + _REPOSITORYMODELUNLOADRESPONSE._serialized_end = 3637 + _GRPCINFERENCESERVICE._serialized_start = 3640 + _GRPCINFERENCESERVICE._serialized_end = 4454 # @@protoc_insertion_point(module_scope) diff --git a/mlserver/grpc/dataplane_pb2.pyi b/mlserver/grpc/dataplane_pb2.pyi index a90a7ed93..91f22c32c 100644 --- a/mlserver/grpc/dataplane_pb2.pyi +++ b/mlserver/grpc/dataplane_pb2.pyi @@ -956,3 +956,311 @@ class InferTensorContents(google.protobuf.message.Message): ) -> None: ... global___InferTensorContents = InferTensorContents + +class ModelRepositoryParameter(google.protobuf.message.Message): + """ + Messages for the Repository API + + NOTE: These messages used to exist previously on a different protobuf + definition. However, they have now been merged with the main + GRPCInferenceService. + + An model repository parameter value. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + BOOL_PARAM_FIELD_NUMBER: builtins.int + INT64_PARAM_FIELD_NUMBER: builtins.int + STRING_PARAM_FIELD_NUMBER: builtins.int + BYTES_PARAM_FIELD_NUMBER: builtins.int + bool_param: builtins.bool = ... + """A boolean parameter value.""" + + int64_param: builtins.int = ... + """An int64 parameter value.""" + + string_param: typing.Text = ... + """A string parameter value.""" + + bytes_param: builtins.bytes = ... + """A bytes parameter value.""" + def __init__( + self, + *, + bool_param: builtins.bool = ..., + int64_param: builtins.int = ..., + string_param: typing.Text = ..., + bytes_param: builtins.bytes = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "bool_param", + b"bool_param", + "bytes_param", + b"bytes_param", + "int64_param", + b"int64_param", + "parameter_choice", + b"parameter_choice", + "string_param", + b"string_param", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "bool_param", + b"bool_param", + "bytes_param", + b"bytes_param", + "int64_param", + b"int64_param", + "parameter_choice", + b"parameter_choice", + "string_param", + b"string_param", + ], + ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["parameter_choice", b"parameter_choice"], + ) -> typing.Optional[ + typing_extensions.Literal[ + "bool_param", "int64_param", "string_param", "bytes_param" + ] + ]: ... + +global___ModelRepositoryParameter = ModelRepositoryParameter + +class RepositoryIndexRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + REPOSITORY_NAME_FIELD_NUMBER: builtins.int + READY_FIELD_NUMBER: builtins.int + repository_name: typing.Text = ... + """The name of the repository. If empty the index is returned + for all repositories. + """ + + ready: builtins.bool = ... + """If true return only models currently ready for inferencing.""" + def __init__( + self, + *, + repository_name: typing.Text = ..., + ready: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "ready", b"ready", "repository_name", b"repository_name" + ], + ) -> None: ... + +global___RepositoryIndexRequest = RepositoryIndexRequest + +class RepositoryIndexResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + class ModelIndex(google.protobuf.message.Message): + """Index entry for a model.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + NAME_FIELD_NUMBER: builtins.int + VERSION_FIELD_NUMBER: builtins.int + STATE_FIELD_NUMBER: builtins.int + REASON_FIELD_NUMBER: builtins.int + name: typing.Text = ... + """The name of the model.""" + + version: typing.Text = ... + """The version of the model.""" + + state: typing.Text = ... + """The state of the model.""" + + reason: typing.Text = ... + """The reason, if any, that the model is in the given state.""" + def __init__( + self, + *, + name: typing.Text = ..., + version: typing.Text = ..., + state: typing.Text = ..., + reason: typing.Text = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "name", + b"name", + "reason", + b"reason", + "state", + b"state", + "version", + b"version", + ], + ) -> None: ... + MODELS_FIELD_NUMBER: builtins.int + @property + def models( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___RepositoryIndexResponse.ModelIndex + ]: + """An index entry for each model.""" + pass + def __init__( + self, + *, + models: typing.Optional[ + typing.Iterable[global___RepositoryIndexResponse.ModelIndex] + ] = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["models", b"models"] + ) -> None: ... + +global___RepositoryIndexResponse = RepositoryIndexResponse + +class RepositoryModelLoadRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + class ParametersEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text = ... + @property + def value(self) -> global___ModelRepositoryParameter: ... + def __init__( + self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ModelRepositoryParameter] = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["key", b"key", "value", b"value"], + ) -> None: ... + REPOSITORY_NAME_FIELD_NUMBER: builtins.int + MODEL_NAME_FIELD_NUMBER: builtins.int + PARAMETERS_FIELD_NUMBER: builtins.int + repository_name: typing.Text = ... + """The name of the repository to load from. If empty the model + is loaded from any repository. + """ + + model_name: typing.Text = ... + """The name of the model to load, or reload.""" + @property + def parameters( + self, + ) -> google.protobuf.internal.containers.MessageMap[ + typing.Text, global___ModelRepositoryParameter + ]: + """Optional model repository request parameters.""" + pass + def __init__( + self, + *, + repository_name: typing.Text = ..., + model_name: typing.Text = ..., + parameters: typing.Optional[ + typing.Mapping[typing.Text, global___ModelRepositoryParameter] + ] = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "model_name", + b"model_name", + "parameters", + b"parameters", + "repository_name", + b"repository_name", + ], + ) -> None: ... + +global___RepositoryModelLoadRequest = RepositoryModelLoadRequest + +class RepositoryModelLoadResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + def __init__( + self, + ) -> None: ... + +global___RepositoryModelLoadResponse = RepositoryModelLoadResponse + +class RepositoryModelUnloadRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + class ParametersEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text = ... + @property + def value(self) -> global___ModelRepositoryParameter: ... + def __init__( + self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ModelRepositoryParameter] = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["key", b"key", "value", b"value"], + ) -> None: ... + REPOSITORY_NAME_FIELD_NUMBER: builtins.int + MODEL_NAME_FIELD_NUMBER: builtins.int + PARAMETERS_FIELD_NUMBER: builtins.int + repository_name: typing.Text = ... + """The name of the repository from which the model was originally + loaded. If empty the repository is not considered. + """ + + model_name: typing.Text = ... + """The name of the model to unload.""" + @property + def parameters( + self, + ) -> google.protobuf.internal.containers.MessageMap[ + typing.Text, global___ModelRepositoryParameter + ]: + """Optional model repository request parameters.""" + pass + def __init__( + self, + *, + repository_name: typing.Text = ..., + model_name: typing.Text = ..., + parameters: typing.Optional[ + typing.Mapping[typing.Text, global___ModelRepositoryParameter] + ] = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "model_name", + b"model_name", + "parameters", + b"parameters", + "repository_name", + b"repository_name", + ], + ) -> None: ... + +global___RepositoryModelUnloadRequest = RepositoryModelUnloadRequest + +class RepositoryModelUnloadResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... + def __init__( + self, + ) -> None: ... + +global___RepositoryModelUnloadResponse = RepositoryModelUnloadResponse diff --git a/mlserver/grpc/dataplane_pb2_grpc.py b/mlserver/grpc/dataplane_pb2_grpc.py index 85e8979f6..9ad255a0a 100644 --- a/mlserver/grpc/dataplane_pb2_grpc.py +++ b/mlserver/grpc/dataplane_pb2_grpc.py @@ -47,6 +47,21 @@ def __init__(self, channel): request_serializer=dataplane__pb2.ModelInferRequest.SerializeToString, response_deserializer=dataplane__pb2.ModelInferResponse.FromString, ) + self.RepositoryIndex = channel.unary_unary( + "/inference.GRPCInferenceService/RepositoryIndex", + request_serializer=dataplane__pb2.RepositoryIndexRequest.SerializeToString, + response_deserializer=dataplane__pb2.RepositoryIndexResponse.FromString, + ) + self.RepositoryModelLoad = channel.unary_unary( + "/inference.GRPCInferenceService/RepositoryModelLoad", + request_serializer=dataplane__pb2.RepositoryModelLoadRequest.SerializeToString, + response_deserializer=dataplane__pb2.RepositoryModelLoadResponse.FromString, + ) + self.RepositoryModelUnload = channel.unary_unary( + "/inference.GRPCInferenceService/RepositoryModelUnload", + request_serializer=dataplane__pb2.RepositoryModelUnloadRequest.SerializeToString, + response_deserializer=dataplane__pb2.RepositoryModelUnloadResponse.FromString, + ) class GRPCInferenceServiceServicer(object): @@ -91,6 +106,24 @@ def ModelInfer(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def RepositoryIndex(self, request, context): + """Get the index of model repository contents.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def RepositoryModelLoad(self, request, context): + """Load or reload a model from a repository.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def RepositoryModelUnload(self, request, context): + """Unload a model.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_GRPCInferenceServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -124,6 +157,21 @@ def add_GRPCInferenceServiceServicer_to_server(servicer, server): request_deserializer=dataplane__pb2.ModelInferRequest.FromString, response_serializer=dataplane__pb2.ModelInferResponse.SerializeToString, ), + "RepositoryIndex": grpc.unary_unary_rpc_method_handler( + servicer.RepositoryIndex, + request_deserializer=dataplane__pb2.RepositoryIndexRequest.FromString, + response_serializer=dataplane__pb2.RepositoryIndexResponse.SerializeToString, + ), + "RepositoryModelLoad": grpc.unary_unary_rpc_method_handler( + servicer.RepositoryModelLoad, + request_deserializer=dataplane__pb2.RepositoryModelLoadRequest.FromString, + response_serializer=dataplane__pb2.RepositoryModelLoadResponse.SerializeToString, + ), + "RepositoryModelUnload": grpc.unary_unary_rpc_method_handler( + servicer.RepositoryModelUnload, + request_deserializer=dataplane__pb2.RepositoryModelUnloadRequest.FromString, + response_serializer=dataplane__pb2.RepositoryModelUnloadResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "inference.GRPCInferenceService", rpc_method_handlers @@ -311,3 +359,90 @@ def ModelInfer( timeout, metadata, ) + + @staticmethod + def RepositoryIndex( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/inference.GRPCInferenceService/RepositoryIndex", + dataplane__pb2.RepositoryIndexRequest.SerializeToString, + dataplane__pb2.RepositoryIndexResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def RepositoryModelLoad( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/inference.GRPCInferenceService/RepositoryModelLoad", + dataplane__pb2.RepositoryModelLoadRequest.SerializeToString, + dataplane__pb2.RepositoryModelLoadResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def RepositoryModelUnload( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/inference.GRPCInferenceService/RepositoryModelUnload", + dataplane__pb2.RepositoryModelUnloadRequest.SerializeToString, + dataplane__pb2.RepositoryModelUnloadResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/mlserver/grpc/model_repository.py b/mlserver/grpc/model_repository.py new file mode 100644 index 000000000..cec05f57a --- /dev/null +++ b/mlserver/grpc/model_repository.py @@ -0,0 +1,57 @@ +from typing import Callable + +from . import model_repository_pb2 as mr_pb +from .converters import ( + RepositoryIndexRequestConverter, + RepositoryIndexResponseConverter, +) +from .model_repository_pb2_grpc import ModelRepositoryServiceServicer +from .utils import handle_mlserver_error +from .logging import logger +from ..handlers import ModelRepositoryHandlers + + +def _deprecated(f: Callable): + async def _inner(self, request, context): + logger.warning( + "DEPRECATED!! " + f"The `inference.model_repository.ModelRepositoryService/{f.__name__}` " + "endpoint has now been deprecated and will be removed in " + "MLServer 1.2.0. " + f"Please use the `inference.GRPCInferenceService/{f.__name__}` " + "endpoint instead." + ) + return await f(self, request, context) + + return _inner + + +class ModelRepositoryServicer(ModelRepositoryServiceServicer): + def __init__(self, handlers: ModelRepositoryHandlers): + self._handlers = handlers + + @_deprecated + async def RepositoryIndex( + self, request: mr_pb.RepositoryIndexRequest, context + ) -> mr_pb.RepositoryIndexResponse: + payload = RepositoryIndexRequestConverter.to_types(request) + index = await self._handlers.index(payload) + return RepositoryIndexResponseConverter.from_types( # type: ignore + index, use_model_repository=True + ) + + @handle_mlserver_error + @_deprecated + async def RepositoryModelLoad( + self, request: mr_pb.RepositoryModelLoadRequest, context + ) -> mr_pb.RepositoryModelLoadResponse: + await self._handlers.load(request.model_name) + return mr_pb.RepositoryModelLoadResponse() + + @handle_mlserver_error + @_deprecated + async def RepositoryModelUnload( + self, request: mr_pb.RepositoryModelUnloadRequest, context + ) -> mr_pb.RepositoryModelUnloadResponse: + await self._handlers.unload(request.model_name) + return mr_pb.RepositoryModelUnloadResponse() diff --git a/mlserver/grpc/server.py b/mlserver/grpc/server.py index 06974585e..f4d75aadb 100644 --- a/mlserver/grpc/server.py +++ b/mlserver/grpc/server.py @@ -5,7 +5,8 @@ from ..handlers import DataPlane, ModelRepositoryHandlers from ..settings import Settings -from .servicers import InferenceServicer, ModelRepositoryServicer +from .servicers import InferenceServicer +from .model_repository import ModelRepositoryServicer from .dataplane_pb2_grpc import add_GRPCInferenceServiceServicer_to_server from .model_repository_pb2_grpc import add_ModelRepositoryServiceServicer_to_server from .interceptors import LoggingInterceptor, PromServerInterceptor @@ -27,7 +28,9 @@ def __init__( self._model_repository_handlers = model_repository_handlers def _create_server(self): - self._inference_servicer = InferenceServicer(self._data_plane) + self._inference_servicer = InferenceServicer( + self._data_plane, self._model_repository_handlers + ) self._model_repository_servicer = ModelRepositoryServicer( self._model_repository_handlers ) diff --git a/mlserver/grpc/servicers.py b/mlserver/grpc/servicers.py index 4fedfd7c2..62d6d2958 100644 --- a/mlserver/grpc/servicers.py +++ b/mlserver/grpc/servicers.py @@ -1,12 +1,7 @@ import grpc -from typing import Callable -from fastapi import status - from . import dataplane_pb2 as pb -from . import model_repository_pb2 as mr_pb from .dataplane_pb2_grpc import GRPCInferenceServiceServicer -from .model_repository_pb2_grpc import ModelRepositoryServiceServicer from .converters import ( ModelInferRequestConverter, ModelInferResponseConverter, @@ -15,40 +10,19 @@ RepositoryIndexRequestConverter, RepositoryIndexResponseConverter, ) -from .logging import logger -from .utils import to_headers, to_metadata +from .utils import to_headers, to_metadata, handle_mlserver_error from ..utils import insert_headers, extract_headers from ..handlers import DataPlane, ModelRepositoryHandlers -from ..errors import MLServerError - -STATUS_CODE_MAPPING = { - status.HTTP_400_BAD_REQUEST: grpc.StatusCode.INVALID_ARGUMENT, - status.HTTP_404_NOT_FOUND: grpc.StatusCode.NOT_FOUND, - status.HTTP_422_UNPROCESSABLE_ENTITY: grpc.StatusCode.FAILED_PRECONDITION, - status.HTTP_500_INTERNAL_SERVER_ERROR: grpc.StatusCode.INTERNAL, -} - - -def _grpc_status_code(err: MLServerError): - return STATUS_CODE_MAPPING.get(err.status_code, grpc.StatusCode.UNKNOWN) - - -def _handle_mlserver_error(f: Callable): - async def _inner(self, request, context): - try: - return await f(self, request, context) - except MLServerError as err: - logger.error(err) - await context.abort(code=_grpc_status_code(err), details=str(err)) - - return _inner class InferenceServicer(GRPCInferenceServiceServicer): - def __init__(self, data_plane: DataPlane): + def __init__( + self, data_plane: DataPlane, model_repository_handlers: ModelRepositoryHandlers + ): super().__init__() self._data_plane = data_plane + self._model_repository_handlers = model_repository_handlers async def ServerLive( self, request: pb.ServerLiveRequest, context @@ -76,7 +50,7 @@ async def ServerMetadata( metadata = await self._data_plane.metadata() return ServerMetadataResponseConverter.from_types(metadata) - @_handle_mlserver_error + @handle_mlserver_error async def ModelMetadata( self, request: pb.ModelMetadataRequest, context ) -> pb.ModelMetadataResponse: @@ -85,7 +59,7 @@ async def ModelMetadata( ) return ModelMetadataResponseConverter.from_types(metadata) - @_handle_mlserver_error + @handle_mlserver_error async def ModelInfer( self, request: pb.ModelInferRequest, context: grpc.ServicerContext ) -> pb.ModelInferResponse: @@ -106,28 +80,23 @@ async def ModelInfer( response = ModelInferResponseConverter.from_types(result) return response - -class ModelRepositoryServicer(ModelRepositoryServiceServicer): - def __init__(self, handlers: ModelRepositoryHandlers): - self._handlers = handlers - async def RepositoryIndex( - self, request: mr_pb.RepositoryIndexRequest, context - ) -> mr_pb.RepositoryIndexResponse: + self, request: pb.RepositoryIndexRequest, context + ) -> pb.RepositoryIndexResponse: payload = RepositoryIndexRequestConverter.to_types(request) - index = await self._handlers.index(payload) - return RepositoryIndexResponseConverter.from_types(index) + index = await self._model_repository_handlers.index(payload) + return RepositoryIndexResponseConverter.from_types(index) # type: ignore - @_handle_mlserver_error + @handle_mlserver_error async def RepositoryModelLoad( - self, request: mr_pb.RepositoryModelLoadRequest, context - ) -> mr_pb.RepositoryModelLoadResponse: - await self._handlers.load(request.model_name) - return mr_pb.RepositoryModelLoadResponse() + self, request: pb.RepositoryModelLoadRequest, context + ) -> pb.RepositoryModelLoadResponse: + await self._model_repository_handlers.load(request.model_name) + return pb.RepositoryModelLoadResponse() - @_handle_mlserver_error + @handle_mlserver_error async def RepositoryModelUnload( - self, request: mr_pb.RepositoryModelUnloadRequest, context - ) -> mr_pb.RepositoryModelUnloadResponse: - await self._handlers.unload(request.model_name) - return mr_pb.RepositoryModelUnloadResponse() + self, request: pb.RepositoryModelUnloadRequest, context + ) -> pb.RepositoryModelUnloadResponse: + await self._model_repository_handlers.unload(request.model_name) + return pb.RepositoryModelUnloadResponse() diff --git a/mlserver/grpc/utils.py b/mlserver/grpc/utils.py index b2bc1152c..f5e8d4839 100644 --- a/mlserver/grpc/utils.py +++ b/mlserver/grpc/utils.py @@ -1,7 +1,21 @@ -from typing import Dict, Tuple +import grpc + +from typing import Callable, Dict, Tuple +from fastapi import status from grpc import ServicerContext +from .logging import logger +from ..errors import MLServerError + + +STATUS_CODE_MAPPING = { + status.HTTP_400_BAD_REQUEST: grpc.StatusCode.INVALID_ARGUMENT, + status.HTTP_404_NOT_FOUND: grpc.StatusCode.NOT_FOUND, + status.HTTP_422_UNPROCESSABLE_ENTITY: grpc.StatusCode.FAILED_PRECONDITION, + status.HTTP_500_INTERNAL_SERVER_ERROR: grpc.StatusCode.INTERNAL, +} + def to_headers(context: ServicerContext) -> Dict[str, str]: metadata = context.invocation_metadata() + context.trailing_metadata() @@ -14,3 +28,18 @@ def to_headers(context: ServicerContext) -> Dict[str, str]: def to_metadata(headers: Dict[str, str]) -> Tuple[Tuple[str, str], ...]: return tuple(headers.items()) + + +def _grpc_status_code(err: MLServerError): + return STATUS_CODE_MAPPING.get(err.status_code, grpc.StatusCode.UNKNOWN) + + +def handle_mlserver_error(f: Callable): + async def _inner(self, request, context): + try: + return await f(self, request, context) + except MLServerError as err: + logger.error(err) + await context.abort(code=_grpc_status_code(err), details=str(err)) + + return _inner diff --git a/proto/dataplane.proto b/proto/dataplane.proto index 6a40c10f4..9a5b6cfa8 100644 --- a/proto/dataplane.proto +++ b/proto/dataplane.proto @@ -24,6 +24,18 @@ service GRPCInferenceService // Perform inference using a specific model. rpc ModelInfer(ModelInferRequest) returns (ModelInferResponse) {} + + // Get the index of model repository contents. + rpc RepositoryIndex(RepositoryIndexRequest) + returns (RepositoryIndexResponse) {} + + // Load or reload a model from a repository. + rpc RepositoryModelLoad(RepositoryModelLoadRequest) + returns (RepositoryModelLoadResponse) {} + + // Unload a model. + rpc RepositoryModelUnload(RepositoryModelUnloadRequest) + returns (RepositoryModelUnloadResponse) {} } @@ -303,3 +315,98 @@ message InferTensorContents // one-dimensional, row-major order of the tensor elements. repeated bytes bytes_contents = 8; } + +// +// Messages for the Repository API +// +// NOTE: These messages used to exist previously on a different protobuf +// definition. However, they have now been merged with the main +// GRPCInferenceService. +// + + +// An model repository parameter value. +message ModelRepositoryParameter +{ + // The parameter value can be a string, an int64 or a boolean + oneof parameter_choice + { + // A boolean parameter value. + bool bool_param = 1; + + // An int64 parameter value. + int64 int64_param = 2; + + // A string parameter value. + string string_param = 3; + + // A bytes parameter value. + bytes bytes_param = 4; + } +} + + +message RepositoryIndexRequest +{ + // The name of the repository. If empty the index is returned + // for all repositories. + string repository_name = 1; + + // If true return only models currently ready for inferencing. + bool ready = 2; +} + +message RepositoryIndexResponse +{ + // Index entry for a model. + message ModelIndex { + // The name of the model. + string name = 1; + + // The version of the model. + string version = 2; + + // The state of the model. + string state = 3; + + // The reason, if any, that the model is in the given state. + string reason = 4; + } + + // An index entry for each model. + repeated ModelIndex models = 1; +} + +message RepositoryModelLoadRequest +{ + // The name of the repository to load from. If empty the model + // is loaded from any repository. + string repository_name = 1; + + // The name of the model to load, or reload. + string model_name = 2; + + // Optional model repository request parameters. + map parameters = 3; +} + +message RepositoryModelLoadResponse +{ +} + +message RepositoryModelUnloadRequest +{ + // The name of the repository from which the model was originally + // loaded. If empty the repository is not considered. + string repository_name = 1; + + // The name of the model to unload. + string model_name = 2; + + // Optional model repository request parameters. + map parameters = 3; +} + +message RepositoryModelUnloadResponse +{ +} diff --git a/tests/grpc/conftest.py b/tests/grpc/conftest.py index 7ff5b52a1..98ae6b094 100644 --- a/tests/grpc/conftest.py +++ b/tests/grpc/conftest.py @@ -11,9 +11,7 @@ from mlserver.handlers import DataPlane, ModelRepositoryHandlers from mlserver.settings import Settings from mlserver.grpc import dataplane_pb2 as pb -from mlserver.grpc import model_repository_pb2 as mr_pb from mlserver.grpc.dataplane_pb2_grpc import GRPCInferenceServiceStub -from mlserver.grpc.model_repository_pb2_grpc import ModelRepositoryServiceStub from mlserver.grpc import GRPCServer from ..conftest import TESTDATA_PATH @@ -37,11 +35,6 @@ def model_infer_request() -> pb.ModelInferRequest: return _read_testdata_pb(payload_path, pb.ModelInferRequest) -@pytest.fixture -def grpc_repository_index_request() -> mr_pb.RepositoryIndexRequest: - return mr_pb.RepositoryIndexRequest(ready=None) - - @pytest.fixture def grpc_parameters() -> Dict[str, pb.InferParameter]: return { @@ -52,19 +45,16 @@ def grpc_parameters() -> Dict[str, pb.InferParameter]: @pytest.fixture -async def inference_service_stub( - grpc_server, settings: Settings -) -> AsyncGenerator[GRPCInferenceServiceStub, None]: - async with aio.insecure_channel(f"{settings.host}:{settings.grpc_port}") as channel: - yield GRPCInferenceServiceStub(channel) +def grpc_repository_index_request() -> pb.RepositoryIndexRequest: + return pb.RepositoryIndexRequest(ready=None) @pytest.fixture -async def model_repository_service_stub( +async def inference_service_stub( grpc_server, settings: Settings -) -> AsyncGenerator[ModelRepositoryServiceStub, None]: +) -> AsyncGenerator[GRPCInferenceServiceStub, None]: async with aio.insecure_channel(f"{settings.host}:{settings.grpc_port}") as channel: - yield ModelRepositoryServiceStub(channel) + yield GRPCInferenceServiceStub(channel) @pytest.fixture diff --git a/tests/grpc/test_converters.py b/tests/grpc/test_converters.py index 81ef03a95..7c2220096 100644 --- a/tests/grpc/test_converters.py +++ b/tests/grpc/test_converters.py @@ -84,30 +84,6 @@ def test_modelinferresponse_from_types(inference_response): ) -def test_repositoryindexrequest_to_types(grpc_repository_index_request): - repository_index_request = RepositoryIndexRequestConverter.to_types( - grpc_repository_index_request - ) - - assert repository_index_request.ready == grpc_repository_index_request.ready - - -def test_repositoryindexresponse_from_types(repository_index_response): - grpc_repository_index_request = RepositoryIndexResponseConverter.from_types( - repository_index_response - ) - - assert len(grpc_repository_index_request.models) == len(repository_index_response) - - for expected, grpc_model in zip( - repository_index_response, grpc_repository_index_request.models - ): - assert expected.name == grpc_model.name - assert expected.version == grpc_model.version - assert expected.state.value == grpc_model.state - assert expected.reason == grpc_model.reason - - def test_parameters_to_types(grpc_parameters): parameters = ParametersConverter.to_types(grpc_parameters) @@ -161,3 +137,29 @@ def test_inferoutputtensor_from_types( ): infer_output_tensor = InferOutputTensorConverter.from_types(response_output) assert infer_output_tensor == expected + + +def test_repositoryindexrequest_to_types(grpc_repository_index_request): + repository_index_request = RepositoryIndexRequestConverter.to_types( + grpc_repository_index_request + ) + + assert repository_index_request.ready == grpc_repository_index_request.ready + + +def test_repositoryindexresponse_from_types(repository_index_response): + grpc_repository_index_response = RepositoryIndexResponseConverter.from_types( + repository_index_response + ) + + assert isinstance(grpc_repository_index_response, pb.RepositoryIndexResponse) + assert len(grpc_repository_index_response.models) == len(repository_index_response) + + for expected, grpc_model in zip( + repository_index_response, grpc_repository_index_response.models + ): + assert isinstance(grpc_model, pb.RepositoryIndexResponse.ModelIndex) + assert expected.name == grpc_model.name + assert expected.version == grpc_model.version + assert expected.state.value == grpc_model.state + assert expected.reason == grpc_model.reason diff --git a/tests/grpc/test_model_repository.py b/tests/grpc/test_model_repository.py new file mode 100644 index 000000000..b8696dc8a --- /dev/null +++ b/tests/grpc/test_model_repository.py @@ -0,0 +1,110 @@ +""" +NOTE: These tests belong to the deprecated ModelRepository API, which has now +been merged with the GRPCInferenceService's dataplane. +""" + +import pytest +import grpc + +from typing import AsyncGenerator +from grpc import aio + +from mlserver.settings import Settings +from mlserver.grpc.converters import ( + RepositoryIndexRequestConverter, + RepositoryIndexResponseConverter, +) +from mlserver.grpc.model_repository_pb2_grpc import ModelRepositoryServiceStub +from mlserver.grpc import dataplane_pb2 as pb +from mlserver.grpc import model_repository_pb2 as mr_pb + + +@pytest.fixture +def grpc_repository_index_request() -> mr_pb.RepositoryIndexRequest: + return mr_pb.RepositoryIndexRequest(ready=None) + + +@pytest.fixture +async def model_repository_service_stub( + grpc_server, settings: Settings +) -> AsyncGenerator[ModelRepositoryServiceStub, None]: + async with aio.insecure_channel(f"{settings.host}:{settings.grpc_port}") as channel: + yield ModelRepositoryServiceStub(channel) + + +def test_repositoryindexrequest_to_types(grpc_repository_index_request): + repository_index_request = RepositoryIndexRequestConverter.to_types( + grpc_repository_index_request + ) + + assert repository_index_request.ready == grpc_repository_index_request.ready + + +def test_repositoryindexresponse_from_types(repository_index_response): + grpc_repository_index_response = RepositoryIndexResponseConverter.from_types( + repository_index_response, use_model_repository=True + ) + + assert isinstance(grpc_repository_index_response, mr_pb.RepositoryIndexResponse) + assert len(grpc_repository_index_response.models) == len(repository_index_response) + + for expected, grpc_model in zip( + repository_index_response, grpc_repository_index_response.models + ): + assert isinstance(grpc_model, mr_pb.RepositoryIndexResponse.ModelIndex) + assert expected.name == grpc_model.name + assert expected.version == grpc_model.version + assert expected.state.value == grpc_model.state + assert expected.reason == grpc_model.reason + + +async def test_model_repository_index( + model_repository_service_stub, grpc_repository_index_request +): + index = await model_repository_service_stub.RepositoryIndex( + grpc_repository_index_request + ) + + assert len(index.models) == 1 + + +async def test_model_repository_unload( + inference_service_stub, model_repository_service_stub, sum_model_settings +): + unload_request = mr_pb.RepositoryModelUnloadRequest( + model_name=sum_model_settings.name + ) + await model_repository_service_stub.RepositoryModelUnload(unload_request) + + with pytest.raises(grpc.RpcError): + await inference_service_stub.ModelMetadata( + pb.ModelMetadataRequest(name=sum_model_settings.name) + ) + + +async def test_model_repository_load( + inference_service_stub, model_repository_service_stub, sum_model_settings +): + await model_repository_service_stub.RepositoryModelUnload( + mr_pb.RepositoryModelLoadRequest(model_name=sum_model_settings.name) + ) + + load_request = mr_pb.RepositoryModelLoadRequest(model_name=sum_model_settings.name) + await model_repository_service_stub.RepositoryModelLoad(load_request) + + response = await inference_service_stub.ModelMetadata( + pb.ModelMetadataRequest(name=sum_model_settings.name) + ) + + assert response.name == sum_model_settings.name + + +async def test_model_repository_load_error( + inference_service_stub, model_repository_service_stub, sum_model_settings +): + with pytest.raises(grpc.RpcError) as err: + load_request = mr_pb.RepositoryModelLoadRequest(model_name="my-model") + await model_repository_service_stub.RepositoryModelLoad(load_request) + + assert err.value.code() == grpc.StatusCode.NOT_FOUND + assert err.value.details() == "Model my-model not found" diff --git a/tests/grpc/test_servicers.py b/tests/grpc/test_servicers.py index 02ac5b672..cbea033fc 100644 --- a/tests/grpc/test_servicers.py +++ b/tests/grpc/test_servicers.py @@ -2,7 +2,6 @@ import grpc from mlserver.grpc import dataplane_pb2 as pb -from mlserver.grpc import model_repository_pb2 as mr_pb from mlserver import __version__ @@ -94,22 +93,16 @@ async def test_model_infer_error(inference_service_stub, model_infer_request): async def test_model_repository_index( - model_repository_service_stub, grpc_repository_index_request + inference_service_stub, grpc_repository_index_request ): - index = await model_repository_service_stub.RepositoryIndex( - grpc_repository_index_request - ) + index = await inference_service_stub.RepositoryIndex(grpc_repository_index_request) assert len(index.models) == 1 -async def test_model_repository_unload( - inference_service_stub, model_repository_service_stub, sum_model_settings -): - unload_request = mr_pb.RepositoryModelUnloadRequest( - model_name=sum_model_settings.name - ) - await model_repository_service_stub.RepositoryModelUnload(unload_request) +async def test_model_repository_unload(inference_service_stub, sum_model_settings): + unload_request = pb.RepositoryModelUnloadRequest(model_name=sum_model_settings.name) + await inference_service_stub.RepositoryModelUnload(unload_request) with pytest.raises(grpc.RpcError): await inference_service_stub.ModelMetadata( @@ -117,15 +110,13 @@ async def test_model_repository_unload( ) -async def test_model_repository_load( - inference_service_stub, model_repository_service_stub, sum_model_settings -): - await model_repository_service_stub.RepositoryModelUnload( - mr_pb.RepositoryModelLoadRequest(model_name=sum_model_settings.name) +async def test_model_repository_load(inference_service_stub, sum_model_settings): + await inference_service_stub.RepositoryModelUnload( + pb.RepositoryModelLoadRequest(model_name=sum_model_settings.name) ) - load_request = mr_pb.RepositoryModelLoadRequest(model_name=sum_model_settings.name) - await model_repository_service_stub.RepositoryModelLoad(load_request) + load_request = pb.RepositoryModelLoadRequest(model_name=sum_model_settings.name) + await inference_service_stub.RepositoryModelLoad(load_request) response = await inference_service_stub.ModelMetadata( pb.ModelMetadataRequest(name=sum_model_settings.name) @@ -134,12 +125,10 @@ async def test_model_repository_load( assert response.name == sum_model_settings.name -async def test_model_repository_load_error( - inference_service_stub, model_repository_service_stub, sum_model_settings -): +async def test_model_repository_load_error(inference_service_stub, sum_model_settings): with pytest.raises(grpc.RpcError) as err: - load_request = mr_pb.RepositoryModelLoadRequest(model_name="my-model") - await model_repository_service_stub.RepositoryModelLoad(load_request) + load_request = pb.RepositoryModelLoadRequest(model_name="my-model") + await inference_service_stub.RepositoryModelLoad(load_request) assert err.value.code() == grpc.StatusCode.NOT_FOUND assert err.value.details() == "Model my-model not found"