Skip to content

Commit aa35cfd

Browse files
netfstensorflower-gardener
authored andcommitted
Switch from beta to GA gRPC API.
PiperOrigin-RevId: 206612637
1 parent 1e74469 commit aa35cfd

7 files changed

Lines changed: 230 additions & 293 deletions

File tree

tensorflow_serving/apis/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ serving_proto_library(
178178

179179
py_library(
180180
name = "prediction_service_proto_py_pb2",
181-
srcs = ["prediction_service_pb2.py"],
181+
srcs = [
182+
"prediction_service_pb2.py",
183+
"prediction_service_pb2_grpc.py",
184+
],
182185
srcs_version = "PY2AND3",
183186
deps = [
184187
":classification_proto_py_pb2",

tensorflow_serving/apis/prediction_service_pb2.py

Lines changed: 59 additions & 254 deletions
Large diffs are not rendered by default.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
16+
# source: tensorflow_serving/apis/prediction_service.proto
17+
# To regenerate run
18+
# python -m grpc.tools.protoc --python_out=. --grpc_python_out=. -I. tensorflow_serving/apis/prediction_service.proto
19+
import grpc
20+
21+
from tensorflow_serving.apis import classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2
22+
from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2
23+
from tensorflow_serving.apis import inference_pb2 as tensorflow__serving_dot_apis_dot_inference__pb2
24+
from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2
25+
from tensorflow_serving.apis import regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2
26+
27+
28+
class PredictionServiceStub(object):
29+
"""open source marker; do not remove
30+
PredictionService provides access to machine-learned models loaded by
31+
model_servers.
32+
"""
33+
34+
def __init__(self, channel):
35+
"""Constructor.
36+
37+
Args:
38+
channel: A grpc.Channel.
39+
"""
40+
self.Classify = channel.unary_unary(
41+
'/tensorflow.serving.PredictionService/Classify',
42+
request_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString,
43+
response_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString,
44+
)
45+
self.Regress = channel.unary_unary(
46+
'/tensorflow.serving.PredictionService/Regress',
47+
request_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString,
48+
response_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString,
49+
)
50+
self.Predict = channel.unary_unary(
51+
'/tensorflow.serving.PredictionService/Predict',
52+
request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
53+
response_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
54+
)
55+
self.MultiInference = channel.unary_unary(
56+
'/tensorflow.serving.PredictionService/MultiInference',
57+
request_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.SerializeToString,
58+
response_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.FromString,
59+
)
60+
self.GetModelMetadata = channel.unary_unary(
61+
'/tensorflow.serving.PredictionService/GetModelMetadata',
62+
request_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
63+
response_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
64+
)
65+
66+
67+
class PredictionServiceServicer(object):
68+
"""open source marker; do not remove
69+
PredictionService provides access to machine-learned models loaded by
70+
model_servers.
71+
"""
72+
73+
def Classify(self, request, context):
74+
"""Classify.
75+
"""
76+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
77+
context.set_details('Method not implemented!')
78+
raise NotImplementedError('Method not implemented!')
79+
80+
def Regress(self, request, context):
81+
"""Regress.
82+
"""
83+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
84+
context.set_details('Method not implemented!')
85+
raise NotImplementedError('Method not implemented!')
86+
87+
def Predict(self, request, context):
88+
"""Predict -- provides access to loaded TensorFlow model.
89+
"""
90+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
91+
context.set_details('Method not implemented!')
92+
raise NotImplementedError('Method not implemented!')
93+
94+
def MultiInference(self, request, context):
95+
"""MultiInference API for multi-headed models.
96+
"""
97+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
98+
context.set_details('Method not implemented!')
99+
raise NotImplementedError('Method not implemented!')
100+
101+
def GetModelMetadata(self, request, context):
102+
"""GetModelMetadata - provides access to metadata for loaded models.
103+
"""
104+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
105+
context.set_details('Method not implemented!')
106+
raise NotImplementedError('Method not implemented!')
107+
108+
109+
def add_PredictionServiceServicer_to_server(servicer, server):
110+
rpc_method_handlers = {
111+
'Classify': grpc.unary_unary_rpc_method_handler(
112+
servicer.Classify,
113+
request_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString,
114+
response_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString,
115+
),
116+
'Regress': grpc.unary_unary_rpc_method_handler(
117+
servicer.Regress,
118+
request_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString,
119+
response_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString,
120+
),
121+
'Predict': grpc.unary_unary_rpc_method_handler(
122+
servicer.Predict,
123+
request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
124+
response_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
125+
),
126+
'MultiInference': grpc.unary_unary_rpc_method_handler(
127+
servicer.MultiInference,
128+
request_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.FromString,
129+
response_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.SerializeToString,
130+
),
131+
'GetModelMetadata': grpc.unary_unary_rpc_method_handler(
132+
servicer.GetModelMetadata,
133+
request_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
134+
response_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
135+
),
136+
}
137+
generic_handler = grpc.method_handlers_generic_handler(
138+
'tensorflow.serving.PredictionService', rpc_method_handlers)
139+
server.add_generic_rpc_handlers((generic_handler,))

tensorflow_serving/example/inception_client.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
# This is a placeholder for a Google-internal import.
2424

25-
from grpc.beta import implementations
25+
import grpc
2626
import tensorflow as tf
2727

2828
from tensorflow_serving.apis import predict_pb2
29-
from tensorflow_serving.apis import prediction_service_pb2
29+
from tensorflow_serving.apis import prediction_service_pb2_grpc
3030

3131

3232
tf.app.flags.DEFINE_string('server', 'localhost:9000',
@@ -36,9 +36,8 @@
3636

3737

3838
def main(_):
39-
host, port = FLAGS.server.split(':')
40-
channel = implementations.insecure_channel(host, int(port))
41-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
39+
channel = grpc.insecure_channel(FLAGS.server)
40+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
4241
# Send request
4342
with open(FLAGS.image, 'rb') as f:
4443
# See prediction_service.proto for gRPC request/response details.

tensorflow_serving/example/mnist_client.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232

3333
# This is a placeholder for a Google-internal import.
3434

35-
from grpc.beta import implementations
35+
import grpc
3636
import numpy
3737
import tensorflow as tf
3838

3939
from tensorflow_serving.apis import predict_pb2
40-
from tensorflow_serving.apis import prediction_service_pb2
40+
from tensorflow_serving.apis import prediction_service_pb2_grpc
4141
import mnist_input_data
4242

4343

@@ -137,9 +137,8 @@ def do_inference(hostport, work_dir, concurrency, num_tests):
137137
IOError: An error occurred processing test data set.
138138
"""
139139
test_data_set = mnist_input_data.read_data_sets(work_dir).test
140-
host, port = hostport.split(':')
141-
channel = implementations.insecure_channel(host, int(port))
142-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
140+
channel = grpc.insecure_channel(hostport)
141+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
143142
result_counter = _ResultCounter(num_tests, concurrency)
144143
for _ in range(num_tests):
145144
request = predict_pb2.PredictRequest()

tensorflow_serving/model_servers/tensorflow_model_server_test.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
# This is a placeholder for a Google-internal import.
3030

3131
import grpc
32-
from grpc.beta import implementations
33-
from grpc.beta import interfaces as beta_interfaces
34-
from grpc.framework.interfaces.face import face
3532
import tensorflow as tf
3633

3734
from tensorflow.core.framework import types_pb2
@@ -42,7 +39,7 @@
4239
from tensorflow_serving.apis import inference_pb2
4340
from tensorflow_serving.apis import model_service_pb2_grpc
4441
from tensorflow_serving.apis import predict_pb2
45-
from tensorflow_serving.apis import prediction_service_pb2
42+
from tensorflow_serving.apis import prediction_service_pb2_grpc
4643
from tensorflow_serving.apis import regression_pb2
4744

4845
FLAGS = flags.FLAGS
@@ -70,12 +67,12 @@ def WaitForServerReady(port):
7067

7168
try:
7269
# Send empty request to missing model
73-
channel = implementations.insecure_channel('localhost', port)
74-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
70+
channel = grpc.insecure_channel('localhost:{}'.format(port))
71+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
7572
stub.Predict(request, RPC_TIMEOUT)
76-
except face.AbortionError as error:
73+
except grpc.RpcError as error:
7774
# Missing model error will have details containing 'Servable'
78-
if 'Servable' in error.details:
75+
if 'Servable' in error.details():
7976
print 'Server is ready'
8077
break
8178

@@ -199,9 +196,8 @@ def VerifyPredictRequest(self,
199196
if specify_output:
200197
request.output_filter.append('y')
201198
# Send request
202-
host, port = model_server_address.split(':')
203-
channel = implementations.insecure_channel(host, int(port))
204-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
199+
channel = grpc.insecure_channel(model_server_address)
200+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
205201
result = stub.Predict(request, RPC_TIMEOUT) # 5 secs timeout
206202
# Verify response
207203
self.assertTrue('y' in result.outputs)
@@ -313,9 +309,8 @@ def testClassify(self):
313309
example.features.feature['x'].float_list.value.extend([2.0])
314310

315311
# Send request
316-
host, port = model_server_address.split(':')
317-
channel = implementations.insecure_channel(host, int(port))
318-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
312+
channel = grpc.insecure_channel(model_server_address)
313+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
319314
result = stub.Classify(request, RPC_TIMEOUT) # 5 secs timeout
320315
# Verify response
321316
self.assertEquals(1, len(result.result.classifications))
@@ -345,9 +340,8 @@ def testRegress(self):
345340
example.features.feature['x'].float_list.value.extend([2.0])
346341

347342
# Send request
348-
host, port = model_server_address.split(':')
349-
channel = implementations.insecure_channel(host, int(port))
350-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
343+
channel = grpc.insecure_channel(model_server_address)
344+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
351345
result = stub.Regress(request, RPC_TIMEOUT) # 5 secs timeout
352346
# Verify response
353347
self.assertEquals(1, len(result.result.regressions))
@@ -381,9 +375,8 @@ def testMultiInference(self):
381375
example.features.feature['x'].float_list.value.extend([2.0])
382376

383377
# Send request
384-
host, port = model_server_address.split(':')
385-
channel = implementations.insecure_channel(host, int(port))
386-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
378+
channel = grpc.insecure_channel(model_server_address)
379+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
387380
result = stub.MultiInference(request, RPC_TIMEOUT) # 5 secs timeout
388381

389382
# Verify response
@@ -451,13 +444,13 @@ def _TestBadModel(self):
451444
model_server_address = self.RunServer(PickUnusedPort(), 'default',
452445
model_path,
453446
wait_for_server_ready=False)
454-
with self.assertRaises(face.AbortionError) as error:
447+
with self.assertRaises(grpc.RpcError) as ectxt:
455448
self.VerifyPredictRequest(
456449
model_server_address, expected_output=3.0,
457450
expected_version=self._GetModelVersion(model_path),
458451
signature_name='')
459-
self.assertIs(beta_interfaces.StatusCode.FAILED_PRECONDITION,
460-
error.exception.code)
452+
self.assertIs(grpc.StatusCode.FAILED_PRECONDITION,
453+
ectxt.exception.code())
461454

462455
def _TestBadModelUpconvertedSavedModel(self):
463456
"""Test Predict against a bad upconverted SavedModel model export."""

tensorflow_serving/model_servers/tensorflow_model_server_test_client.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
# This is a placeholder for a Google-internal import.
2121

22-
from grpc.beta import implementations
22+
import grpc
2323
import tensorflow as tf
2424

2525
from tensorflow.core.framework import types_pb2
2626
from tensorflow.python.platform import flags
2727
from tensorflow_serving.apis import predict_pb2
28-
from tensorflow_serving.apis import prediction_service_pb2
28+
from tensorflow_serving.apis import prediction_service_pb2_grpc
2929

3030

3131
tf.app.flags.DEFINE_string('server', 'localhost:8500',
@@ -41,9 +41,8 @@ def main(_):
4141
request.inputs['x'].float_val.append(2.0)
4242
request.output_filter.append('y')
4343
# Send request
44-
host, port = FLAGS.server.split(':')
45-
channel = implementations.insecure_channel(host, int(port))
46-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
44+
channel = grpc.insecure_channel(FLAGS.server)
45+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
4746
print stub.Predict(request, 5.0) # 5 secs timeout
4847

4948

0 commit comments

Comments
 (0)