diff --git a/backend/device_registry/api_views.py b/backend/device_registry/api_views.py index 35ffa55f8..7dfde9701 100644 --- a/backend/device_registry/api_views.py +++ b/backend/device_registry/api_views.py @@ -9,6 +9,7 @@ from .models import Device, DeviceInfo, PortScan from device_registry.serializers import DeviceSerializer from django.db import IntegrityError +from django.shortcuts import get_object_or_404 from rest_framework import permissions from rest_framework import status from rest_framework.decorators import api_view, renderer_classes, permission_classes @@ -338,3 +339,20 @@ def mtls_renew_cert_view(request, format=None): 'certificate_expires': certificate_expires, 'claim_token': device_object.claim_token, }) + + +@api_view(['GET']) +@permission_classes((permissions.IsAuthenticated,)) +def claim_by_link(request): + params = request.query_params + device = get_object_or_404( + Device, + claim_token=params['claim-token'], + device_id=params['device-id'], + owner__isnull=True + ) + if device: + device.owner = request.user + device.save() + return Response(f'Device {device.device_id} claimed!') + return Response('Device not found', status=status.HTTP_404_NOT_FOUND) diff --git a/backend/device_registry/tests.py b/backend/device_registry/tests.py index 9f55b4dba..e883448f0 100644 --- a/backend/device_registry/tests.py +++ b/backend/device_registry/tests.py @@ -13,7 +13,7 @@ from django.utils import timezone from django.test import TestCase, RequestFactory from rest_framework.test import APIRequestFactory -from .api_views import mtls_ping_view +from .api_views import mtls_ping_view, claim_by_link from .models import Device, DeviceInfo, PortScan @@ -242,3 +242,29 @@ def test_active_inactive(self): def test_get_expiration_date(self): exp_date = self.device0.get_cert_expiration_date() self.assertEqual(exp_date.date(), datetime.date(2019, 4, 4)) + + +class ClaimLinkTest(TestCase): + def setUp(self): + self.api = RequestFactory() + self.device0 = Device.objects.create( + device_id='device0.d.wott-dev.local', + claim_token='token' + ) + self.user0 = User.objects.create_user('test') + + def test_claim_get_view(self): + request = self.api.get(f'/api/v0.2/claim-device/?device-id={self.device0.device_id}&claim-token={self.device0.claim_token}') + request.user = self.user0 + self.assertFalse(self.device0.claimed()) + response = claim_by_link(request) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, f'Device {self.device0.device_id} claimed!') + self.device0.refresh_from_db() + self.assertTrue(self.device0.claimed()) + + def test_claim_get_404(self): + request = self.api.get(f'/claim-device/?device-id=none&claim-token=none') + request.user = self.user0 + response = claim_by_link(request) + self.assertEqual(response.status_code, 404) diff --git a/backend/device_registry/urls.py b/backend/device_registry/urls.py index 082f82893..96cf04dec 100644 --- a/backend/device_registry/urls.py +++ b/backend/device_registry/urls.py @@ -28,6 +28,9 @@ path('api/{}/sign-csr'.format(api_version), api_views.sign_new_device_view, name='sign-device-cert'), + path('api/{}/claim-device'.format(api_version), + api_views.claim_by_link, + name='claim-by-link'), ] # Only load if mTLS