|
35 | 35 | LABEL_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'car.jpg') |
36 | 36 | LANDMARK_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'landmark.jpg') |
37 | 37 | TEXT_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'text.jpg') |
| 38 | +FULL_TEXT_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'full-text.jpg') |
38 | 39 |
|
39 | 40 |
|
40 | 41 | class Config(object): |
@@ -80,6 +81,107 @@ def _pb_not_implemented_skip(self, message): |
80 | 81 | self.skipTest(message) |
81 | 82 |
|
82 | 83 |
|
| 84 | +class TestVisionFullText(unittest.TestCase): |
| 85 | + def setUp(self): |
| 86 | + self.to_delete_by_case = [] |
| 87 | + |
| 88 | + def tearDown(self): |
| 89 | + for value in self.to_delete_by_case: |
| 90 | + value.delete() |
| 91 | + |
| 92 | + def _assert_full_text(self, full_text): |
| 93 | + from google.cloud.vision.text import TextAnnotation |
| 94 | + |
| 95 | + self.assertIsInstance(full_text, TextAnnotation) |
| 96 | + self.assertIsInstance(full_text.text, six.text_type) |
| 97 | + self.assertEqual(len(full_text.pages), 1) |
| 98 | + self.assertIsInstance(full_text.pages[0].width, int) |
| 99 | + self.assertIsInstance(full_text.pages[0].height, int) |
| 100 | + |
| 101 | + def test_detect_full_text_content(self): |
| 102 | + client = Config.CLIENT |
| 103 | + with open(FULL_TEXT_FILE, 'rb') as image_file: |
| 104 | + image = client.image(content=image_file.read()) |
| 105 | + full_text = image.detect_full_text() |
| 106 | + self._assert_full_text(full_text) |
| 107 | + |
| 108 | + def test_detect_full_text_filename(self): |
| 109 | + client = Config.CLIENT |
| 110 | + image = client.image(filename=FULL_TEXT_FILE) |
| 111 | + full_text = image.detect_full_text() |
| 112 | + self._assert_full_text(full_text) |
| 113 | + |
| 114 | + def test_detect_full_text_gcs(self): |
| 115 | + bucket_name = Config.TEST_BUCKET.name |
| 116 | + blob_name = 'full-text.jpg' |
| 117 | + blob = Config.TEST_BUCKET.blob(blob_name) |
| 118 | + self.to_delete_by_case.append(blob) # Clean-up. |
| 119 | + with open(FULL_TEXT_FILE, 'rb') as file_obj: |
| 120 | + blob.upload_from_file(file_obj) |
| 121 | + |
| 122 | + source_uri = 'gs://%s/%s' % (bucket_name, blob_name) |
| 123 | + |
| 124 | + client = Config.CLIENT |
| 125 | + image = client.image(source_uri=source_uri) |
| 126 | + full_text = image.detect_full_text() |
| 127 | + self._assert_full_text(full_text) |
| 128 | + |
| 129 | + |
| 130 | +class TestVisionClientCropHint(BaseVisionTestCase): |
| 131 | + def setUp(self): |
| 132 | + self.to_delete_by_case = [] |
| 133 | + |
| 134 | + def tearDown(self): |
| 135 | + for value in self.to_delete_by_case: |
| 136 | + value.delete() |
| 137 | + |
| 138 | + def _assert_crop_hint(self, hint): |
| 139 | + from google.cloud.vision.crop_hint import CropHint |
| 140 | + from google.cloud.vision.geometry import Bounds |
| 141 | + |
| 142 | + self.assertIsInstance(hint, CropHint) |
| 143 | + self.assertIsInstance(hint.bounds, Bounds) |
| 144 | + self.assertGreater(hint.bounds.vertices, 1) |
| 145 | + self.assertIsInstance(hint.confidence, (int, float)) |
| 146 | + self.assertIsInstance(hint.importance_fraction, float) |
| 147 | + |
| 148 | + def test_detect_crop_hints_content(self): |
| 149 | + client = Config.CLIENT |
| 150 | + with open(FACE_FILE, 'rb') as image_file: |
| 151 | + image = client.image(content=image_file.read()) |
| 152 | + crop_hints = image.detect_crop_hints( |
| 153 | + aspect_ratios=[1.3333, 1.7777], limit=2) |
| 154 | + self.assertEqual(len(crop_hints), 2) |
| 155 | + for hint in crop_hints: |
| 156 | + self._assert_crop_hint(hint) |
| 157 | + |
| 158 | + def test_detect_crop_hints_filename(self): |
| 159 | + client = Config.CLIENT |
| 160 | + image = client.image(filename=FACE_FILE) |
| 161 | + crop_hints = image.detect_crop_hints( |
| 162 | + aspect_ratios=[1.3333, 1.7777], limit=2) |
| 163 | + self.assertEqual(len(crop_hints), 2) |
| 164 | + for hint in crop_hints: |
| 165 | + self._assert_crop_hint(hint) |
| 166 | + |
| 167 | + def test_detect_crop_hints_gcs(self): |
| 168 | + bucket_name = Config.TEST_BUCKET.name |
| 169 | + blob_name = 'faces.jpg' |
| 170 | + blob = Config.TEST_BUCKET.blob(blob_name) |
| 171 | + self.to_delete_by_case.append(blob) # Clean-up. |
| 172 | + with open(FACE_FILE, 'rb') as file_obj: |
| 173 | + blob.upload_from_file(file_obj) |
| 174 | + |
| 175 | + source_uri = 'gs://%s/%s' % (bucket_name, blob_name) |
| 176 | + client = Config.CLIENT |
| 177 | + image = client.image(source_uri=source_uri) |
| 178 | + crop_hints = image.detect_crop_hints( |
| 179 | + aspect_ratios=[1.3333, 1.7777], limit=2) |
| 180 | + self.assertEqual(len(crop_hints), 2) |
| 181 | + for hint in crop_hints: |
| 182 | + self._assert_crop_hint(hint) |
| 183 | + |
| 184 | + |
83 | 185 | class TestVisionClientLogo(unittest.TestCase): |
84 | 186 | def setUp(self): |
85 | 187 | self.to_delete_by_case = [] |
@@ -559,3 +661,82 @@ def test_batch_detect_gcs(self): |
559 | 661 |
|
560 | 662 | self.assertEqual(len(results[1].logos), 0) |
561 | 663 | self.assertEqual(len(results[1].faces), 2) |
| 664 | + |
| 665 | + |
| 666 | +class TestVisionWebAnnotation(BaseVisionTestCase): |
| 667 | + def setUp(self): |
| 668 | + self.to_delete_by_case = [] |
| 669 | + |
| 670 | + def tearDown(self): |
| 671 | + for value in self.to_delete_by_case: |
| 672 | + value.delete() |
| 673 | + |
| 674 | + def _assert_web_entity(self, web_entity): |
| 675 | + from google.cloud.vision.web import WebEntity |
| 676 | + |
| 677 | + self.assertIsInstance(web_entity, WebEntity) |
| 678 | + self.assertIsInstance(web_entity.entity_id, six.text_type) |
| 679 | + self.assertIsInstance(web_entity.score, float) |
| 680 | + self.assertIsInstance(web_entity.description, six.text_type) |
| 681 | + |
| 682 | + def _assert_web_image(self, web_image): |
| 683 | + from google.cloud.vision.web import WebImage |
| 684 | + |
| 685 | + self.assertIsInstance(web_image, WebImage) |
| 686 | + self.assertIsInstance(web_image.url, six.text_type) |
| 687 | + self.assertIsInstance(web_image.score, float) |
| 688 | + |
| 689 | + def _assert_web_page(self, web_page): |
| 690 | + from google.cloud.vision.web import WebPage |
| 691 | + |
| 692 | + self.assertIsInstance(web_page, WebPage) |
| 693 | + self.assertIsInstance(web_page.url, six.text_type) |
| 694 | + self.assertIsInstance(web_page.score, float) |
| 695 | + |
| 696 | + def _assert_web_images(self, web_images, limit): |
| 697 | + self.assertEqual(len(web_images.web_entities), limit) |
| 698 | + for web_entity in web_images.web_entities: |
| 699 | + self._assert_web_entity(web_entity) |
| 700 | + |
| 701 | + self.assertEqual(len(web_images.full_matching_images), limit) |
| 702 | + for web_image in web_images.full_matching_images: |
| 703 | + self._assert_web_image(web_image) |
| 704 | + |
| 705 | + self.assertEqual(len(web_images.partial_matching_images), limit) |
| 706 | + for web_image in web_images.partial_matching_images: |
| 707 | + self._assert_web_image(web_image) |
| 708 | + |
| 709 | + self.assertEqual(len(web_images.pages_with_matching_images), limit) |
| 710 | + for web_page in web_images.pages_with_matching_images: |
| 711 | + self._assert_web_page(web_page) |
| 712 | + |
| 713 | + def test_detect_web_images_from_content(self): |
| 714 | + client = Config.CLIENT |
| 715 | + with open(LANDMARK_FILE, 'rb') as image_file: |
| 716 | + image = client.image(content=image_file.read()) |
| 717 | + limit = 5 |
| 718 | + web_images = image.detect_web(limit=limit) |
| 719 | + self._assert_web_images(web_images, limit) |
| 720 | + |
| 721 | + def test_detect_web_images_from_gcs(self): |
| 722 | + client = Config.CLIENT |
| 723 | + bucket_name = Config.TEST_BUCKET.name |
| 724 | + blob_name = 'landmark.jpg' |
| 725 | + blob = Config.TEST_BUCKET.blob(blob_name) |
| 726 | + self.to_delete_by_case.append(blob) # Clean-up. |
| 727 | + with open(LANDMARK_FILE, 'rb') as file_obj: |
| 728 | + blob.upload_from_file(file_obj) |
| 729 | + |
| 730 | + source_uri = 'gs://%s/%s' % (bucket_name, blob_name) |
| 731 | + |
| 732 | + image = client.image(source_uri=source_uri) |
| 733 | + limit = 5 |
| 734 | + web_images = image.detect_web(limit=limit) |
| 735 | + self._assert_web_images(web_images, limit) |
| 736 | + |
| 737 | + def test_detect_web_images_from_filename(self): |
| 738 | + client = Config.CLIENT |
| 739 | + image = client.image(filename=LANDMARK_FILE) |
| 740 | + limit = 5 |
| 741 | + web_images = image.detect_web(limit=limit) |
| 742 | + self._assert_web_images(web_images, limit) |
0 commit comments