diff --git a/doc/release/RELEASE-NOTES.md b/doc/release/RELEASE-NOTES.md index 790df5374..235161cc8 100644 --- a/doc/release/RELEASE-NOTES.md +++ b/doc/release/RELEASE-NOTES.md @@ -11,6 +11,8 @@ This project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html ### Code/API changes * [OSDEV-2137](https://opensupplyhub.atlassian.net/browse/OSDEV-2137) - Switched to a custom, page-compatible keyset for the `/facilities-downloads` endpoint, enabling more efficient, cursor-based pagination and improved download performance and compatibility. +* [OSDEV-2068](https://opensupplyhub.atlassian.net/browse/OSDEV-2068) - Enabled users to download their own data without impacting free & purchased data-download allowances. Introduced `is_same_contributor` field in the GET `/api/facilities-downloads` response. + ### Release instructions * Ensure that the following commands are included in the `post_deployment` command: diff --git a/src/django/api/facilities_download_view_set.py b/src/django/api/facilities_download_view_set.py index 83ceb420b..aa768c5cd 100644 --- a/src/django/api/facilities_download_view_set.py +++ b/src/django/api/facilities_download_view_set.py @@ -9,6 +9,7 @@ FacilityDownloadSerializerEmbedMode from api.serializers.utils import get_embed_contributor_id_from_query_params from api.services.facilities_download_service import FacilitiesDownloadService +from api.serializers.facility.utils import is_same_contributor_for_queryset from api.constants import PaginationConfig @@ -44,10 +45,17 @@ def list(self, request): base_qs = FacilitiesDownloadService.get_filtered_queryset(request) + is_same_contributor = is_same_contributor_for_queryset( + base_qs, + request + ) + limit = None + if ( not switch_is_active('private_instance') and not self.__is_embed_mode() + and not is_same_contributor ): limit = FacilitiesDownloadService.get_download_limit(request) @@ -78,9 +86,14 @@ def list(self, request): ) list_serializer = self.get_serializer(items) - rows = [f['row'] for f in list_serializer.data] + rows = [facility_data['row'] for facility_data in list_serializer.data] headers = list_serializer.child.get_headers() - data = {'rows': rows, 'headers': headers} + + data = { + 'rows': rows, + 'headers': headers, + 'is_same_contributor': is_same_contributor + } payload = { 'next': next_link, @@ -94,18 +107,23 @@ def list(self, request): payload['count'] = base_qs.count() if is_last_page and limit: - total_records = (page - 1) * page_size + len(items) - prev_free = getattr(limit, 'free_download_records', 0) - prev_paid = getattr(limit, 'paid_download_records', 0) + # Charge for the full result set, not just the last page size + returned_count = base_qs.count() + + prev_free_amount = getattr(limit, 'free_download_records', 0) + prev_paid_amount = getattr(limit, 'paid_download_records', 0) + FacilitiesDownloadService.register_download_if_needed( limit, - total_records - ) - FacilitiesDownloadService.send_email_if_needed( - request, - limit, - prev_free, - prev_paid + returned_count, + is_same_contributor ) + if returned_count: + FacilitiesDownloadService.send_email_if_needed( + request, + limit, + prev_free_amount, + prev_paid_amount + ) return Response(payload) diff --git a/src/django/api/models/facility_download_limit.py b/src/django/api/models/facility_download_limit.py index 9cb6f99e1..e93a6ec85 100644 --- a/src/django/api/models/facility_download_limit.py +++ b/src/django/api/models/facility_download_limit.py @@ -53,17 +53,34 @@ class FacilityDownloadLimit(models.Model): objects = FacilityDownloadLimitManager() @transaction.atomic - def register_download(self, records_to_subtract): + def register_download( + self, + records_to_subtract: int, + ): self.refresh_from_db() - if self.free_download_records >= records_to_subtract: - self.free_download_records -= records_to_subtract + # Prevent overdrafts by capping to remaining quota + remaining_quota = (self.free_download_records or 0) + \ + (self.paid_download_records or 0) + to_subtract = min( + max(records_to_subtract, 0), + remaining_quota + ) + + if to_subtract == 0: + return + + if self.free_download_records >= to_subtract: + self.free_download_records -= to_subtract else: remaining_records = ( - records_to_subtract - self.free_download_records + to_subtract - self.free_download_records ) self.free_download_records = 0 - self.paid_download_records -= remaining_records + self.paid_download_records = max( + self.paid_download_records - remaining_records, + 0 + ) self.save() diff --git a/src/django/api/serializers/facility/utils.py b/src/django/api/serializers/facility/utils.py index 027566dba..969379bf7 100644 --- a/src/django/api/serializers/facility/utils.py +++ b/src/django/api/serializers/facility/utils.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import (Iterable, Union) from itertools import groupby from api.constants import FacilityClaimStatuses @@ -346,3 +346,22 @@ def add_http_prefix_to_url(value: str) -> str: ): value = f"https://{value}" return value + + +def is_same_contributor_for_queryset(queryset: Iterable, request) -> bool: + contributor = getattr(request.user, 'contributor', None) + if not contributor: + return False + current_user_contributor_id = contributor.id + + found_any_facility = False + for facility in queryset: + found_any_facility = True + facility_contributor_ids = [ + contributor.get('id') for contributor in facility.contributors + if contributor.get('id') is not None + ] + if current_user_contributor_id not in facility_contributor_ids: + return False + + return found_any_facility diff --git a/src/django/api/services/facilities_download_service.py b/src/django/api/services/facilities_download_service.py index 707d3035c..6c8ec5120 100644 --- a/src/django/api/services/facilities_download_service.py +++ b/src/django/api/services/facilities_download_service.py @@ -95,9 +95,28 @@ def enforce_limits(qs, limit, is_first_page): ) @staticmethod - def register_download_if_needed(limit, record_count): - if limit: - limit.register_download(record_count) + def check_pagination(page_queryset): + if page_queryset is None: + raise ValidationError("Invalid pageSize parameter") + return page_queryset + + @staticmethod + def register_download_if_needed( + limit: FacilityDownloadLimit, + records_returned: int, + is_same_contributor: bool = False + ): + if is_same_contributor or not limit: + return + try: + count = int(records_returned) + except (TypeError, ValueError): + count = 0 + + if count <= 0: + return + + limit.register_download(count) @staticmethod def send_email_if_needed( diff --git a/src/django/api/tests/test_facilities_download_viewset.py b/src/django/api/tests/test_facilities_download_viewset.py index 6ce4774aa..feef0822e 100644 --- a/src/django/api/tests/test_facilities_download_viewset.py +++ b/src/django/api/tests/test_facilities_download_viewset.py @@ -2,9 +2,10 @@ from rest_framework.test import APITestCase from django.urls import reverse from django.contrib.auth.models import Group -from unittest.mock import patch +from unittest.mock import patch, MagicMock from api.models.user import User +from api.models.contributor.contributor import Contributor from api.constants import FeatureGroups from api.models.facility_download_limit import FacilityDownloadLimit from django.utils import timezone @@ -418,7 +419,7 @@ def test_query_parameters(self): user = self.create_user() self.login_user(user) - response = self.get_facility_downloads({"countries": "IN"}) + response = self.get_facility_downloads({"countries": ["IN"]}) self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [ @@ -483,9 +484,7 @@ def test_new_user_has_current_date_in_updated_at(self): self.assertEqual(limit.updated_at.date(), current_date.date()) def test_old_user_has_release_date_in_updated_at(self): - # The record has been added to FacilityDownloadLimit. user = self.create_user() - # Simulation old user. FacilityDownloadLimit.objects.filter(user=user).delete() self.login_user(user) release_date = make_aware(datetime(2025, 7, 12)) @@ -518,8 +517,287 @@ def test_api_user_not_limited_by_download_count(self): user = self.create_user(is_api_user=True) self.login_user(user) - # Make multiple downloads that would exceed the limit for regular - # users. for _ in range(5): response = self.get_facility_downloads() self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_is_same_contributor_true_when_all_facilities_belong_to_user(self): + user = self.create_user() + self.login_user(user) + + contributor = Contributor.objects.create( + admin=user, + name="Test Contributor", + contrib_type="Brand / Retailer" + ) + + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.get_filtered_queryset' + ) as mock_get_queryset: + mock_queryset = MagicMock() + mock_facility = MagicMock() + mock_facility.contributors = [{'id': contributor.id}] + mock_queryset.__iter__.return_value = [mock_facility] + mock_queryset.count.return_value = 1 + mock_get_queryset.return_value = mock_queryset + + response = self.get_facility_downloads( + {'contributors': [str(contributor.id)]} + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue( + response.data['results']['is_same_contributor'] + ) + + def test_is_same_contributor_false_when_mixed_contributors(self): + user = self.create_user() + self.login_user(user) + + contributor = Contributor.objects.create( + admin=user, + name="Test Contributor", + contrib_type="Brand / Retailer" + ) + + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.get_filtered_queryset' + ) as mock_get_queryset: + mock_queryset = MagicMock() + mock_facility1 = MagicMock() + mock_facility1.contributors = [{'id': contributor.id}] + mock_facility2 = MagicMock() + mock_facility2.contributors = [{'id': 999}] + mock_queryset.__iter__.return_value = [ + mock_facility1, + mock_facility2 + ] + mock_queryset.count.return_value = 2 + mock_get_queryset.return_value = mock_queryset + + response = self.get_facility_downloads( + {'contributors': [str(contributor.id)]} + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertFalse( + response.data['results']['is_same_contributor'] + ) + + def test_is_same_contributor_false_when_user_has_no_contributor(self): + user = self.create_user() + self.login_user(user) + + response = self.get_facility_downloads({'contributors': ['123']}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertFalse( + response.data['results']['is_same_contributor'] + ) + + def test_is_same_contributor_with_combine_contributors_and_logic(self): + user = self.create_user() + self.login_user(user) + + contributor = Contributor.objects.create( + admin=user, + name="Test Contributor", + contrib_type="Brand / Retailer" + ) + + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.get_filtered_queryset' + ) as mock_get_queryset: + mock_queryset = MagicMock() + mock_facility = MagicMock() + mock_facility.contributors = [ + {'id': contributor.id}, + {'id': 456} + ] + mock_queryset.__iter__.return_value = [mock_facility] + mock_queryset.count.return_value = 1 + mock_get_queryset.return_value = mock_queryset + + response = self.get_facility_downloads({ + 'contributors': [str(contributor.id), '456'], + 'combine_contributors': 'AND' + }) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue( + response.data['results']['is_same_contributor'] + ) + + def test_multi_page_download_decrements_free_by_total_count(self): + user = self.create_user() + self.login_user(user) + + limit = FacilityDownloadLimit.objects.create( + user=user, + free_download_records=20, + paid_download_records=0, + ) + + # Request first page to get the total count from the payload + resp_page1 = self.get_facility_downloads({"pageSize": 10, "page": 1}) + self.assertEqual(resp_page1.status_code, status.HTTP_200_OK) + total_count = resp_page1.data.get("count") + self.assertIsNotNone(total_count) + + # Quotas should remain unchanged after first page + limit.refresh_from_db() + self.assertEqual(limit.free_download_records, 20) + self.assertEqual(limit.paid_download_records, 0) + + # Request last page to trigger quota registration using total_count + # Patch email/checkout to avoid external calls during tests + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.send_email_if_needed', + return_value=None + ): + resp_page2 = self.get_facility_downloads({ + "pageSize": 10, + "page": 2 + }) + self.assertEqual(resp_page2.status_code, status.HTTP_200_OK) + + limit.refresh_from_db() + expected_free = max(20 - total_count, 0) + self.assertEqual(limit.free_download_records, expected_free) + self.assertEqual(limit.paid_download_records, 0) + + def test_multi_page_download_consumes_paid_when_free_insufficient(self): + user = self.create_user() + self.login_user(user) + + limit = FacilityDownloadLimit.objects.create( + user=user, + free_download_records=5, + paid_download_records=20, + ) + + resp_page1 = self.get_facility_downloads({"pageSize": 10, "page": 1}) + self.assertEqual(resp_page1.status_code, status.HTTP_200_OK) + total_count = resp_page1.data.get("count") + self.assertIsNotNone(total_count) + + # Trigger decrement on last page + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.send_email_if_needed', + return_value=None + ): + resp_page2 = self.get_facility_downloads({ + "pageSize": 10, + "page": 2 + }) + self.assertEqual(resp_page2.status_code, status.HTTP_200_OK) + + limit.refresh_from_db() + self.assertEqual(limit.free_download_records, 0) + expected_paid = max(20 - max(total_count - 5, 0), 0) + self.assertEqual(limit.paid_download_records, expected_paid) + + def test_is_same_contributor_with_empty_queryset(self): + """Test is_same_contributor with empty queryset.""" + user = self.create_user() + self.login_user(user) + + contributor = Contributor.objects.create( + admin=user, + name="Test Contributor", + contrib_type="Brand / Retailer" + ) + + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.get_filtered_queryset' + ) as mock_get_queryset: + mock_queryset = MagicMock() + mock_queryset.__iter__.return_value = [] + mock_queryset.count.return_value = 0 + mock_get_queryset.return_value = mock_queryset + + response = self.get_facility_downloads( + {'contributors': [str(contributor.id)]} + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertFalse( + response.data['results']['is_same_contributor'] + ) + + def test_is_same_contributor_with_multiple_contributors_or_logic(self): + user = self.create_user() + self.login_user(user) + + contributor = Contributor.objects.create( + admin=user, + name="Test Contributor", + contrib_type="Brand / Retailer" + ) + + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.get_filtered_queryset' + ) as mock_get_queryset: + mock_queryset = MagicMock() + mock_facility = MagicMock() + mock_facility.contributors = [ + {'id': contributor.id}, + {'id': 456}, + {'id': 789} + ] + mock_queryset.__iter__.return_value = [mock_facility] + mock_queryset.count.return_value = 1 + mock_get_queryset.return_value = mock_queryset + + response = self.get_facility_downloads({ + 'contributors': [str(contributor.id), '456'] + }) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue( + response.data['results']['is_same_contributor'] + ) + + +def test_exhausted_quota_all_mine_still_allowed(self): + user = self.create_user() + self.login_user(user) + + contributor = Contributor.objects.create( + admin=user, + name="Test Contributor", + contrib_type="Brand / Retailer" + ) + + limit = FacilityDownloadLimit.objects.create( + user=user, + free_download_records=0, + paid_download_records=0, + ) + + with patch( + 'api.services.facilities_download_service.' + 'FacilitiesDownloadService.get_filtered_queryset' + ) as mock_get_queryset: + mock_queryset = MagicMock() + mock_facility = MagicMock() + mock_facility.contributors = [{'id': contributor.id}] + mock_queryset.__iter__.return_value = [mock_facility] + mock_queryset.count.return_value = 1 + mock_get_queryset.return_value = mock_queryset + + resp = self.get_facility_downloads({ + 'contributors': [str(contributor.id)] + }) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + # Quotas must remain unchanged since it's an own-data download + limit.refresh_from_db() + self.assertEqual(limit.free_download_records, 0) + self.assertEqual(limit.paid_download_records, 0) diff --git a/src/django/api/tests/test_facility_download.py b/src/django/api/tests/test_facility_download.py index 826a91062..8ed40298a 100644 --- a/src/django/api/tests/test_facility_download.py +++ b/src/django/api/tests/test_facility_download.py @@ -416,6 +416,9 @@ def get_rows(self, response): def get_headers(self, response): return response.data["results"]["headers"] + def get_is_same_contributor(self, response): + return response.data["results"].get("is_same_contributor") + def test_download_is_fetched(self): response = self.get_facility_download() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -427,6 +430,8 @@ def test_default_headers_are_created(self): headers = self.get_headers(response) expected_headers = self.default_headers self.assertEqual(headers, expected_headers) + # Ensure flag is present at top-level results for UI logic + self.assertIn("is_same_contributor", response.data["results"]) def test_embed_headers_exclude_contributor(self): params = "embed=1&contributors={}".format(self.contributor.id) @@ -478,6 +483,8 @@ def test_base_row_is_created(self): ] self.assertEqual(len(base_row), len(expected_base_row)) self.assertEqual(base_row, expected_base_row) + # Assert the flag exists and is boolean + self.assertIsInstance(self.get_is_same_contributor(response), bool) def test_contrib_rows_are_created_for_no_contrib_values(self): params = "embed=1&contributors={}&q={}".format( @@ -627,6 +634,8 @@ def test_handles_additional_list_items(self): headers = self.get_headers(response) self.assertEqual(headers, self.default_headers) + # Flag should exist regardless of headers + self.assertIn("is_same_contributor", response.data["results"]) rows = self.get_rows(response) diff --git a/src/django/api/tests/test_facility_download_limit_model.py b/src/django/api/tests/test_facility_download_limit_model.py index 1f610f380..bc0080ac2 100644 --- a/src/django/api/tests/test_facility_download_limit_model.py +++ b/src/django/api/tests/test_facility_download_limit_model.py @@ -1,5 +1,6 @@ from django.test import TestCase from api.models import User, FacilityDownloadLimit +from api.services.facilities_download_service import FacilitiesDownloadService class FacilityDownloadLimitModelTest(TestCase): @@ -40,3 +41,77 @@ def test_register_download(self): limit.paid_download_records, expected_paid_download_records_after ) + + def test_register_download_if_needed_with_is_same_contributor_true(self): + """ + Test that when is_same_contributor=True, + download limits are NOT decremented. + """ + limit = FacilityDownloadLimit.objects.create( + user=self.user, + free_download_records=100, + paid_download_records=200 + ) + + initial_free = limit.free_download_records + initial_paid = limit.paid_download_records + records_to_subtract = 50 + + FacilitiesDownloadService.register_download_if_needed( + limit, records_to_subtract, is_same_contributor=True + ) + + limit.refresh_from_db() + self.assertEqual(limit.free_download_records, initial_free) + self.assertEqual(limit.paid_download_records, initial_paid) + + def test_register_download_if_needed_with_is_same_contributor_false(self): + """ + Test that when is_same_contributor=False, + download limits are decremented. + """ + limit = FacilityDownloadLimit.objects.create( + user=self.user, + free_download_records=100, + paid_download_records=200 + ) + + initial_free = limit.free_download_records + initial_paid = limit.paid_download_records + records_to_subtract = 50 + + FacilitiesDownloadService.register_download_if_needed( + limit, records_to_subtract, is_same_contributor=False + ) + + limit.refresh_from_db() + self.assertEqual( + limit.free_download_records, initial_free - records_to_subtract + ) + self.assertEqual(limit.paid_download_records, initial_paid) + + def test_register_download_if_needed_exceeds_free_limit(self): + """ + Test that when records exceed free limit, + paid records are used. + """ + limit = FacilityDownloadLimit.objects.create( + user=self.user, + free_download_records=50, + paid_download_records=200 + ) + + initial_free = limit.free_download_records + initial_paid = limit.paid_download_records + records_to_subtract = 100 # Exceeds free limit + + FacilitiesDownloadService.register_download_if_needed( + limit, records_to_subtract, is_same_contributor=False + ) + + limit.refresh_from_db() + self.assertEqual(limit.free_download_records, 0) + self.assertEqual( + limit.paid_download_records, + initial_paid - (records_to_subtract - initial_free) + ) diff --git a/src/django/api/views/facility/facilities_view_set.py b/src/django/api/views/facility/facilities_view_set.py index da6f7fa02..8f9db10fb 100644 --- a/src/django/api/views/facility/facilities_view_set.py +++ b/src/django/api/views/facility/facilities_view_set.py @@ -90,6 +90,7 @@ from api.serializers.facility.facility_list_page_parameter_serializer \ import FacilityListPageParameterSerializer from api.throttles import DataUploadThrottle +from api.serializers.facility.utils import is_same_contributor_for_queryset from api.views.disabled_pagination_inspector import DisabledPaginationInspector @@ -254,32 +255,55 @@ def list(self, request): context = {'request': request} - if page_queryset is not None: - should_serialize_details = params.validated_data['detail'] - should_serialize_number_of_public_contributors = \ - params.validated_data["number_of_public_contributors"] - exclude_fields = [] - - if not should_serialize_details: - exclude_fields.extend([ - 'contributor_fields', - 'extended_fields', - 'contributors', - 'sector']) - if not should_serialize_number_of_public_contributors: - exclude_fields.extend(['number_of_public_contributors']) + should_serialize_details = params.validated_data['detail'] + should_serialize_number_of_public_contributors = \ + params.validated_data["number_of_public_contributors"] + exclude_fields = [] + + if not should_serialize_details: + exclude_fields.extend([ + 'contributor_fields', + 'extended_fields', + 'contributors', + 'sector']) + if not should_serialize_number_of_public_contributors: + exclude_fields.extend(['number_of_public_contributors']) + if page_queryset is not None: serializer = FacilityIndexSerializer(page_queryset, many=True, context=context, exclude_fields=exclude_fields) - response = self.get_paginated_response(serializer.data) - response.data['extent'] = extent - return response - response_data = FacilityIndexSerializer(queryset, many=True, - context=context).data - response_data['extent'] = extent - return Response(response_data) + is_same_contributor = is_same_contributor_for_queryset( + page_queryset, + request + ) + + page = self.get_paginated_response(serializer.data) + page.data['extent'] = extent + page.data['params'] = params.validated_data + page.data['is_same_contributor'] = is_same_contributor + return page + + # Non-paginated response + is_same_contributor = is_same_contributor_for_queryset( + queryset, + request + ) + + serializer = FacilityIndexSerializer(queryset, many=True, + context=context, + exclude_fields=exclude_fields) + + response = { + 'type': 'FeatureCollection', + 'features': serializer.data, + 'is_same_contributor': is_same_contributor, + } + if extent is not None: + response['extent'] = extent + response['params'] = params.validated_data + return Response(response) @swagger_auto_schema(manual_parameters=facility_parameters) def retrieve(self, request, pk=None): diff --git a/src/react/src/__tests__/components/DownloadFacilitiesButton.test.js b/src/react/src/__tests__/components/DownloadFacilitiesButton.test.js index 56c6ca122..3643a97a1 100644 --- a/src/react/src/__tests__/components/DownloadFacilitiesButton.test.js +++ b/src/react/src/__tests__/components/DownloadFacilitiesButton.test.js @@ -72,7 +72,7 @@ describe('DownloadFacilitiesButton component', () => { - , + , { container } ); @@ -292,4 +292,61 @@ describe('DownloadFacilitiesButton component', () => { expect(screen.getByText(expectedTooltipText)).toBeInTheDocument() ); }); + + test('shows same-contributor tooltip when isSameContributor is true for logged-in users', async () => { + const props = { + disabled: false, + userAllowedRecords: 1000, + isSameContributor: true, + }; + const customState = { + auth: { + user: { + user: { + isAnon: false, + }, + }, + }, + embeddedMap: { embed: false }, + }; + + const expectedTooltipText = 'You are downloading data for the same contributor as your account. Downloading data for the same contributor is free.'; + + const { getByRole } = renderComponent(props, customState); + const button = getByRole('button', { name: 'Download' }); + fireEvent.mouseOver(button); + + await waitFor(() => + expect(screen.getByText(expectedTooltipText)).toBeInTheDocument() + ); + }); + + test('anonymous users see anonymous tooltip even when isSameContributor is true', async () => { + const props = { + disabled: true, + userAllowedRecords: 1000, + isSameContributor: true, + }; + const customState = { + auth: { + user: { + user: { + isAnon: true, + }, + }, + }, + embeddedMap: { embed: false }, + }; + + const expectedTooltipText = 'Log in or sign up to download this dataset.'; + + const { getByRole } = renderComponent(props, customState); + const button = getByRole('button', { name: 'Download' }); + expect(button).toBeDisabled(); + fireEvent.mouseOver(button); + + await waitFor(() => + expect(screen.getByText(expectedTooltipText)).toBeInTheDocument() + ); + }); }); diff --git a/src/react/src/__tests__/components/DownloadLimitInfo.test.js b/src/react/src/__tests__/components/DownloadLimitInfo.test.js index 9d0515efd..587a2c4ac 100644 --- a/src/react/src/__tests__/components/DownloadLimitInfo.test.js +++ b/src/react/src/__tests__/components/DownloadLimitInfo.test.js @@ -18,7 +18,7 @@ describe('DownloadLimitInfo component', () => { ).toBeInTheDocument(); expect( getByText( - /This search includes more production locations than you have available for download. You may purchase additional downloads to continue\./ + /This search includes more production locations than you have available for download\. You may purchase additional downloads to continue\./ ) ).toBeInTheDocument(); expect( diff --git a/src/react/src/components/DownloadButtonWithFlags.jsx b/src/react/src/components/DownloadButtonWithFlags.jsx new file mode 100644 index 000000000..8d795456e --- /dev/null +++ b/src/react/src/components/DownloadButtonWithFlags.jsx @@ -0,0 +1,58 @@ +import React from 'react'; +import { bool, number, func } from 'prop-types'; +import FeatureFlag from './FeatureFlag'; +import DownloadFacilitiesButton from './DownloadFacilitiesButton'; +import { PRIVATE_INSTANCE, FACILITIES_DOWNLOAD_LIMIT } from '../util/constants'; + +function DownloadButtonWithFlags({ + embed, + facilitiesCount, + isSameContributor, + userAllowedRecords, + setLoginRequiredDialogIsOpen, +}) { + const count = facilitiesCount == null ? 0 : facilitiesCount; + + return ( + FACILITIES_DOWNLOAD_LIMIT} + upgrade={ + !embed && + !isSameContributor && + count > userAllowedRecords + } + userAllowedRecords={userAllowedRecords} + setLoginRequiredDialogIsOpen={setLoginRequiredDialogIsOpen} + facilitiesCount={count} + isSameContributor={isSameContributor} + /> + } + > + FACILITIES_DOWNLOAD_LIMIT} + userAllowedRecords={FACILITIES_DOWNLOAD_LIMIT} + setLoginRequiredDialogIsOpen={setLoginRequiredDialogIsOpen} + facilitiesCount={count} + isSameContributor={isSameContributor} + /> + + ); +} + +DownloadButtonWithFlags.propTypes = { + embed: bool.isRequired, + facilitiesCount: number, + isSameContributor: bool, + userAllowedRecords: number.isRequired, + setLoginRequiredDialogIsOpen: func.isRequired, +}; + +DownloadButtonWithFlags.defaultProps = { + facilitiesCount: 0, + isSameContributor: false, +}; + +export default DownloadButtonWithFlags; diff --git a/src/react/src/components/DownloadFacilitiesButton.jsx b/src/react/src/components/DownloadFacilitiesButton.jsx index 20eee1680..ac13d034f 100644 --- a/src/react/src/components/DownloadFacilitiesButton.jsx +++ b/src/react/src/components/DownloadFacilitiesButton.jsx @@ -72,6 +72,7 @@ const DownloadFacilitiesButton = ({ classes, theme, facilitiesCount, + isSameContributor, }) => { const [anchorEl, setAnchorEl] = React.useState(null); const isPrivateInstance = includes(activeFeatureFlags, PRIVATE_INSTANCE); @@ -132,6 +133,7 @@ const DownloadFacilitiesButton = ({ upgrade, classes, facilitiesCount, + isSameContributor, }), [ user, @@ -141,6 +143,7 @@ const DownloadFacilitiesButton = ({ upgrade, classes, facilitiesCount, + isSameContributor, ], ); @@ -198,6 +201,7 @@ DownloadFacilitiesButton.defaultProps = { logDownloadError: null, checkoutUrl: null, checkoutUrlError: null, + isSameContributor: false, }; DownloadFacilitiesButton.propTypes = { @@ -213,6 +217,7 @@ DownloadFacilitiesButton.propTypes = { classes: object.isRequired, activeFeatureFlags: arrayOf(string).isRequired, facilitiesCount: number.isRequired, + isSameContributor: bool, }; function mapStateToProps({ diff --git a/src/react/src/components/FeatureFlag.jsx b/src/react/src/components/FeatureFlag.jsx index 6af76db97..da82bb3a9 100644 --- a/src/react/src/components/FeatureFlag.jsx +++ b/src/react/src/components/FeatureFlag.jsx @@ -1,7 +1,6 @@ import React from 'react'; import { arrayOf, bool, node } from 'prop-types'; import { connect } from 'react-redux'; -import includes from 'lodash/includes'; import { featureFlagPropType } from '../util/propTypes'; @@ -16,20 +15,21 @@ function FeatureFlag({ alternative, activeFeatureFlags, fetching, + isSameContributor, }) { if (fetching) { return null; } - if (!includes(activeFeatureFlags, flag)) { - return alternative; - } + const shouldRenderChildren = + isSameContributor || activeFeatureFlags.includes(flag); - return <>{children}; + return shouldRenderChildren ? <>{children} : alternative; } FeatureFlag.defaultProps = { alternative: null, + isSameContributor: false, }; FeatureFlag.propTypes = { @@ -38,6 +38,7 @@ FeatureFlag.propTypes = { alternative: node, activeFeatureFlags: arrayOf(featureFlagPropType).isRequired, fetching: bool.isRequired, + isSameContributor: bool, }; function mapStateToProps({ diff --git a/src/react/src/components/FilterSidebarFacilitiesTab.jsx b/src/react/src/components/FilterSidebarFacilitiesTab.jsx index e1e819ca0..781556fb5 100644 --- a/src/react/src/components/FilterSidebarFacilitiesTab.jsx +++ b/src/react/src/components/FilterSidebarFacilitiesTab.jsx @@ -5,10 +5,6 @@ import { withRouter } from 'react-router'; import { Link } from 'react-router-dom'; import CircularProgress from '@material-ui/core/CircularProgress'; import LinearProgress from '@material-ui/core/LinearProgress'; -import Dialog from '@material-ui/core/Dialog'; -import DialogTitle from '@material-ui/core/DialogTitle'; -import DialogContent from '@material-ui/core/DialogContent'; -import DialogActions from '@material-ui/core/DialogActions'; import Typography from '@material-ui/core/Typography'; import Button from '@material-ui/core/Button'; import List from '@material-ui/core/List'; @@ -24,7 +20,6 @@ import noop from 'lodash/noop'; import CopySearch from './CopySearch'; import FeatureFlag from './FeatureFlag'; -import DownloadFacilitiesButton from './DownloadFacilitiesButton'; import ShowOnly from './ShowOnly'; import { @@ -44,13 +39,7 @@ import { import { facilityCollectionPropType, userPropType } from '../util/propTypes'; -import { - REPORT_A_FACILITY, - authLoginFormRoute, - authRegisterFormRoute, - PRIVATE_INSTANCE, - FACILITIES_DOWNLOAD_LIMIT, -} from '../util/constants'; +import { REPORT_A_FACILITY } from '../util/constants'; import { makeFacilityDetailLink } from '../util/util'; import { useMergeButtonClickHandler } from '../util/hooks'; @@ -61,6 +50,8 @@ import { filterSidebarStyles } from '../util/styles'; import BadgeClaimed from './BadgeClaimed'; import CopyLinkIcon from './CopyLinkIcon'; import { useResultListHeight } from '../util/useHeightSubtract'; +import DownloadButtonWithFlags from './DownloadButtonWithFlags'; +import LoginRequiredDialog from './LoginRequiredDialog'; const makeFacilitiesTabStyles = theme => ({ noResultsTextStyles: Object.freeze({ @@ -205,6 +196,7 @@ function FilterSidebarFacilitiesTab({ updateTargetOSID, fetchTargetFacility, classes, + isSameContributor, }) { const [loginRequiredDialogIsOpen, setLoginRequiredDialogIsOpen] = useState( false, @@ -324,11 +316,6 @@ function FilterSidebarFacilitiesTab({ const facilitiesCount = get(data, 'count', null); - const LoginLink = props => ; - const RegisterLink = props => ( - - ); - const progress = facilitiesCount ? (get(downloadData, 'results.rows', []).length * 100) / facilitiesCount : 0; @@ -350,37 +337,15 @@ function FilterSidebarFacilitiesTab({ /> ) : ( - FACILITIES_DOWNLOAD_LIMIT - } - upgrade={ - !embed && - facilitiesCount > - user.allowed_records_number - } - userAllowedRecords={user.allowed_records_number} - setLoginRequiredDialogIsOpen={ - setLoginRequiredDialogIsOpen - } - facilitiesCount={facilitiesCount} - /> + - FACILITIES_DOWNLOAD_LIMIT - } - userAllowedRecords={FACILITIES_DOWNLOAD_LIMIT} - setLoginRequiredDialogIsOpen={ - setLoginRequiredDialogIsOpen - } - /> - + /> )} - - - - - ) : ( -
- )} - + setLoginRequiredDialogIsOpen(false)} + /> ); } @@ -651,6 +574,7 @@ FilterSidebarFacilitiesTab.defaultProps = { }, data: null, error: null, + isSameContributor: false, }; FilterSidebarFacilitiesTab.propTypes = { @@ -669,6 +593,7 @@ FilterSidebarFacilitiesTab.propTypes = { updateToMergeOSID: func.isRequired, fetchToMergeFacility: func.isRequired, updateTargetOSID: func.isRequired, + isSameContributor: bool, fetchTargetFacility: func.isRequired, }; @@ -708,6 +633,7 @@ function mapStateToProps({ facilityToMergeOSID, scrollTop, embed: !!embed, + isSameContributor: get(data, 'is_same_contributor', false), }; } diff --git a/src/react/src/components/FilterSidebarHeader.jsx b/src/react/src/components/FilterSidebarHeader.jsx index 3e9a31f7e..4b7c74acd 100644 --- a/src/react/src/components/FilterSidebarHeader.jsx +++ b/src/react/src/components/FilterSidebarHeader.jsx @@ -19,6 +19,7 @@ const FilterSidebarHeader = ({ classes, embed, user, + isSameContributor, }) => (

@@ -43,6 +44,7 @@ const FilterSidebarHeader = ({ > } > <> @@ -57,6 +59,7 @@ FilterSidebarHeader.propTypes = { embed: string.isRequired, classes: object.isRequired, user: userPropType.isRequired, + isSameContributor: bool.isRequired, }; const mapStateToProps = ({ @@ -71,6 +74,7 @@ const mapStateToProps = ({ embed, facilitiesCount: get(facilities, 'count', null), user, + isSameContributor: get(facilities, 'is_same_contributor', false), }); export default withStyles(filterSidebarHeaderStyles)( diff --git a/src/react/src/components/LoginRequiredDialog.jsx b/src/react/src/components/LoginRequiredDialog.jsx new file mode 100644 index 000000000..e7fda1aa6 --- /dev/null +++ b/src/react/src/components/LoginRequiredDialog.jsx @@ -0,0 +1,64 @@ +import React from 'react'; +import { bool, func } from 'prop-types'; +import Dialog from '@material-ui/core/Dialog'; +import DialogTitle from '@material-ui/core/DialogTitle'; +import DialogContent from '@material-ui/core/DialogContent'; +import DialogActions from '@material-ui/core/DialogActions'; +import Typography from '@material-ui/core/Typography'; +import Button from '@material-ui/core/Button'; +import RouterLink from './RouterLink'; +import { authLoginFormRoute, authRegisterFormRoute } from '../util/constants'; + +function LoginRequiredDialog({ open, onClose }) { + return ( + + {open ? ( + <> + Log In To Download + + + You must log in with an Open Supply Hub account + before downloading your search results. + + + + + + + + + ) : ( +
+ )} +
+ ); +} + +LoginRequiredDialog.propTypes = { + open: bool.isRequired, + onClose: func.isRequired, +}; + +export default LoginRequiredDialog; diff --git a/src/react/src/components/NonVectorTileFilterSidebarFacilitiesTab.jsx b/src/react/src/components/NonVectorTileFilterSidebarFacilitiesTab.jsx index e4991c722..831089c1e 100644 --- a/src/react/src/components/NonVectorTileFilterSidebarFacilitiesTab.jsx +++ b/src/react/src/components/NonVectorTileFilterSidebarFacilitiesTab.jsx @@ -3,10 +3,6 @@ import { arrayOf, bool, func, number, string } from 'prop-types'; import { connect } from 'react-redux'; import { Link } from 'react-router-dom'; import CircularProgress from '@material-ui/core/CircularProgress'; -import Dialog from '@material-ui/core/Dialog'; -import DialogTitle from '@material-ui/core/DialogTitle'; -import DialogContent from '@material-ui/core/DialogContent'; -import DialogActions from '@material-ui/core/DialogActions'; import Typography from '@material-ui/core/Typography'; import Button from '@material-ui/core/Button'; import List from '@material-ui/core/List'; @@ -20,8 +16,6 @@ import includes from 'lodash/includes'; import lowerCase from 'lodash/lowerCase'; import ControlledTextInput from './ControlledTextInput'; -import DownloadFacilitiesButton from './DownloadFacilitiesButton'; -import FeatureFlag from './FeatureFlag'; import { toggleFilterModal, @@ -30,18 +24,13 @@ import { import { facilityCollectionPropType, userPropType } from '../util/propTypes'; -import { - authLoginFormRoute, - authRegisterFormRoute, - PRIVATE_INSTANCE, - FACILITIES_DOWNLOAD_LIMIT, -} from '../util/constants'; - import { makeFacilityDetailLink, getValueFromEvent } from '../util/util'; import COLOURS from '../util/COLOURS'; import { filterSidebarStyles } from '../util/styles'; +import DownloadButtonWithFlags from './DownloadButtonWithFlags'; +import LoginRequiredDialog from './LoginRequiredDialog'; const SEARCH_TERM_INPUT = 'SEARCH_TERM_INPUT'; @@ -111,6 +100,7 @@ function NonVectorTileFilterSidebarFacilitiesTab({ updateFilterText, classes, user, + isSameContributor, }) { const [loginRequiredDialogIsOpen, setLoginRequiredDialogIsOpen] = useState( false, @@ -215,47 +205,20 @@ function NonVectorTileFilterSidebarFacilitiesTab({ ? `Displaying ${filteredFacilities.length} facilities of ${facilitiesCount} results` : `Displaying ${filteredFacilities.length} facilities`; - const LoginLink = props => ; - const RegisterLink = props => ( - - ); - const listHeaderInsetComponent = (
{headerDisplayString} - FACILITIES_DOWNLOAD_LIMIT - } - upgrade={ - !embed && - facilitiesCount > - user.allowed_records_number - } - userAllowedRecords={user.allowed_records_number} - setLoginRequiredDialogIsOpen={ - setLoginRequiredDialogIsOpen - } - facilitiesCount={facilitiesCount} - /> + - FACILITIES_DOWNLOAD_LIMIT - } - userAllowedRecords={FACILITIES_DOWNLOAD_LIMIT} - setLoginRequiredDialogIsOpen={ - setLoginRequiredDialogIsOpen - } - /> - + />
@@ -322,52 +285,10 @@ function NonVectorTileFilterSidebarFacilitiesTab({ />
- - {loginRequiredDialogIsOpen ? ( - <> - Log In To Download - - - You must log in with an Open Supply Hub account - before downloading your search results. - - - - - - - - - ) : ( -
- )} -
+ setLoginRequiredDialogIsOpen(false)} + /> ); } @@ -376,6 +297,7 @@ NonVectorTileFilterSidebarFacilitiesTab.defaultProps = { data: null, error: null, user: null, + isSameContributor: false, }; NonVectorTileFilterSidebarFacilitiesTab.propTypes = { @@ -388,6 +310,7 @@ NonVectorTileFilterSidebarFacilitiesTab.propTypes = { updateFilterText: func.isRequired, embed: bool.isRequired, user: userPropType, + isSameContributor: bool, }; function mapStateToProps({ @@ -411,6 +334,7 @@ function mapStateToProps({ filterText, windowHeight, embed: !!embed, + isSameContributor: get(data, 'is_same_contributor', false), }; } diff --git a/src/react/src/components/RouterLink.jsx b/src/react/src/components/RouterLink.jsx new file mode 100644 index 000000000..f84ccbf1a --- /dev/null +++ b/src/react/src/components/RouterLink.jsx @@ -0,0 +1,14 @@ +import React from 'react'; +import { Link } from 'react-router-dom'; +import { oneOfType, string, object } from 'prop-types'; + +const RouterLink = React.forwardRef((props, ref) => { + const { to, ...other } = props; + return ; +}); + +RouterLink.propTypes = { + to: oneOfType([string, object]).isRequired, +}; + +export default RouterLink; diff --git a/src/react/src/util/getTooltipForFacilitiesDownload.jsx b/src/react/src/util/getTooltipForFacilitiesDownload.jsx index 4f2fe029e..42119bdcd 100644 --- a/src/react/src/util/getTooltipForFacilitiesDownload.jsx +++ b/src/react/src/util/getTooltipForFacilitiesDownload.jsx @@ -13,6 +13,7 @@ const getTooltipForFacilitiesDownload = ({ upgrade, classes, facilitiesCount, + isSameContributor, }) => { const tooltipTexts = { availableDownloads: `Registered users can download up to ${FREE_FACILITIES_DOWNLOAD_LIMIT} production @@ -25,6 +26,8 @@ const getTooltipForFacilitiesDownload = ({ continue.`, anonymousUser: 'Log in or sign up to download this dataset.', embeddedOrPrivateInstance: `Downloads are supported for searches resulting in ${FACILITIES_DOWNLOAD_LIMIT} production locations or less.`, + sameContributor: + 'You are downloading data for the same contributor as your account. Downloading data for the same contributor is free.', }; // Determine base tooltip. @@ -32,6 +35,8 @@ const getTooltipForFacilitiesDownload = ({ if (isEmbedded || isPrivateInstance) { tooltipText = tooltipTexts.embeddedOrPrivateInstance; + } else if (isSameContributor) { + tooltipText = tooltipTexts.sameContributor; } else if (upgrade) { tooltipText = userAllowedRecords === 0 diff --git a/src/react/src/util/util.js b/src/react/src/util/util.js index e21ab8027..64cc73604 100644 --- a/src/react/src/util/util.js +++ b/src/react/src/util/util.js @@ -1739,7 +1739,9 @@ export const processDromoResults = ( return; } - const headers = Object.keys(results[0]); + const headers = Object.keys(results[0]).filter( + header => header !== 'is_same_contributor', + ); const csvRows = results.map(row => headers.map(header => formatCSVField(row[header])).join(','), );