diff --git a/docs/configuration/command.md b/docs/configuration/command.md index 26df81221..5ecadfdec 100644 --- a/docs/configuration/command.md +++ b/docs/configuration/command.md @@ -28,6 +28,27 @@ UNFOLD = { Command results use infinite scrolling with a default page size of 100 results. When the last item becomes visible in the viewport, a new page of results is automatically loaded and appended to the existing list, allowing continuous browsing through search results. +## Search only specific models + +- `search_models` accepts `list` or `tuple` of allowed models which can be searched + +```python +UNFOLD = { + # ... + "COMMAND": { + "search_models": ["example.sample"], # List or tuple + # "search_models": "example.utils.search_models_callback" + }, + # ... +} + +# utils.py +def search_models_callback(request): + return [ + "example.sample", + ] +``` + ## Custom search callback The search callback feature provides a way to define a custom hook that can inject additional content into search results. This is particularly useful when you want to search for results from external sources or services beyond the Django admin interface. diff --git a/src/unfold/sites.py b/src/unfold/sites.py index 37018cf62..e03fdf387 100644 --- a/src/unfold/sites.py +++ b/src/unfold/sites.py @@ -210,12 +210,23 @@ def _search_apps( return results def _search_models( - self, request: HttpRequest, app_list: list[dict[str, Any]], search_term: str + self, + request: HttpRequest, + app_list: list[dict[str, Any]], + search_term: str, + allowed_models: Optional[list[str]] = None, ) -> list[SearchResult]: results = [] for app in app_list: for model in app["models"]: + # Skip models which are not allowed + if isinstance(allowed_models, (list, tuple)): + if model["model"]._meta.label.lower() not in [ + m.lower() for m in allowed_models + ]: + continue + admin_instance = self._registry.get(model["model"]) search_fields = admin_instance.get_search_fields(request) @@ -278,26 +289,41 @@ def search( results = cache_results else: results = self._search_apps(app_list, search_term) - search_models = self._get_config("COMMAND", request).get("search_models") - search_callback = self._get_config("COMMAND", request).get( - "search_callback" - ) if extended_search: - if search_callback: + if search_callback := self._get_config("COMMAND", request).get( + "search_callback" + ): results.extend( self._get_value(search_callback, request, search_term) ) - if search_models is True: - results.extend(self._search_models(request, app_list, search_term)) + search_models = self._get_value( + self._get_config("COMMAND", request).get("search_models"), request + ) + + if search_models is True or isinstance(search_models, (list, tuple)): + allowed_models = ( + search_models + if isinstance(search_models, (list, tuple)) + else None + ) + + results.extend( + self._search_models( + request, app_list, search_term, allowed_models + ) + ) cache.set(cache_key, results, timeout=CACHE_TIMEOUT) execution_time = time.time() - start_time - paginator = Paginator(results, PER_PAGE) + show_history = self._get_value( + self._get_config("COMMAND", request).get("show_history"), request + ) + return TemplateResponse( request, template=template_name, @@ -306,9 +332,7 @@ def search( "results": paginator.page(request.GET.get("page", 1)), "page_counter": (int(request.GET.get("page", 1)) - 1) * PER_PAGE, "execution_time": execution_time, - "command_show_history": self._get_config("COMMAND", request).get( - "show_history" - ), + "command_show_history": show_history, }, headers={ "HX-Trigger": "search", diff --git a/src/unfold/static/unfold/js/app.js b/src/unfold/static/unfold/js/app.js index a753147be..8f20c2977 100644 --- a/src/unfold/static/unfold/js/app.js +++ b/src/unfold/static/unfold/js/app.js @@ -137,14 +137,25 @@ function searchCommand() { return; } - this.items = document - .getElementById("command-results-list") - .querySelectorAll("li"); - this.totalItems = this.items.length; + const commandResultsList = document.getElementById( + "command-results-list" + ); + if (commandResultsList) { + this.items = commandResultsList.querySelectorAll("li"); + this.totalItems = this.items.length; + } else { + this.items = undefined; + this.totalItems = 0; + } if (event.target.id === "command-results") { this.currentIndex = 0; - this.totalItems = this.items.length; + + if (this.items) { + this.totalItems = this.items.length; + } else { + this.totalItems = 0; + } } this.hasResults = this.totalItems > 0; diff --git a/tests/test_command.py b/tests/test_command.py index e747d1bae..8aeaa9b51 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -108,3 +108,71 @@ def test_command_search_extended_model_with_permission( ) assert response.status_code == HTTPStatus.OK assert "sample-test-tag-with-permission" in response.content.decode() + + +@pytest.mark.django_db +@override_settings( + CACHES={"default": {"BACKEND": "django.core.cache.backends.dummy.DummyCache"}} +) +def test_command_allowed_models(admin_client, admin_user, tag_factory): + tag_factory(name="another-test-tag") + + with override_settings( + UNFOLD={ + **CONFIG_DEFAULTS, + **{ + "COMMAND": { + "search_models": False, + } + }, + } + ): + response = admin_client.get( + reverse("admin:search") + "?s=another-test-tag&extended=1" + ) + assert "another-test-tag" not in response.content.decode() + + with override_settings( + UNFOLD={ + **CONFIG_DEFAULTS, + **{ + "COMMAND": { + "search_models": True, + } + }, + } + ): + response = admin_client.get( + reverse("admin:search") + "?s=another-test-tag&extended=1" + ) + assert "another-test-tag" in response.content.decode() + + with override_settings( + UNFOLD={ + **CONFIG_DEFAULTS, + **{ + "COMMAND": { + "search_models": [], + } + }, + } + ): + response = admin_client.get( + reverse("admin:search") + "?s=another-test-tag&extended=1" + ) + assert "another-test-tag" not in response.content.decode() + + with override_settings( + UNFOLD={ + **CONFIG_DEFAULTS, + **{ + "COMMAND": { + "search_models": ["example.tag"], + } + }, + } + ): + response = admin_client.get( + reverse("admin:search") + "?s=another-test-tag&extended=1" + ) + assert "another-test-tag" in response.content.decode()