Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions marimo/_cli/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,52 @@ def get_github_src_url(url: str) -> str:
return f"https://raw.githubusercontent.com{path}"


def is_gist_src(url: str) -> bool:
if not is_url(url):
return False

hostname = urllib.parse.urlparse(url).hostname
if (
hostname != "gist.github.com"
and hostname != "gist.githubusercontent.com"
):
return False
return True


def get_gist_src_url(url: str) -> str:
# Return url if it's a direct link to a raw file,
# or get the first python or markdown file of the gist
# by getting the raw_url from api.github.com
path_parts = urllib.parse.urlparse(url).path.strip("/").split("/")
if "raw" in path_parts:
if not path_parts[-1].endswith((".py", ".md")):
raise ValueError("No python or markdown files found in the Gist")
return url
else:
gist_id = path_parts[1]
api_url = f"https://api.github.com/gists/{gist_id}"
api_response = requests.get(api_url, headers=USER_AGENT_HEADER)
api_response.raise_for_status()
gist_data = api_response.json()

files_dict = gist_data.get("files", {})
if not files_dict:
raise ValueError("No files found in the Gist")

py_or_md_url_generator = (
file_info["raw_url"]
for filename, file_info in files_dict.items()
if filename.lower().endswith((".py", ".md"))
)
raw_url = next(py_or_md_url_generator, "")

if raw_url == "":
raise ValueError("No python or markdown files found in the Gist")

return raw_url


class FileReader(abc.ABC):
@abc.abstractmethod
def can_read(self, name: str) -> bool:
Expand Down Expand Up @@ -178,6 +224,18 @@ def read(self, name: str) -> tuple[str, str]:
return content, os.path.basename(url)


class GistSourceReader(FileReader):
def can_read(self, name: str) -> bool:
return is_gist_src(name) or is_gist_src(name)

def read(self, name: str) -> tuple[str, str]:
url = get_gist_src_url(name)
response = requests.get(url, headers=USER_AGENT_HEADER)
response.raise_for_status()
content = response.text()
return content, os.path.basename(url)


class GenericURLReader(FileReader):
def can_read(self, name: str) -> bool:
return is_url(name)
Expand All @@ -198,6 +256,7 @@ def __init__(self) -> None:
GitHubIssueReader(),
StaticNotebookReader(),
GitHubSourceReader(),
GistSourceReader(),
GenericURLReader(),
]

Expand Down
122 changes: 121 additions & 1 deletion tests/_cli/test_file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import urllib.error
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, patch

import click
import pytest
Expand All @@ -13,10 +13,13 @@
from marimo._cli.file_path import (
FileContentReader,
GenericURLReader,
GistSourceReader,
GitHubIssueReader,
GitHubSourceReader,
LocalFileReader,
StaticNotebookReader,
get_gist_src_url,
is_gist_src,
is_github_src,
validate_name,
)
Expand Down Expand Up @@ -182,6 +185,7 @@ def test_file_content_reader() -> None:
StaticNotebookReader, "read"
) as mock_static_notebook_read,
patch.object(GitHubSourceReader, "read") as mock_github_source_read,
patch.object(GistSourceReader, "read") as mock_gist_source_read,
patch.object(GenericURLReader, "read") as mock_generic_url_read,
):
mock_local_read.return_value = ("local content", "local.py")
Expand All @@ -191,6 +195,7 @@ def test_file_content_reader() -> None:
"notebook.py",
)
mock_github_source_read.return_value = ("github content", "github.py")
mock_gist_source_read.return_value = ("gist content", "gist.py")
mock_generic_url_read.return_value = ("url content", "url.py")

assert reader.read_file("local.py") == ("local content", "local.py")
Expand All @@ -200,6 +205,9 @@ def test_file_content_reader() -> None:
assert reader.read_file(
"https://github.com/marimo-team/marimo/blob/main/example.py"
) == ("github content", "github.py")
assert reader.read_file(
"https://gist.github.com/marimo-team/gist_id/example.py"
) == ("gist content", "gist.py")
assert reader.read_file("https://example.com/file.py") == (
"url content",
"url.py",
Expand Down Expand Up @@ -459,3 +467,115 @@ def test_generic_url_reader_with_edge_case_filenames(
content, filename = reader.read(url)
assert content == "print('Hello, world!')"
assert filename == expected_filename


@pytest.mark.parametrize(
("url", "expected"),
[
("https://gist.github.com/user/12345", True),
("https://gist.githubusercontent.com/user/123/raw/file.py", True),
("https://github.com/marimo-team/marimo", False),
],
)
def test_is_gist_src(url: str, expected: bool) -> None:
assert is_gist_src(url) == expected


def test_get_gist_src_url_success():
mock_response = Mock()
mock_response.json.return_value = {
"files": {
"config.txt": {
"filename": "config.txt",
"raw_url": "https://gist.githubusercontent.com/user/id/raw/rev/config.txt",
},
"notebook.py": {
"filename": "notebook.py",
"raw_url": "https://gist.githubusercontent.com/user/id/raw/rev/notebook.py",
},
}
}
mock_response.raise_for_status = Mock()

with patch("marimo._utils.requests.get", return_value=mock_response):
url = "https://gist.github.com/user/id"
expected_raw_url = (
"https://gist.githubusercontent.com/user/id/raw/rev/notebook.py"
)
assert get_gist_src_url(url) == expected_raw_url


def test_get_gist_src_url_no_files_found():
mock_response = Mock()
mock_response.json.return_value = {"files": {}}
mock_response.raise_for_status = Mock()

with patch("marimo._utils.requests.get", return_value=mock_response):
with pytest.raises(ValueError, match="No files found in the Gist"):
get_gist_src_url("https://gist.github.com/user/id")


def test_get_gist_src_url_no_matching_files():
mock_response = Mock()
mock_response.json.return_value = {
"files": {
"config.txt": {
"filename": "config.txt",
"raw_url": "https://gist.githubusercontent.com/user/id/raw/rev/config.txt",
}
}
}
mock_response.raise_for_status = Mock()

with patch("marimo._utils.requests.get", return_value=mock_response):
with pytest.raises(
ValueError, match="No python or markdown files found in the Gist"
):
get_gist_src_url("https://gist.github.com/user/id")


def test_gist_source_reader() -> None:
reader = GistSourceReader()
valid_url = "https://gist.github.com/user/12345"
invalid_url = "https://github.com/marimo-team/marimo"

assert reader.can_read(valid_url) is True
assert reader.can_read(invalid_url) is False

expected_raw_url = (
"https://gist.githubusercontent.com/user/id/raw/rev/notebook.py"
)
expected_content = "print('Hello from Gist')"

mock_api_response = Mock()
mock_api_response.json.return_value = {
"files": {
"notebook.py": {
"filename": "notebook.py",
"raw_url": expected_raw_url,
},
}
}
mock_api_response.raise_for_status = Mock()

mock_content_response = Response(
200,
expected_content.encode("utf-8"),
{},
)

with patch("marimo._utils.requests.get") as mock_get:
mock_get.side_effect = [mock_api_response, mock_content_response]

content, filename = reader.read(valid_url)
assert content == expected_content
assert filename == "notebook.py"

# Verify that requests.get was called with the correct URLs
# First call is to the gist API
# Second call is to the raw content URL
assert mock_get.call_count == 2
api_call_args, _ = mock_get.call_args_list[0]
assert "https://api.github.com/gists/12345" in api_call_args[0]
content_call_args, _ = mock_get.call_args_list[1]
assert content_call_args[0] == expected_raw_url
Loading