Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions superset/views/redirects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional

from flask import flash, request, Response
from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
Expand All @@ -24,11 +27,22 @@
from superset.typing import FlaskResponse
from superset.views.base import BaseSupersetView

logger = logging.getLogger(__name__)


class R(BaseSupersetView): # pylint: disable=invalid-name

"""used for short urls"""

@staticmethod
def _validate_url(url: Optional[str] = None) -> bool:
if url and (
url.startswith("//superset/dashboard/")
or url.startswith("//superset/explore/")
):
return True
return False

@event_logger.log_this
@expose("/<int:url_id>")
def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use
Expand All @@ -38,8 +52,9 @@ def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if url.url.startswith(explore_url):
explore_url += f"r={url_id}"
return redirect(explore_url[1:])

return redirect(url.url[1:])
if self._validate_url(url.url):
return redirect(url.url[1:])
return redirect("/")

flash("URL to nowhere...", "danger")
return redirect("/")
Expand All @@ -49,6 +64,9 @@ def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use
@expose("/shortner/", methods=["POST"])
def shortner(self) -> FlaskResponse: # pylint: disable=no-self-use
url = request.form.get("data")
if not self._validate_url(url):
logger.warning("Invalid URL: %s", url)
return Response(f"Invalid URL: {url}", 400)
obj = models.Url(url=url)
db.session.add(obj)
db.session.commit()
Expand Down
22 changes: 22 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,28 @@ def test_shortner(self):
resp = self.client.post("/r/shortner/", data=dict(data=data))
assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8"))

def test_shortner_invalid(self):
self.login(username="admin")
invalid_urls = [
"hhttp://invalid.com",
"hhttps://invalid.com",
"www.invalid.com",
]
for invalid_url in invalid_urls:
resp = self.client.post("/r/shortner/", data=dict(data=invalid_url))
assert resp.status_code == 400

def test_redirect_invalid(self):
model_url = models.Url(url="hhttp://invalid.com")
db.session.add(model_url)
db.session.commit()

self.login(username="admin")
response = self.client.get(f"/r/{model_url.id}")
assert response.headers["Location"] == "http://localhost/"
db.session.delete(model_url)
db.session.commit()

@skipUnless(
(is_feature_enabled("KV_STORE")), "skipping as /kv/ endpoints are not enabled"
)
Expand Down