diff --git a/ml_engine/online_prediction/predict.py b/ml_engine/online_prediction/predict.py index 5727177420a..8d6bb8fa181 100644 --- a/ml_engine/online_prediction/predict.py +++ b/ml_engine/online_prediction/predict.py @@ -62,10 +62,10 @@ def predict_json(project, model, instances, version=None): # [START predict_tf_records] -def predict_tf_records(project, - model, - example_bytes_list, - version=None): +def predict_examples(project, + model, + example_bytes_list, + version=None): """Send protocol buffer data to a deployed model for prediction. Args: @@ -119,7 +119,7 @@ def census_to_example_bytes(json_instance): """ import tensorflow as tf feature_dict = {} - for key, data in json_instance.iteritems(): + for key, data in six.iteritems(json_instance): if isinstance(data, six.string_types): feature_dict[key] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[str(data)])) @@ -153,7 +153,7 @@ def main(project, model, version=None, force_tfrecord=False): census_to_example_bytes(e) for e in user_input ] - result = predict_tf_records( + result = predict_examples( project, model, example_bytes_list, version=version) else: result = predict_json( diff --git a/ml_engine/online_prediction/predict_test.py b/ml_engine/online_prediction/predict_test.py index 90930cae611..81203960bd2 100644 --- a/ml_engine/online_prediction/predict_test.py +++ b/ml_engine/online_prediction/predict_test.py @@ -22,8 +22,8 @@ MODEL = 'census' -VERSION = 'v1' -TF_RECORDS_VERSION = 'v1tfrecord' +JSON_VERSION = 'v1json' +EXAMPLES_VERSION = 'v1example' PROJECT = 'python-docs-samples-tests' JSON = { 'age': 25, @@ -41,22 +41,21 @@ 'native_country': ' United-States' } EXPECTED_OUTPUT = { - u'probabilities': [0.9942260384559631, 0.005774002522230148], - u'logits': [-5.148599147796631], - u'classes': 0, - u'logistic': [0.005774001590907574] + u'confidence': 0.7760371565818787, + u'predictions': u' <=50K' } def test_predict_json(): result = predict.predict_json( - PROJECT, MODEL, [JSON, JSON], version=VERSION) + PROJECT, MODEL, [JSON, JSON], version=JSON_VERSION) assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result def test_predict_json_error(): with pytest.raises(RuntimeError): - predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION) + predict.predict_json( + PROJECT, MODEL, [{"foo": "bar"}], version=JSON_VERSION) @pytest.mark.slow @@ -66,9 +65,8 @@ def test_census_example_to_bytes(): @pytest.mark.slow -@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449') -def test_predict_tfrecords(): +def test_predict_examples(): b = predict.census_to_example_bytes(JSON) - result = predict.predict_tfrecords( - PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION) + result = predict.predict_examples( + PROJECT, MODEL, [b, b], version=EXAMPLES_VERSION) assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result