Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ml_engine/online_prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def predict_json(project, model, instances, version=None):


# [START predict_tf_records]
def predict_tf_records(project,
model,
example_bytes_list,
version=None):
def predict_examples(project,
model,
example_bytes_list,
version=None):
"""Send protocol buffer data to a deployed model for prediction.

Args:
Expand Down Expand Up @@ -119,7 +119,7 @@ def census_to_example_bytes(json_instance):
"""
import tensorflow as tf
feature_dict = {}
for key, data in json_instance.iteritems():
for key, data in six.iteritems(json_instance):
if isinstance(data, six.string_types):
feature_dict[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[str(data)]))
Expand Down Expand Up @@ -153,7 +153,7 @@ def main(project, model, version=None, force_tfrecord=False):
census_to_example_bytes(e)
for e in user_input
]
result = predict_tf_records(
result = predict_examples(
project, model, example_bytes_list, version=version)
else:
result = predict_json(
Expand Down
22 changes: 10 additions & 12 deletions ml_engine/online_prediction/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@


MODEL = 'census'
VERSION = 'v1'
TF_RECORDS_VERSION = 'v1tfrecord'
JSON_VERSION = 'v1json'
EXAMPLES_VERSION = 'v1example'
PROJECT = 'python-docs-samples-tests'
JSON = {
'age': 25,
Expand All @@ -41,22 +41,21 @@
'native_country': ' United-States'
}
EXPECTED_OUTPUT = {
u'probabilities': [0.9942260384559631, 0.005774002522230148],
u'logits': [-5.148599147796631],
u'classes': 0,
u'logistic': [0.005774001590907574]
u'confidence': 0.7760371565818787,
u'predictions': u' <=50K'
}


def test_predict_json():
result = predict.predict_json(
PROJECT, MODEL, [JSON, JSON], version=VERSION)
PROJECT, MODEL, [JSON, JSON], version=JSON_VERSION)
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result


def test_predict_json_error():
with pytest.raises(RuntimeError):
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
predict.predict_json(
PROJECT, MODEL, [{"foo": "bar"}], version=JSON_VERSION)


@pytest.mark.slow
Expand All @@ -66,9 +65,8 @@ def test_census_example_to_bytes():


@pytest.mark.slow
@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449')
def test_predict_tfrecords():
def test_predict_examples():
b = predict.census_to_example_bytes(JSON)
result = predict.predict_tfrecords(
PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION)
result = predict.predict_examples(
PROJECT, MODEL, [b, b], version=EXAMPLES_VERSION)
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result