diff --git a/superset/views/redirects.py b/superset/views/redirects.py index 02dc587e71f5..1be79b6461a2 100644 --- a/superset/views/redirects.py +++ b/superset/views/redirects.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging +from typing import Optional + from flask import flash, request, Response from flask_appbuilder import expose from flask_appbuilder.security.decorators import has_access_api @@ -24,11 +27,22 @@ from superset.typing import FlaskResponse from superset.views.base import BaseSupersetView +logger = logging.getLogger(__name__) + class R(BaseSupersetView): # pylint: disable=invalid-name """used for short urls""" + @staticmethod + def _validate_url(url: Optional[str] = None) -> bool: + if url and ( + url.startswith("//superset/dashboard/") + or url.startswith("//superset/explore/") + ): + return True + return False + @event_logger.log_this @expose("/") def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use @@ -38,8 +52,9 @@ def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use if url.url.startswith(explore_url): explore_url += f"r={url_id}" return redirect(explore_url[1:]) - - return redirect(url.url[1:]) + if self._validate_url(url.url): + return redirect(url.url[1:]) + return redirect("/") flash("URL to nowhere...", "danger") return redirect("/") @@ -49,6 +64,9 @@ def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use @expose("/shortner/", methods=["POST"]) def shortner(self) -> FlaskResponse: # pylint: disable=no-self-use url = request.form.get("data") + if not self._validate_url(url): + logger.warning("Invalid URL: %s", url) + return Response(f"Invalid URL: {url}", 400) obj = models.Url(url=url) db.session.add(obj) db.session.commit() diff --git a/tests/core_tests.py b/tests/core_tests.py index 3bc230a19406..111964aebaa7 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -634,6 +634,28 @@ def test_shortner(self): resp = self.client.post("/r/shortner/", data=dict(data=data)) assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8")) + def test_shortner_invalid(self): + self.login(username="admin") + invalid_urls = [ + "hhttp://invalid.com", + "hhttps://invalid.com", + "www.invalid.com", + ] + for invalid_url in invalid_urls: + resp = self.client.post("/r/shortner/", data=dict(data=invalid_url)) + assert resp.status_code == 400 + + def test_redirect_invalid(self): + model_url = models.Url(url="hhttp://invalid.com") + db.session.add(model_url) + db.session.commit() + + self.login(username="admin") + response = self.client.get(f"/r/{model_url.id}") + assert response.headers["Location"] == "http://localhost/" + db.session.delete(model_url) + db.session.commit() + @skipUnless( (is_feature_enabled("KV_STORE")), "skipping as /kv/ endpoints are not enabled" )