diff --git a/doc/release/RELEASE-NOTES.md b/doc/release/RELEASE-NOTES.md index aec8ab496..d1e71c9b7 100644 --- a/doc/release/RELEASE-NOTES.md +++ b/doc/release/RELEASE-NOTES.md @@ -19,7 +19,7 @@ This project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html * *Describe schema changes here.* ### Code/API changes -* *Describe code/API changes here.* +* [OSDEV-1581](https://opensupplyhub.atlassian.net/browse/OSDEV-1581) - Added support for Geohex grid aggregation to the GET `/api/v1/production-locations/` endpoint. To receive the Geohex grid aggregation list in the response, it is necessary to pass the `aggregation` parameter with a value of `geohex_grid` and optionally specify `geohex_grid_precision` with an integer between 0 and 15. If `geohex_grid_precision` is not defined, the default value of 5 will be used. ### Architecture/Environment changes * *Describe architecture/environment changes here.* diff --git a/src/django/api/serializers/v1/production_locations_serializer.py b/src/django/api/serializers/v1/production_locations_serializer.py index 8debfa013..fef08876b 100644 --- a/src/django/api/serializers/v1/production_locations_serializer.py +++ b/src/django/api/serializers/v1/production_locations_serializer.py @@ -25,6 +25,8 @@ class ProductionLocationsSerializer(Serializer): # These params are checking considering serialize_params output size = IntegerField(required=False) + address = CharField(required=False) + description = CharField(required=False) number_of_workers_min = IntegerField(required=False) number_of_workers_max = IntegerField(required=False) percent_female_workers_min = FloatField(required=False) @@ -45,6 +47,15 @@ class ProductionLocationsSerializer(Serializer): choices=['asc', 'desc'], required=False ) + aggregation = ChoiceField( + choices=['geohex_grid'], + required=False, + ) + geohex_grid_precision = IntegerField( + min_value=0, + max_value=15, + required=False, + ) def validate(self, data): validators = [ diff --git a/src/django/api/services/opensearch/search.py b/src/django/api/services/opensearch/search.py index df3351399..72af7c9c8 100644 --- a/src/django/api/services/opensearch/search.py +++ b/src/django/api/services/opensearch/search.py @@ -46,11 +46,26 @@ def __prepare_opensearch_response(self, response): else: logger.warning(f"Missing '_source' in hit: {hit}") - return { + response_data = { "count": total_hits, - "data": data + "data": data, } + geohex_buckets = ( + response.get("aggregations", {}) + .get("grouped", {}) + .get("buckets", []) + ) + + if geohex_buckets: + response_data.update({ + "aggregations": { + "geohex_grid": geohex_buckets + } + }) + + return response_data + @staticmethod def __remove_null_values(obj): if isinstance(obj, dict): diff --git a/src/django/api/tests/test_moderation_events_query_builder.py b/src/django/api/tests/test_moderation_events_query_builder.py index cb0226c8b..d4ca1272d 100644 --- a/src/django/api/tests/test_moderation_events_query_builder.py +++ b/src/django/api/tests/test_moderation_events_query_builder.py @@ -1,4 +1,3 @@ -import unittest from django.test import TestCase from api.views.v1.opensearch_query_builder. \ moderation_events_query_builder import ModerationEventsQueryBuilder @@ -88,6 +87,21 @@ def test_add_sort(self): expected = {'created_at': {'order': 'asc'}} self.assertIn(expected, self.builder.query_body['sort']) + def test_add_sort_with_default_order(self): + self.builder.add_sort('created_at') + expected = {'created_at': {'order': 'desc'}} + self.assertIn(expected, self.builder.query_body['sort']) + + def test_add_sort_name(self): + self.builder.add_sort('name', 'asc') + expected = {'cleaned_data.name': {'order': 'asc'}} + self.assertIn(expected, self.builder.query_body['sort']) + + def test_add_sort_address(self): + self.builder.add_sort('address', 'asc') + expected = {'cleaned_data.address': {'order': 'asc'}} + self.assertIn(expected, self.builder.query_body['sort']) + def test_add_sort_country(self): self.builder.add_sort('country', 'asc') expected = {'cleaned_data.country.name': {'order': 'asc'}} @@ -121,7 +135,3 @@ def test_get_final_query_body(self): 'sort': [] } self.assertEqual(final_query, expected) - - -if __name__ == '__main__': - unittest.main() diff --git a/src/django/api/tests/test_opensearch_response_formatter.py b/src/django/api/tests/test_opensearch_response_formatter.py index d85eeca80..54124c3d2 100644 --- a/src/django/api/tests/test_opensearch_response_formatter.py +++ b/src/django/api/tests/test_opensearch_response_formatter.py @@ -188,6 +188,47 @@ def test_prepare_opensearch_response_rename_lon_field(self, mock_logger): self.assertEqual(result, expected_result) mock_logger.warning.assert_not_called() + @patch('api.services.opensearch.search.logger') + def test_prepare_opensearch_response_with_aggregation_data( + self, + mock_logger + ): + response = { + "hits": { + "total": {"value": 10}, + "hits": [ + {"_source": {"field1": "value1"}}, + {"_source": {"field2": "value2"}} + ] + }, + "aggregations": { + "grouped": { + "buckets": [ + {"key": "value1"}, + {"key": "value2"} + ] + } + } + } + expected_result = { + "count": 10, + "data": [ + {"field1": "value1"}, + {"field2": "value2"} + ], + "aggregations": { + "geohex_grid": [ + {"key": "value1"}, + {"key": "value2"} + ] + } + } + + result = self.service. \ + _OpenSearchService__prepare_opensearch_response(response) + self.assertEqual(result, expected_result) + mock_logger.warning.assert_not_called() + if __name__ == '__main__': unittest.main() diff --git a/src/django/api/tests/test_production_locations_query_builder.py b/src/django/api/tests/test_production_locations_query_builder.py index 5b3fe1405..f78e962f6 100644 --- a/src/django/api/tests/test_production_locations_query_builder.py +++ b/src/django/api/tests/test_production_locations_query_builder.py @@ -1,4 +1,3 @@ -import unittest from django.test import TestCase from api.views.v1.opensearch_query_builder. \ production_locations_query_builder import ProductionLocationsQueryBuilder @@ -26,20 +25,6 @@ def test_add_match(self): self.builder.query_body['query']['bool']['must'] ) - def test_add_multi_match(self): - self.builder.add_multi_match('test query') - expected = { - 'multi_match': { - 'query': 'test query', - 'fields': ['name^2', 'address', 'description', 'local_name'], - 'fuzziness': 2 - } - } - self.assertIn( - expected, - self.builder.query_body['query']['bool']['must'] - ) - def test_add_terms_for_standard_field(self): self.builder.add_terms('country', ['US', 'CA']) expected = {'terms': {'country.alpha_2': ['US', 'CA']}} @@ -181,6 +166,11 @@ def test_add_sort(self): expected = {'name.keyword': {'order': 'desc'}} self.assertIn(expected, self.builder.query_body['sort']) + def test_add_sort_with_default_order(self): + self.builder.add_sort('name') + expected = {'name.keyword': {'order': 'asc'}} + self.assertIn(expected, self.builder.query_body['sort']) + def test_add_search_after(self): search_after_value = 'test_value' search_after_id = 'test_id' @@ -214,6 +204,55 @@ def test_get_final_query_body(self): } self.assertEqual(final_query, expected) + def test_add_multi_match(self): + self.builder.add_multi_match( + 'test query' + ) + expected = { + 'multi_match': { + 'query': 'test query', + 'fields': ['name^2', 'address', 'description', 'local_name'], + 'fuzziness': 2, + } + } + self.assertIn( + expected, self.builder.query_body['query']['bool']['must'] + ) + + def test_add_aggregations_with_precision(self): + aggregation = 'geohex_grid' + geohex_grid_precision = 5 + self.builder.add_aggregations( + aggregation, + geohex_grid_precision + ) + expected = { + 'grouped': { + 'geohex_grid': { + 'field': 'coordinates', + 'precision': geohex_grid_precision + } + } + } + self.assertIn('aggregations', self.builder.query_body) + self.assertEqual(expected, self.builder.query_body['aggregations']) + + def test_add_aggregations_without_precision(self): + aggregation = 'geohex_grid' + self.builder.add_aggregations( + aggregation + ) + expected = { + 'grouped': { + 'geohex_grid': {'field': 'coordinates'} + } + } + self.assertIn('aggregations', self.builder.query_body) + self.assertEqual(expected, self.builder.query_body['aggregations']) -if __name__ == '__main__': - unittest.main() + def test_add_aggregations_where_aggregation_is_not_geohex_grid(self): + aggregation = 'test_aggregation' + self.builder.add_aggregations( + aggregation + ) + self.assertNotIn('aggregations', self.builder.query_body) diff --git a/src/django/api/tests/test_v1_utils.py b/src/django/api/tests/test_v1_utils.py index a22244f23..b69265f9a 100644 --- a/src/django/api/tests/test_v1_utils.py +++ b/src/django/api/tests/test_v1_utils.py @@ -1,48 +1,16 @@ from django.test import TestCase from django.http import QueryDict -from rest_framework.serializers import ( - CharField, - ChoiceField, - FloatField, - IntegerField, - ListField, - Serializer -) from rest_framework.response import Response from api.views.v1.utils import ( serialize_params, handle_value_error, handle_opensearch_exception ) +from api.serializers.v1.production_locations_serializer \ + import ProductionLocationsSerializer from api.services.opensearch.search import OpenSearchServiceException -class TestProductionLocationsSerializer(Serializer): - size = IntegerField(required=False) - address = CharField(required=False) - description = CharField(required=False) - search_after_value = CharField(required=False) - search_after_id = CharField(required=False) - number_of_workers_min = IntegerField(required=False) - number_of_workers_max = IntegerField(required=False) - percent_female_workers_min = FloatField(required=False) - percent_female_workers_max = FloatField(required=False) - coordinates_lat = FloatField(required=False) - coordinates_lng = FloatField(required=False) - country = ListField( - child=CharField(required=False), - required=False - ) - sort_by = ChoiceField( - choices=['name', 'address'], - required=False - ) - order_by = ChoiceField( - choices=['asc', 'desc'], - required=False - ) - - class V1UtilsTests(TestCase): def test_serialize_params_with_deep_object(self): @@ -56,7 +24,7 @@ def test_serialize_params_with_deep_object(self): 'coordinates[lng]': '56.78', }) serialized_params, error_response = \ - serialize_params(TestProductionLocationsSerializer, query_dict) + serialize_params(ProductionLocationsSerializer, query_dict) self.assertIsNone(error_response) self.assertEqual(serialized_params['number_of_workers_min'], 10) self.assertEqual(serialized_params['number_of_workers_max'], 50) @@ -74,10 +42,12 @@ def test_serialize_params_with_single_values(self): 'search_after[value]': 'abc123', 'sort_by': 'name', 'order_by': 'asc', - 'size': 10 + 'size': 10, + 'aggregation': 'geohex_grid', + 'geohex_grid_precision': 2, }) serialized_params, error_response = \ - serialize_params(TestProductionLocationsSerializer, query_dict) + serialize_params(ProductionLocationsSerializer, query_dict) self.assertIsNone(error_response) self.assertEqual(serialized_params['address'], '123 Main St') self.assertEqual( @@ -92,17 +62,21 @@ def test_serialize_params_with_single_values(self): self.assertEqual(serialized_params['sort_by'], 'name') self.assertEqual(serialized_params['order_by'], 'asc') self.assertEqual(serialized_params['size'], 10) + self.assertEqual(serialized_params['aggregation'], 'geohex_grid') + self.assertEqual(serialized_params['geohex_grid_precision'], 2) def test_serialize_params_with_mixed_values(self): query_dict = QueryDict('', mutable=True) query_dict.update({ 'number_of_workers[min]': '10', + 'number_of_workers[max]': '50', 'address': '123 Main St', }) serialized_params, error_response = \ - serialize_params(TestProductionLocationsSerializer, query_dict) + serialize_params(ProductionLocationsSerializer, query_dict) self.assertIsNone(error_response) self.assertEqual(serialized_params['number_of_workers_min'], 10) + self.assertEqual(serialized_params['number_of_workers_max'], 50) self.assertEqual(serialized_params['address'], '123 Main St') def test_serialize_params_invalid(self): @@ -111,8 +85,8 @@ def test_serialize_params_invalid(self): 'number_of_workers[min]': 'not_a_number', 'size': 'not_a_number' }) - serialized_params, error_response = \ - serialize_params(TestProductionLocationsSerializer, query_dict) + _, error_response = \ + serialize_params(ProductionLocationsSerializer, query_dict) self.assertIsNotNone(error_response) self.assertEqual( error_response['detail'], @@ -121,14 +95,78 @@ def test_serialize_params_invalid(self): self.assertIn( { 'field': 'number_of_workers_min', - 'detail': 'A Valid Integer Is Required.' + 'detail': 'A valid integer is required.' }, error_response['errors'] ) self.assertIn( { 'field': 'size', - 'detail': 'A Valid Integer Is Required.' + 'detail': 'A valid integer is required.' + }, + error_response['errors'] + ) + + def test_serialize_invalid_aggregation(self): + query_dict = QueryDict('', mutable=True) + query_dict.update({ + 'aggregation': 'invalid_aggregation', + }) + _, error_response = \ + serialize_params(ProductionLocationsSerializer, query_dict) + self.assertIsNotNone(error_response) + self.assertIn( + { + 'field': 'aggregation', + 'detail': '"invalid_aggregation" is not a valid choice.' + }, + error_response['errors'] + ) + + def test_serialize_invalid_precision_type(self): + query_dict = QueryDict('', mutable=True) + query_dict.update({ + 'geohex_grid_precision': 'not_a_number', + }) + _, error_response = \ + serialize_params(ProductionLocationsSerializer, query_dict) + self.assertIsNotNone(error_response) + self.assertIn( + { + 'field': 'geohex_grid_precision', + 'detail': 'A valid integer is required.' + }, + error_response['errors'] + ) + + def test_serialize_precision_value_too_low(self): + query_dict = QueryDict('', mutable=True) + query_dict.update({ + 'geohex_grid_precision': '-1', + }) + _, error_response = \ + serialize_params(ProductionLocationsSerializer, query_dict) + self.assertIsNotNone(error_response) + self.assertIn( + { + 'field': 'geohex_grid_precision', + 'detail': 'Ensure this value is greater than or equal to 0.' + }, + error_response['errors'] + ) + + def test_serialize_precision_value_too_high(self): + query_dict = QueryDict('', mutable=True) + query_dict.update({ + 'geohex_grid_precision': '16', + }) + _, error_response = \ + serialize_params(ProductionLocationsSerializer, query_dict) + self.assertIsNotNone(error_response) + self.assertIn( + { + 'field': 'geohex_grid_precision', + 'detail': 'Ensure this value is less than or equal to 15.' }, error_response['errors'] ) diff --git a/src/django/api/views/v1/opensearch_query_builder/opensearch_query_builder.py b/src/django/api/views/v1/opensearch_query_builder/opensearch_query_builder.py index 16cfbde63..a8b46a3fb 100644 --- a/src/django/api/views/v1/opensearch_query_builder/opensearch_query_builder.py +++ b/src/django/api/views/v1/opensearch_query_builder/opensearch_query_builder.py @@ -21,20 +21,6 @@ def add_match(self, field, value, fuzziness=None): } self.query_body['query']['bool']['must'].append(match_query) - def add_multi_match(self, query): - self.query_body['query']['bool']['must'].append({ - 'multi_match': { - 'query': query, - 'fields': [ - f'{V1_PARAMETERS_LIST.NAME}^2', - V1_PARAMETERS_LIST.ADDRESS, - V1_PARAMETERS_LIST.DESCRIPTION, - V1_PARAMETERS_LIST.LOCAL_NAME - ], - 'fuzziness': self.default_fuzziness - } - }) - def add_range(self, field, query_params): if field in { V1_PARAMETERS_LIST.NUMBER_OF_WORKERS, @@ -151,14 +137,6 @@ def add_search_after(self, search_after_value, search_after_id, id_type): search_after_id ] - @abstractmethod - def add_sort(self, field, order_by=None): - pass - - @abstractmethod - def add_terms(self, field, values): - pass - def get_final_query_body(self): return self.query_body @@ -174,3 +152,11 @@ def _build_os_id(self, values): } } ) + + @abstractmethod + def add_sort(self, field, order_by=None): + pass + + @abstractmethod + def add_terms(self, field, values): + pass diff --git a/src/django/api/views/v1/opensearch_query_builder/opensearch_query_director.py b/src/django/api/views/v1/opensearch_query_builder/opensearch_query_director.py index 6787b24b7..fb6a3c68c 100644 --- a/src/django/api/views/v1/opensearch_query_builder/opensearch_query_director.py +++ b/src/django/api/views/v1/opensearch_query_builder/opensearch_query_director.py @@ -31,6 +31,44 @@ def __init__(self, builder): V1_PARAMETERS_LIST.DATE_LT: 'range', } + def build_query(self, query_params): + self.__builder.reset() + + self.__process_template_fields(query_params) + self.__process_sorting(query_params) + self.__process_search_after(query_params) + self.__process_pagination(query_params) + self.__process_size(query_params) + self.__process_multi_match(query_params) + self.__process_aggregation(query_params) + + return self.__builder.get_final_query_body() + + def __process_template_fields(self, query_params): + for field, query_type in self.__opensearch_template_fields.items(): + self.__process_query_field(field, query_type, query_params) + + def __process_query_field(self, field, query_type, query_params): + if query_type == "match": + value = query_params.get(field) + self.__add_match_query(field, value) + return + + if query_type == "terms": + values = query_params.getlist(field) + self.__add_terms_query(field, values) + return + + if query_type == "range": + self.__add_range_query(field, query_params) + return + + if query_type == "geo_distance": + lat = query_params.get(f"{field}[lat]") + lng = query_params.get(f"{field}[lng]") + distance = query_params.get("distance", "10km") + self.__add_geo_distance_query(field, lat, lng, distance) + def __add_match_query(self, field, value): if value: self.__builder.add_match(field, value, fuzziness='2') @@ -44,62 +82,52 @@ def __add_range_query(self, field, query_params): def __add_geo_distance_query(self, field, lat, lng, distance): if lat and lng: self.__builder.add_geo_distance( - field, - float(lat), - float(lng), - distance + field, float(lat), float(lng), distance ) - def build_query(self, query_params): - self.__builder.reset() - - for field, query_type in self.__opensearch_template_fields.items(): - if query_type == "match": - value = query_params.get(field) - self.__add_match_query(field, value) - continue - - if query_type == "terms": - values = query_params.getlist(field) - self.__add_terms_query(field, values) - continue - - if query_type == "range": - self.__add_range_query(field, query_params) - continue - - if query_type == "geo_distance": - lat = query_params.get(f"{field}[lat]") - lng = query_params.get(f"{field}[lng]") - distance = query_params.get("distance", "10km") - self.__add_geo_distance_query(field, lat, lng, - distance) - + def __process_sorting(self, query_params): sort_by = query_params.get(V1_PARAMETERS_LIST.SORT_BY) + if sort_by: order_by = query_params.get(V1_PARAMETERS_LIST.ORDER_BY) self.__builder.add_sort(sort_by, order_by) - search_after_id = query_params. \ - get(V1_PARAMETERS_LIST.SEARCH_AFTER + "[id]") - search_after_value = query_params. \ - get(V1_PARAMETERS_LIST.SEARCH_AFTER + "[value]") + def __process_search_after(self, query_params): + search_after_id = query_params.get( + V1_PARAMETERS_LIST.SEARCH_AFTER + "[id]" + ) + search_after_value = query_params.get( + V1_PARAMETERS_LIST.SEARCH_AFTER + "[value]" + ) + if search_after_id and search_after_value: self.__builder.add_search_after( - search_after_value, - search_after_id + search_after_value, search_after_id ) + def __process_pagination(self, query_params): paginate_from = query_params.get(V1_PARAMETERS_LIST.FROM) + if paginate_from: self.__builder.add_from(paginate_from) + def __process_size(self, query_params): size = query_params.get(V1_PARAMETERS_LIST.SIZE) + if size: self.__builder.add_size(size) + def __process_multi_match(self, query_params): multi_match_query = query_params.get(V1_PARAMETERS_LIST.QUERY) - if multi_match_query: + + if multi_match_query and hasattr(self.__builder, 'add_multi_match'): self.__builder.add_multi_match(multi_match_query) - return self.__builder.get_final_query_body() + def __process_aggregation(self, query_params): + aggregation = query_params.get(V1_PARAMETERS_LIST.AGGREGATION) + geohex_grid_precision = query_params.get( + V1_PARAMETERS_LIST.GEOHEX_GRID_PRECISION + ) + + if aggregation and hasattr(self.__builder, 'add_aggregations'): + self.__builder.add_aggregations(aggregation, geohex_grid_precision) diff --git a/src/django/api/views/v1/opensearch_query_builder/production_locations_query_builder.py b/src/django/api/views/v1/opensearch_query_builder/production_locations_query_builder.py index 77fda9848..aa9b7b868 100644 --- a/src/django/api/views/v1/opensearch_query_builder/production_locations_query_builder.py +++ b/src/django/api/views/v1/opensearch_query_builder/production_locations_query_builder.py @@ -98,3 +98,34 @@ def add_search_after( id_type='os_id' ): super().add_search_after(search_after_value, search_after_id, id_type) + + def add_multi_match(self, query): + self.query_body['query']['bool']['must'].append({ + 'multi_match': { + 'query': query, + 'fields': [ + f'{V1_PARAMETERS_LIST.NAME}^2', + V1_PARAMETERS_LIST.ADDRESS, + V1_PARAMETERS_LIST.DESCRIPTION, + V1_PARAMETERS_LIST.LOCAL_NAME + ], + 'fuzziness': self.default_fuzziness + } + }) + + def add_aggregations(self, aggregation, geohex_grid_precision=None): + if aggregation == 'geohex_grid': + aggregation_config = { + 'field': 'coordinates' + } + + if geohex_grid_precision: + aggregation_config['precision'] = geohex_grid_precision + + self.query_body['aggregations'] = { + 'grouped': { + 'geohex_grid': aggregation_config + } + } + + return self.query_body diff --git a/src/django/api/views/v1/parameters_list.py b/src/django/api/views/v1/parameters_list.py index 1133a76bf..73dbd8e8a 100644 --- a/src/django/api/views/v1/parameters_list.py +++ b/src/django/api/views/v1/parameters_list.py @@ -31,3 +31,5 @@ class V1_PARAMETERS_LIST: MODERATION_ID = 'moderation_id' DATE_GTE = 'date_gte' DATE_LT = 'date_lt' + AGGREGATION = 'aggregation' + GEOHEX_GRID_PRECISION = 'geohex_grid_precision' diff --git a/src/django/api/views/v1/utils.py b/src/django/api/views/v1/utils.py index cd27b7f7d..2416851bf 100644 --- a/src/django/api/views/v1/utils.py +++ b/src/django/api/views/v1/utils.py @@ -35,7 +35,9 @@ def serialize_params(serializer_class, query_params): V1_PARAMETERS_LIST.ORDER_BY, V1_PARAMETERS_LIST.SIZE, V1_PARAMETERS_LIST.DATE_GTE, - V1_PARAMETERS_LIST.DATE_LT + V1_PARAMETERS_LIST.DATE_LT, + V1_PARAMETERS_LIST.GEOHEX_GRID_PRECISION, + V1_PARAMETERS_LIST.AGGREGATION, ]: flattened_query_params[key] = value[0] else: @@ -52,18 +54,18 @@ def serialize_params(serializer_class, query_params): for field, error_list in params.errors.items(): error_response['errors'].append({ 'field': field, - 'detail': error_list[0].title() + 'detail': error_list[0].capitalize() }) # Handle errors that come from serializers detail_errors = params.errors.get('detail') if detail_errors: - error_response['detail'] = detail_errors[0].title() + error_response['detail'] = detail_errors[0].capitalize() if 'detail' in params.errors and 'errors' in params.errors: for error_item in params.errors.get('errors', []): error_response['errors'].append({ - 'field': error_item.get('field', '').title(), - 'detail': error_item.get('detail', '').title() + 'field': error_item.get('field', ''), + 'detail': error_item.get('detail', '').capitalize() }) return None, error_response diff --git a/src/tests/v1/test_moderation_events.py b/src/tests/v1/test_moderation_events.py index f003f2fa2..a398d44c5 100644 --- a/src/tests/v1/test_moderation_events.py +++ b/src/tests/v1/test_moderation_events.py @@ -146,7 +146,7 @@ def test_size_exceeds_max_limit(self): ) self.assertEqual(response.status_code, 400) result = response.json() - self.assertEqual(result['detail'], 'The Request Query Is Invalid.') + self.assertEqual(result['detail'], 'The request query is invalid.') def test_search_after_pagination(self): # Step 1: Get the first set of results @@ -215,10 +215,10 @@ def test_date_gte_greater_than_date_lt(self): self.assertEqual(response.status_code, 400) error = result['errors'][0] - self.assertEqual(error['field'], 'Date_Gte') + self.assertEqual(error['field'], 'date_gte') self.assertEqual( error['detail'], - "The 'Date_Gte' Must Be Less Than Or Equal To 'Date_Lt'." + "The 'date_gte' must be less than or equal to 'date_lt'." ) def test_valid_country(self): @@ -245,8 +245,8 @@ def test_invalid_country(self): self.assertEqual(len(result['errors']), 1) error = result['errors'][0] - self.assertEqual(error['field'], 'Country') - self.assertEqual(error['detail'], "'Usa' Is Not A Valid Alpha-2 Country Code.") + self.assertEqual(error['field'], 'country') + self.assertEqual(error['detail'], "'usa' is not a valid alpha-2 country code.") def test_valid_moderation_id(self): valid_moderation_id = '3b50d60f-85b2-4a17-9f8d-5d3e1fc68198' @@ -272,5 +272,5 @@ def test_invalid_moderation_id(self): self.assertEqual(len(result['errors']), 1) error = result['errors'][0] - self.assertEqual(error['field'], 'Moderation_Id') - self.assertEqual(error['detail'], "Invalid Uuid(S): 123!.") + self.assertEqual(error['field'], 'moderation_id') + self.assertEqual(error['detail'], "Invalid uuid(s): 123!.") diff --git a/src/tests/v1/test_production_locations.py b/src/tests/v1/test_production_locations.py index a99b62985..be0b8f9e5 100644 --- a/src/tests/v1/test_production_locations.py +++ b/src/tests/v1/test_production_locations.py @@ -104,3 +104,17 @@ def test_production_locations_history_os_id(self): self.assertEqual( result['data'][0]['historical_os_id'], 'US20203545HUE4L' ) + + def test_production_locations_aggregations(self): + query = "?aggregation=geohex_grid&geohex_grid_precision=2" + response = requests.get( + f"{self.root_url}/api/v1/production-locations/{query}", + headers=self.basic_headers, + ) + + result = response.json() + self.assertIsNotNone(result['aggregations']) + self.assertIsNotNone(result['aggregations']['geohex_grid'][0]['key']) + self.assertIsNotNone( + result['aggregations']['geohex_grid'][0]['doc_count'] + )