diff --git a/google/cloud/vision/entity.py b/google/cloud/vision/entity.py index 4324b17de939..3c8f810c37c7 100644 --- a/google/cloud/vision/entity.py +++ b/google/cloud/vision/entity.py @@ -55,7 +55,7 @@ def from_api_repr(cls, response): :rtype: :class:`~google.cloud.vision.entiy.EntityAnnotation` :returns: Instance of ``EntityAnnotation``. """ - bounds = Bounds.from_api_repr(response['boundingPoly']) + bounds = Bounds.from_api_repr(response.get('boundingPoly')) description = response['description'] locations = [LocationInformation.from_api_repr(location) for location in response.get('locations', [])] diff --git a/google/cloud/vision/geometry.py b/google/cloud/vision/geometry.py index acac2a911710..4e6af390213d 100644 --- a/google/cloud/vision/geometry.py +++ b/google/cloud/vision/geometry.py @@ -31,13 +31,14 @@ def from_api_repr(cls, response_vertices): :type response_vertices: dict :param response_vertices: List of vertices. - :rtype: :class:`~google.cloud.vision.geometry.BoundsBase` - :returns: Instance of BoundsBase with populated verticies. + :rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None + :returns: Instance of BoundsBase with populated verticies or None. """ - vertices = [] - for vertex in response_vertices['vertices']: - vertices.append(Vertex(vertex.get('x', None), - vertex.get('y', None))) + if not response_vertices: + return None + + vertices = [Vertex(vertex.get('x', None), vertex.get('y', None)) for + vertex in response_vertices.get('vertices', [])] return cls(vertices) @property diff --git a/google/cloud/vision/image.py b/google/cloud/vision/image.py index 04315045cb4a..caba1ce10dfb 100644 --- a/google/cloud/vision/image.py +++ b/google/cloud/vision/image.py @@ -94,6 +94,7 @@ def _detect_annotation(self, feature): :class:`~google.cloud.vision.entity.EntityAnnotation`. """ reverse_types = { + 'LABEL_DETECTION': 'labelAnnotations', 'LANDMARK_DETECTION': 'landmarkAnnotations', 'LOGO_DETECTION': 'logoAnnotations', } @@ -122,6 +123,18 @@ def detect_faces(self, limit=10): return faces + def detect_labels(self, limit=10): + """Detect labels that describe objects in an image. + + :type limit: int + :param limit: The maximum number of labels to try and detect. + + :rtype: list + :returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation` + """ + feature = Feature(FeatureTypes.LABEL_DETECTION, limit) + return self._detect_annotation(feature) + def detect_landmarks(self, limit=10): """Detect landmarks in an image. diff --git a/unit_tests/vision/_fixtures.py b/unit_tests/vision/_fixtures.py index 2df26b0e6cfc..2c927b9bbca1 100644 --- a/unit_tests/vision/_fixtures.py +++ b/unit_tests/vision/_fixtures.py @@ -1,3 +1,28 @@ +LABEL_DETECTION_RESPONSE = { + 'responses': [ + { + 'labelAnnotations': [ + { + 'mid': '/m/0k4j', + 'description': 'automobile', + 'score': 0.9776855 + }, + { + 'mid': '/m/07yv9', + 'description': 'vehicle', + 'score': 0.947987 + }, + { + 'mid': '/m/07r04', + 'description': 'truck', + 'score': 0.88429511 + } + ] + } + ] +} + + LANDMARK_DETECTION_RESPONSE = { 'responses': [ { diff --git a/unit_tests/vision/test_client.py b/unit_tests/vision/test_client.py index 54e157f5a3f1..34b01f6bf3f3 100644 --- a/unit_tests/vision/test_client.py +++ b/unit_tests/vision/test_client.py @@ -114,6 +114,27 @@ def test_face_detection_from_content(self): image_request['image']['content']) self.assertEqual(5, image_request['features'][0]['maxResults']) + def test_label_detection_from_source(self): + from google.cloud.vision.entity import EntityAnnotation + from unit_tests.vision._fixtures import (LABEL_DETECTION_RESPONSE as + RETURNED) + credentials = _Credentials() + client = self._makeOne(project=self.PROJECT, credentials=credentials) + client.connection = _Connection(RETURNED) + + image = client.image(source_uri=_IMAGE_SOURCE) + labels = image.detect_labels(limit=3) + self.assertEqual(3, len(labels)) + self.assertTrue(isinstance(labels[0], EntityAnnotation)) + image_request = client.connection._requested[0]['data']['requests'][0] + self.assertEqual(_IMAGE_SOURCE, + image_request['image']['source']['gcs_image_uri']) + self.assertEqual(3, image_request['features'][0]['maxResults']) + self.assertEqual('automobile', labels[0].description) + self.assertEqual('vehicle', labels[1].description) + self.assertEqual('/m/0k4j', labels[0].mid) + self.assertEqual('/m/07yv9', labels[1].mid) + def test_landmark_detection_from_source(self): from google.cloud.vision.entity import EntityAnnotation from unit_tests.vision._fixtures import (LANDMARK_DETECTION_RESPONSE as