From bd0ed115bb355eb3d9eb2517c1613a5a91e6982d Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Wed, 29 Jan 2025 15:33:43 -0500 Subject: [PATCH 1/5] feat: add mo.ui.app_meta().request --- marimo/_messaging/context.py | 23 +++ marimo/_plugins/ui/_impl/chat/chat.py | 1 + marimo/_plugins/ui/_impl/tables/format.py | 53 ++++--- marimo/_pyodide/bootstrap.py | 8 +- marimo/_runtime/app/kernel_runner.py | 2 +- marimo/_runtime/app_meta.py | 15 ++ marimo/_runtime/context/types.py | 7 + marimo/_runtime/requests.py | 121 ++++++++++++++- marimo/_runtime/runtime.py | 18 ++- .../utils/set_ui_element_request_manager.py | 10 +- marimo/_server/api/endpoints/execution.py | 12 +- marimo/_server/export/__init__.py | 5 +- marimo/_server/models/models.py | 16 +- marimo/_server/sessions.py | 16 +- marimo/_smoke_tests/requests.py | 43 ++++++ marimo/_utils/lists.py | 12 ++ tests/_messaging/test_context_vars.py | 99 ++++++++++++ .../ui/_impl/tables/test_table_utils.py | 9 +- tests/_runtime/test_requests.py | 144 ++++++++++++++++++ tests/_server/api/endpoints/test_execution.py | 100 +++++++++++- 20 files changed, 660 insertions(+), 54 deletions(-) create mode 100644 marimo/_smoke_tests/requests.py create mode 100644 marimo/_utils/lists.py create mode 100644 tests/_messaging/test_context_vars.py create mode 100644 tests/_runtime/test_requests.py diff --git a/marimo/_messaging/context.py b/marimo/_messaging/context.py index 68c7d134394..3a5556ec2c6 100644 --- a/marimo/_messaging/context.py +++ b/marimo/_messaging/context.py @@ -1,12 +1,18 @@ # Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + import uuid from contextvars import ContextVar from dataclasses import dataclass from typing import Any, Optional +from marimo._runtime.requests import HTTPRequest + RunId_t = str RUN_ID_CTX = ContextVar[Optional[RunId_t]]("run_id") +HTTP_REQUEST_CTX = ContextVar[Optional[HTTPRequest]]("http_request") + @dataclass class run_id_context: @@ -22,3 +28,20 @@ def __enter__(self) -> None: def __exit__(self, *_: Any) -> None: RUN_ID_CTX.reset(self.token) + + +@dataclass +class http_request_context: + """Context manager for setting and unsetting the HTTP request.""" + + request: Optional[HTTPRequest] + + def __init__(self, request: Optional[HTTPRequest]) -> None: + assert request is not None + self.request = request + + def __enter__(self) -> None: + self.token = HTTP_REQUEST_CTX.set(self.request) + + def __exit__(self, *_: Any) -> None: + HTTP_REQUEST_CTX.reset(self.token) diff --git a/marimo/_plugins/ui/_impl/chat/chat.py b/marimo/_plugins/ui/_impl/chat/chat.py index 3cbcefc9060..a988ad5159b 100644 --- a/marimo/_plugins/ui/_impl/chat/chat.py +++ b/marimo/_plugins/ui/_impl/chat/chat.py @@ -220,6 +220,7 @@ async def _send_prompt(self, args: SendMessageRequest) -> str: SetUIElementValueRequest( object_ids=[self._id], values=[{"messages": self._chat_history}], + request=None, ) ) diff --git a/marimo/_plugins/ui/_impl/tables/format.py b/marimo/_plugins/ui/_impl/tables/format.py index 7ba700df2ad..ffce47e6d1e 100644 --- a/marimo/_plugins/ui/_impl/tables/format.py +++ b/marimo/_plugins/ui/_impl/tables/format.py @@ -14,34 +14,37 @@ def format_value( if format_mapping is None: return value - if value is None: - if col in format_mapping: - formatter = format_mapping[col] - if callable(formatter): - return formatter(value) + if col not in format_mapping: return value - if col in format_mapping: - formatter = format_mapping[col] - try: - if isinstance(formatter, str): - # Handle numeric formatting specially to preserve signs and separators - if isinstance(value, (int, float)): - # Keep integers as integers for 'd' format specifier - if isinstance(value, int) and "d" in formatter: - return formatter.format(value) - # Convert to float for float formatting - return formatter.format(float(value)) - return formatter.format(value) - if callable(formatter): - return formatter(value) - except Exception as e: - import logging + formatter = format_mapping[col] + + # If the value is None, we don't want to format it + # with strings for formatting, but we do want to + # format it with callables. + if value is None and isinstance(formatter, str): + return value + + try: + if isinstance(formatter, str): + # Handle numeric formatting specially to preserve signs and separators + if isinstance(value, (int, float)): + # Keep integers as integers for 'd' format specifier + if isinstance(value, int) and "d" in formatter: + return formatter.format(value) + # Convert to float for float formatting + return formatter.format(float(value)) + return formatter.format(value) + if callable(formatter): + return formatter(value) + except Exception as e: + import logging + + logging.warning( + f"Error formatting for value {value} in column {col}: {str(e)}" + ) + return value - logging.warning( - f"Error formatting for value {value} in column {col}: {str(e)}" - ) - return value return value diff --git a/marimo/_pyodide/bootstrap.py b/marimo/_pyodide/bootstrap.py index 504595c87f3..5f5527f64aa 100644 --- a/marimo/_pyodide/bootstrap.py +++ b/marimo/_pyodide/bootstrap.py @@ -35,7 +35,11 @@ def instantiate( app = session.app_manager.app execution_requests = tuple( - ExecutionRequest(cell_id=cell_data.cell_id, code=cell_data.code) + ExecutionRequest( + cell_id=cell_data.cell_id, + code=cell_data.code, + request=None, + ) for cell_data in app.cell_manager.cell_data() ) @@ -43,7 +47,7 @@ def instantiate( CreationRequest( execution_requests=execution_requests, set_ui_element_value_request=SetUIElementValueRequest( - object_ids=[], values=[] + object_ids=[], values=[], request=None ), auto_run=auto_instantiate, ) diff --git a/marimo/_runtime/app/kernel_runner.py b/marimo/_runtime/app/kernel_runner.py index 0d580af203d..d5c2735ab9c 100644 --- a/marimo/_runtime/app/kernel_runner.py +++ b/marimo/_runtime/app/kernel_runner.py @@ -110,7 +110,7 @@ def globals(self) -> dict[CellId_t, Any]: async def run(self, cells_to_run: set[CellId_t]) -> RunOutput: execution_requests = [ - ExecutionRequest(cell_id=cid, code=cell._cell.code) + ExecutionRequest(cell_id=cid, code=cell._cell.code, request=None) for cid in cells_to_run if (cell := self.app.cell_manager.cell_data_at(cid).cell) is not None diff --git a/marimo/_runtime/app_meta.py b/marimo/_runtime/app_meta.py index 63e94041abc..944211ae4da 100644 --- a/marimo/_runtime/app_meta.py +++ b/marimo/_runtime/app_meta.py @@ -10,6 +10,7 @@ get_context, ) from marimo._runtime.context.utils import RunMode, get_mode +from marimo._runtime.requests import HTTPRequest @mddoc @@ -75,3 +76,17 @@ def mode(self) -> Optional[RunMode]: - None: The mode could not be determined """ return get_mode() + + @property + def request(self) -> Optional[HTTPRequest]: + """ + The current HTTP request if any. + + Returns: + Optional[Request]: The current request object if available, None otherwise. + """ + try: + context = get_context() + return context.request + except ContextNotInitializedError: + return None diff --git a/marimo/_runtime/context/types.py b/marimo/_runtime/context/types.py index 0a87692abcf..827a017e4d0 100644 --- a/marimo/_runtime/context/types.py +++ b/marimo/_runtime/context/types.py @@ -13,10 +13,12 @@ from typing import TYPE_CHECKING, Any, Iterator, Optional from marimo._config.config import MarimoConfig +from marimo._messaging.context import HTTP_REQUEST_CTX from marimo._messaging.types import Stderr, Stdout from marimo._runtime import dataflow from marimo._runtime.cell_lifecycle_registry import CellLifecycleRegistry from marimo._runtime.functions import FunctionRegistry +from marimo._runtime.requests import HTTPRequest if TYPE_CHECKING: from marimo._ast.app import InternalApp @@ -100,6 +102,11 @@ def marimo_config(self) -> MarimoConfig: """ pass + @property + def request(self) -> Optional[HTTPRequest]: + """Get the current request context if any.""" + return HTTP_REQUEST_CTX.get(None) + @property @abc.abstractmethod def cell_id(self) -> Optional[CellId_t]: diff --git a/marimo/_runtime/requests.py b/marimo/_runtime/requests.py index d2396e3029e..f7a972cf234 100644 --- a/marimo/_runtime/requests.py +++ b/marimo/_runtime/requests.py @@ -2,14 +2,30 @@ from __future__ import annotations import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) from uuid import uuid4 from marimo._ast.cell import CellId_t from marimo._config.config import MarimoConfig from marimo._data.models import DataTableSource +if TYPE_CHECKING: + from starlette.datastructures import URL + from starlette.requests import HTTPConnection + UIElementId = str CompletionRequestId = str FunctionCallId = str @@ -21,10 +37,90 @@ SerializedCLIArgs = Dict[str, ListOrValue[Primitive]] +@dataclass +class HTTPRequest(Mapping[str, Any]): + """ + A class that mimics the Request object from Starlette or FastAPI. + + It is a subset and pickle-able version of the Request object. + """ + + url: dict[str, Any] # Serialized URL + base_url: dict[str, Any] # Serialized URL + headers: dict[str, str] # Raw headers + query_params: dict[str, list[str]] # Raw query params + path_params: dict[str, Any] + cookies: dict[str, str] + user: Any + + # We don't include session or auth because they may contain + # information that the app author does not want to expose. + + # session: dict[str, Any] + # auth: Any + + def __getitem__(self, key: str) -> Any: + return self.__dict__[key] + + def __iter__(self) -> Iterator[str]: + return iter(self.__dict__) + + def __len__(self) -> int: + return len(self.__dict__) + + def _display_(self) -> Any: + return asdict(self) + + @staticmethod + def from_request(request: HTTPConnection) -> "HTTPRequest": + def _url_to_dict(url: URL) -> dict[str, Any]: + return { + "path": url.path, + "port": url.port, + "scheme": url.scheme, + "netloc": url.netloc, + "query": url.query, + "hostname": url.hostname, + } + + # Convert URL to dict + url_dict = _url_to_dict(request.url) + + # Convert base_url to dict + base_url_dict = _url_to_dict(request.base_url) + + # Convert query params to dict[str, list[str]] + query_params: dict[str, list[str]] = defaultdict(list) + for k, v in request.query_params.multi_items(): + query_params[k].append(str(v)) + + # Convert headers to dict, remove all marimo-specific headers + headers: dict[str, str] = {} + for k, v in request.headers.items(): + if not k.startswith(("marimo", "x-marimo")): + headers[k] = v + + return HTTPRequest( + url=url_dict, + base_url=base_url_dict, + headers=headers, + query_params=query_params, + path_params=request.path_params, + cookies=request.cookies, + user=request["user"] if "user" in request else {}, + # Left out for now. This may contain information that the app author + # does not want to expose. + # session=request.session if "session" in request else {}, + # auth=request.auth if "auth" in request else {}, + ) + + @dataclass class ExecutionRequest: cell_id: CellId_t code: str + # incoming request, e.g. from Starlette or FastAPI + request: Optional[HTTPRequest] timestamp: float = field(default_factory=time.time) @@ -38,6 +134,8 @@ class ExecuteMultipleRequest: cell_ids: List[CellId_t] # code to register/run for each cell codes: List[str] + # incoming request, e.g. from Starlette or FastAPI + request: Optional[Any] # time at which the request was received timestamp: float = field(default_factory=time.time) @@ -45,7 +143,10 @@ class ExecuteMultipleRequest: def execution_requests(self) -> List[ExecutionRequest]: return [ ExecutionRequest( - cell_id=cell_id, code=code, timestamp=self.timestamp + cell_id=cell_id, + code=code, + request=self.request, + timestamp=self.timestamp, ) for cell_id, code in zip(self.cell_ids, self.codes) ] @@ -59,6 +160,8 @@ def __post_init__(self) -> None: @dataclass class ExecuteScratchpadRequest: code: str + # incoming request, e.g. from Starlette or FastAPI + request: Optional[Any] @dataclass @@ -70,6 +173,8 @@ class RenameRequest: class SetUIElementValueRequest: object_ids: List[UIElementId] values: List[Any] + # Incoming request, e.g. from Starlette or FastAPI + request: Optional[HTTPRequest] # uniquely identifies the request token: str = field(default_factory=lambda: str(uuid4())) @@ -81,12 +186,17 @@ def __post_init__(self) -> None: @staticmethod def from_ids_and_values( ids_and_values: List[Tuple[UIElementId, Any]], + request: Optional[HTTPRequest] = None, ) -> SetUIElementValueRequest: if not ids_and_values: - return SetUIElementValueRequest(object_ids=[], values=[]) + return SetUIElementValueRequest( + object_ids=[], values=[], request=request + ) object_ids, values = zip(*ids_and_values) return SetUIElementValueRequest( - object_ids=list(object_ids), values=list(values) + object_ids=list(object_ids), + values=list(values), + request=request, ) @property @@ -129,6 +239,7 @@ class CreationRequest: execution_requests: Tuple[ExecutionRequest, ...] set_ui_element_value_request: SetUIElementValueRequest auto_run: bool + request: Optional[HTTPRequest] = None @dataclass diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index ba211c442c4..ae0e8fac038 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -31,7 +31,7 @@ ) from marimo._dependencies.dependencies import DependencyManager from marimo._messaging.cell_output import CellChannel -from marimo._messaging.context import run_id_context +from marimo._messaging.context import http_request_context, run_id_context from marimo._messaging.errors import ( Error, MarimoInterruptionError, @@ -1593,7 +1593,9 @@ async def set_ui_element_value( child_context.app is not None and await child_context.app.set_ui_element_value( SetUIElementValueRequest( - object_ids=[object_id], values=[value] + object_ids=[object_id], + values=[value], + request=request.request, ) ) ): @@ -2043,13 +2045,16 @@ async def handle_message(self, request: ControlRequest) -> None: with self.lock_globals(): LOGGER.debug("Handling control request: %s", request) if isinstance(request, CreationRequest): - await self.instantiate(request) + with http_request_context(request.request): + await self.instantiate(request) CompletedRun().broadcast() elif isinstance(request, ExecuteMultipleRequest): - await self.run(request.execution_requests) + with http_request_context(request.request): + await self.run(request.execution_requests) CompletedRun().broadcast() elif isinstance(request, ExecuteScratchpadRequest): - await self.run_scratchpad(request.code) + with http_request_context(request.request): + await self.run_scratchpad(request.code) elif isinstance(request, ExecuteStaleRequest): await self.run_stale_cells() elif isinstance(request, RenameRequest): @@ -2059,7 +2064,8 @@ async def handle_message(self, request: ControlRequest) -> None: elif isinstance(request, SetUserConfigRequest): self.set_user_config(request) elif isinstance(request, SetUIElementValueRequest): - await self.set_ui_element_value(request) + with http_request_context(request.request): + await self.set_ui_element_value(request) CompletedRun().broadcast() elif isinstance(request, FunctionCallRequest): status, ret, _ = await self.function_call_request(request) diff --git a/marimo/_runtime/utils/set_ui_element_request_manager.py b/marimo/_runtime/utils/set_ui_element_request_manager.py index beace666ac9..6f315a92fc4 100644 --- a/marimo/_runtime/utils/set_ui_element_request_manager.py +++ b/marimo/_runtime/utils/set_ui_element_request_manager.py @@ -14,8 +14,10 @@ class SetUIElementRequestManager: def __init__( self, - set_ui_element_queue: QueueType[SetUIElementValueRequest] - | asyncio.Queue[SetUIElementValueRequest], + set_ui_element_queue: ( + QueueType[SetUIElementValueRequest] + | asyncio.Queue[SetUIElementValueRequest] + ), ) -> None: self._set_ui_element_queue = set_ui_element_queue self._processed_request_tokens: set[str] = set() @@ -51,8 +53,10 @@ def _merge_set_ui_element_requests( for request in requests: for ui_id, value in request.ids_and_values: merged[ui_id] = value + last_request = requests[-1] return SetUIElementValueRequest( object_ids=list(merged.keys()), values=list(merged.values()), - token="", + token=last_request.token, + request=last_request.request, ) diff --git a/marimo/_server/api/endpoints/execution.py b/marimo/_server/api/endpoints/execution.py index e8116bf3425..3ee0f86fbc9 100644 --- a/marimo/_server/api/endpoints/execution.py +++ b/marimo/_server/api/endpoints/execution.py @@ -12,6 +12,7 @@ from marimo._messaging.ops import Alert from marimo._runtime.requests import ( FunctionCallRequest, + HTTPRequest, SetUIElementValueRequest, ) from marimo._server.api.deps import AppState @@ -62,7 +63,10 @@ async def set_ui_element_values( body = await parse_request(request, cls=UpdateComponentValuesRequest) app_state.require_current_session().put_control_request( SetUIElementValueRequest( - object_ids=body.object_ids, values=body.values, token=str(uuid4()) + object_ids=body.object_ids, + values=body.values, + token=str(uuid4()), + request=HTTPRequest.from_request(request), ), from_consumer_id=ConsumerId(app_state.require_current_session_id()), ) @@ -91,7 +95,10 @@ async def instantiate( """ app_state = AppState(request) body = await parse_request(request, cls=InstantiateRequest) - app_state.require_current_session().instantiate(body) + app_state.require_current_session().instantiate( + body, + http_request=HTTPRequest.from_request(request), + ) return SuccessResponse() @@ -168,6 +175,7 @@ async def run_cell( """ # noqa: E501 app_state = AppState(request) body = await parse_request(request, cls=RunRequest) + body.request = HTTPRequest.from_request(request) app_state.require_current_session().put_control_request( body.as_execution_request(), from_consumer_id=ConsumerId(app_state.require_current_session_id()), diff --git a/marimo/_server/export/__init__.py b/marimo/_server/export/__init__.py index e7257ddf9a7..fc2cf567347 100644 --- a/marimo/_server/export/__init__.py +++ b/marimo/_server/export/__init__.py @@ -282,7 +282,10 @@ def connection_state(self) -> ConnectionState: ) # Run the notebook to completion once - session.instantiate(InstantiateRequest(object_ids=[], values=[])) + session.instantiate( + InstantiateRequest(object_ids=[], values=[]), + http_request=None, + ) await instantiated_event.wait() # Process console messages # diff --git a/marimo/_server/models/models.py b/marimo/_server/models/models.py index 7c9e4dbe81a..3c32fefaf37 100644 --- a/marimo/_server/models/models.py +++ b/marimo/_server/models/models.py @@ -10,6 +10,7 @@ from marimo._runtime.requests import ( ExecuteMultipleRequest, ExecuteScratchpadRequest, + HTTPRequest, RenameRequest, ) @@ -83,9 +84,15 @@ class RunRequest: cell_ids: List[CellId_t] # code to register/run for each cell codes: List[str] + # incoming request, e.g. from Starlette or FastAPI + request: Optional[HTTPRequest] = None def as_execution_request(self) -> ExecuteMultipleRequest: - return ExecuteMultipleRequest(cell_ids=self.cell_ids, codes=self.codes) + return ExecuteMultipleRequest( + cell_ids=self.cell_ids, + codes=self.codes, + request=self.request, + ) # Validate same length def __post_init__(self) -> None: @@ -97,9 +104,14 @@ def __post_init__(self) -> None: @dataclass class RunScratchpadRequest: code: str + # incoming request, e.g. from Starlette or FastAPI + request: Optional[Any] = None def as_execution_request(self) -> ExecuteScratchpadRequest: - return ExecuteScratchpadRequest(code=self.code) + return ExecuteScratchpadRequest( + code=self.code, + request=self.request, + ) @dataclass diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index a6542b53600..65dd8eee42c 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -51,6 +51,7 @@ CreationRequest, ExecuteMultipleRequest, ExecutionRequest, + HTTPRequest, SerializedCLIArgs, SerializedQueryParams, SetUIElementValueRequest, @@ -651,10 +652,19 @@ def close(self) -> None: self.heartbeat_task.cancel() self.kernel_manager.close_kernel() - def instantiate(self, request: InstantiateRequest) -> None: + def instantiate( + self, + request: InstantiateRequest, + *, + http_request: Optional[HTTPRequest], + ) -> None: """Instantiate the app.""" execution_requests = tuple( - ExecutionRequest(cell_id=cell_data.cell_id, code=cell_data.code) + ExecutionRequest( + cell_id=cell_data.cell_id, + code=cell_data.code, + request=http_request, + ) for cell_data in self.app_file_manager.app.cell_manager.cell_data() ) @@ -665,8 +675,10 @@ def instantiate(self, request: InstantiateRequest) -> None: object_ids=request.object_ids, values=request.values, token=str(uuid4()), + request=http_request, ), auto_run=request.auto_run, + request=http_request, ), from_consumer_id=None, ) diff --git a/marimo/_smoke_tests/requests.py b/marimo/_smoke_tests/requests.py new file mode 100644 index 00000000000..06afb2ba9ce --- /dev/null +++ b/marimo/_smoke_tests/requests.py @@ -0,0 +1,43 @@ +import marimo + +__generated_with = "0.10.17" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import marimo as mo + return (mo,) + + +@app.cell +def _(mo): + refresh = mo.ui.refresh(default_interval="3s") + refresh + return (refresh,) + + +@app.cell +def _(mo, refresh): + refresh + user = mo.app_meta().request.user + [user, user.is_authenticated, user.display_name] + return (user,) + + +@app.cell +def _(mo, refresh): + refresh + list(mo.app_meta().request.keys()) + return + + +@app.cell +def _(mo, refresh): + refresh + mo.app_meta().request + return + + +if __name__ == "__main__": + app.run() diff --git a/marimo/_utils/lists.py b/marimo/_utils/lists.py new file mode 100644 index 00000000000..ed538b4f893 --- /dev/null +++ b/marimo/_utils/lists.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from typing import Iterable, TypeVar, Union + +T = TypeVar("T") + + +def first(iterable: Union[Iterable[T], T]) -> T: + if isinstance(iterable, Iterable): + return next(iter(iterable)) # type: ignore[no-any-return] + else: + return iterable diff --git a/tests/_messaging/test_context_vars.py b/tests/_messaging/test_context_vars.py new file mode 100644 index 00000000000..9ef51c75d89 --- /dev/null +++ b/tests/_messaging/test_context_vars.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import uuid + +import pytest + +from marimo._messaging.context import ( + HTTP_REQUEST_CTX, + RUN_ID_CTX, + http_request_context, + run_id_context, +) +from marimo._runtime.requests import HTTPRequest + + +class TestRunIDContext: + def test_run_id_is_uuid(self): + with run_id_context(): + run_id = RUN_ID_CTX.get() + # Verify it's a valid UUID string + assert uuid.UUID(run_id) + + def test_nested_contexts(self): + with run_id_context(): + outer_id = RUN_ID_CTX.get() + with run_id_context(): + inner_id = RUN_ID_CTX.get() + assert inner_id != outer_id + + # Verify we're back to outer context + assert RUN_ID_CTX.get() == outer_id + + def test_context_cleanup(self): + # Verify context is None before entering + with pytest.raises(LookupError): + RUN_ID_CTX.get() + + with run_id_context(): + assert RUN_ID_CTX.get() is not None + + # Verify context is cleaned up after exiting + with pytest.raises(LookupError): + RUN_ID_CTX.get() + + +class TestHTTPRequestContext: + @pytest.fixture + def mock_request(self): + return HTTPRequest( + url={"path": "/test"}, + base_url={"path": "/"}, + path_params={}, + cookies={}, + user={"is_authenticated": True}, + headers={}, + query_params={}, + ) + + def test_http_request_context(self, mock_request: HTTPRequest): + with http_request_context(mock_request): + assert HTTP_REQUEST_CTX.get() == mock_request + + def test_nested_contexts(self, mock_request: HTTPRequest): + request2 = HTTPRequest( + url={"path": "/test2"}, + base_url={"path": "/"}, + path_params={}, + cookies={}, + user={"is_authenticated": True}, + headers={}, + query_params={}, + ) + + with http_request_context(mock_request): + outer_req = HTTP_REQUEST_CTX.get() + with http_request_context(request2): + inner_req = HTTP_REQUEST_CTX.get() + assert inner_req != outer_req + assert inner_req == request2 + + # Verify we're back to outer context + assert HTTP_REQUEST_CTX.get() == outer_req + + def test_context_cleanup(self, mock_request: HTTPRequest): + # Verify context is None before entering + with pytest.raises(LookupError): + HTTP_REQUEST_CTX.get() + + with http_request_context(mock_request): + assert HTTP_REQUEST_CTX.get() == mock_request + + # Verify context is cleaned up after exiting + with pytest.raises(LookupError): + HTTP_REQUEST_CTX.get() + + def test_none_request_assertion(self): + with pytest.raises(AssertionError): + with http_request_context(None): + pass diff --git a/tests/_plugins/ui/_impl/tables/test_table_utils.py b/tests/_plugins/ui/_impl/tables/test_table_utils.py index 410831f35db..47958f30103 100644 --- a/tests/_plugins/ui/_impl/tables/test_table_utils.py +++ b/tests/_plugins/ui/_impl/tables/test_table_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from marimo._plugins.ui._impl.tables.format import ( + FormatMapping, format_column, format_row, format_value, @@ -9,7 +10,7 @@ def test_format_value(): # Test with string formatter - format_mapping = {"col1": "{:.2f}"} + format_mapping: FormatMapping = {"col1": "{:.2f}"} assert format_value("col1", 123.456, format_mapping) == "123.46" # Test with callable formatter @@ -43,13 +44,13 @@ def test_format_value(): def test_format_row(): # Test with string formatter - format_mapping = {"col1": "{:.2f}", "col2": "{:.1f}"} + format_mapping: FormatMapping = {"col1": "{:.2f}", "col2": "{:.1f}"} row = {"col1": 123.456, "col2": 78.9} expected = {"col1": "123.46", "col2": "78.9"} assert format_row(row, format_mapping) == expected # Test with callable formatter - format_mapping = { + format_mapping: FormatMapping = { "col1": lambda x: f"${x:.2f}", "col2": lambda x: f"{x:.1f}%", } @@ -93,7 +94,7 @@ def test_format_row(): def test_format_column(): # Test with string formatter - format_mapping = {"col1": "{:.2f}"} + format_mapping: FormatMapping = {"col1": "{:.2f}"} values = [123.456, 78.9] expected = ["123.46", "78.90"] assert format_column("col1", values, format_mapping) == expected diff --git a/tests/_runtime/test_requests.py b/tests/_runtime/test_requests.py new file mode 100644 index 00000000000..9c40f6cbf8e --- /dev/null +++ b/tests/_runtime/test_requests.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Optional + +from starlette.datastructures import URL, Headers, QueryParams +from starlette.requests import HTTPConnection + +from marimo._runtime.requests import HTTPRequest + + +class MockHTTPConnection(HTTPConnection): + def __init__( + self, + url: str = "http://localhost:8000/test?param1=value1¶m2=value2", + headers: Optional[dict[str, str]] = None, + path_params: Optional[dict[str, Any]] = None, + ): + url_obj = URL(url) + # Convert headers to list of tuples as expected by Starlette + raw_headers = [(k.lower(), v) for k, v in (headers or {}).items()] + scope = { + "type": "http", + "method": "GET", + "headers": dict(raw_headers), + "path": url_obj.path, + "path_params": path_params or {}, + } + super().__init__(scope) + self._url = url_obj + self._base_url = URL("http://localhost:8000") + + @property + def url(self) -> URL: + return self._url + + @property + def base_url(self) -> URL: + return self._base_url + + @property + def query_params(self) -> QueryParams: + return QueryParams(self.url.query) + + @property + def headers(self) -> Headers: + return Headers(headers=self.scope["headers"]) + + +def test_http_request_like_basic_mapping(): + request = HTTPRequest( + url={"path": "/test"}, + base_url={"path": "/"}, + headers={"Content-Type": "application/json"}, + query_params=defaultdict(list), + path_params={}, + cookies={}, + user={"is_authenticated": True}, + ) + + assert request["url"] == {"path": "/test"} + assert set(request) == { + "url", + "base_url", + "headers", + "query_params", + "path_params", + "cookies", + "user", + } + + +def test_from_request(): + mock_request = MockHTTPConnection( + url="http://localhost:8000/test?param1=value1¶m2=value2", + headers={ + "Content-Type": "application/json", + "Cookie": "session=abc123", + }, + path_params={"id": "123"}, + ) + + request_like = HTTPRequest.from_request(mock_request) + + assert request_like["url"] == { + "path": "/test", + "port": 8000, + "scheme": "http", + "netloc": "localhost:8000", + "query": "param1=value1¶m2=value2", + "hostname": "localhost", + } + + assert request_like["headers"] == { + "content-type": "application/json", + "cookie": "session=abc123", + } + assert request_like["cookies"] == {"session": "abc123"} + assert request_like["path_params"] == {"id": "123"} + + +def test_query_params_filtering(): + mock_request = MockHTTPConnection( + url="http://localhost:8000/test?param1=value1&marimo_param=value2" + ) + + request_like = HTTPRequest.from_request(mock_request) + + # marimo in params is ok + assert dict(request_like["query_params"]) == { + "param1": ["value1"], + "marimo_param": ["value2"], + } + + +def test_header_params_filtering(): + mock_request = MockHTTPConnection( + url="http://localhost:8000/test", + headers={ + "Content-Type": "application/json", + "x-marimo-param": "value1", + "marimo-param": "value2", + }, + ) + + request_like = HTTPRequest.from_request(mock_request) + + assert request_like["headers"] == {"content-type": "application/json"} + + +def test_display(): + request = HTTPRequest( + url={"path": "/test"}, + base_url={"path": "/"}, + headers={}, + query_params=defaultdict(list), + path_params={}, + cookies={}, + user={"is_authenticated": True}, + ) + + display_dict = request._display_() + assert isinstance(display_dict, dict) + assert "url" in display_dict diff --git a/tests/_server/api/endpoints/test_execution.py b/tests/_server/api/endpoints/test_execution.py index 801277fde05..083b1f78589 100644 --- a/tests/_server/api/endpoints/test_execution.py +++ b/tests/_server/api/endpoints/test_execution.py @@ -1,8 +1,11 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING +import json +import time +from typing import TYPE_CHECKING, Any +from marimo._utils.lists import first from tests._server.conftest import get_session_manager from tests._server.mocks import token_header, with_read_session, with_session @@ -155,6 +158,49 @@ def test_takeover_file_key(client: TestClient) -> None: assert response.headers["content-type"] == "application/json" assert response.json()["status"] == "ok" + @staticmethod + @with_session(SESSION_ID) + def test_app_meta_request(client: TestClient) -> None: + response = client.post( + "/api/kernel/run", + headers=HEADERS, + json={ + "cell_ids": ["test-1"], + "codes": [ + "import marimo as mo\n" + "import json\n" + "request = dict(mo.app_meta().request)\n" + "request['user'] = bool(request['user'])\n" # user is not serializable + "print(json.dumps(request))" + ], + }, + ) + assert response.status_code == 200, response.text + assert response.headers["content-type"] == "application/json" + assert "success" in response.json() + + # Sleep for .5 seconds + time.sleep(0.5) + + # Check keys + app_meta_response = get_printed_object(client, "test-1") + assert set(app_meta_response.keys()) == { + "base_url", + "cookies", + "headers", + "user", + "path_params", + "query_params", + "url", + } + # Check no marimo in headers + assert all( + "marimo" not in header + for header in app_meta_response["headers"].keys() + ) + # Check user is True + assert app_meta_response["user"] is True + class TestExecutionRoutes_RunMode: @staticmethod @@ -261,3 +307,55 @@ def test_run_scratchpad(client: TestClient) -> None: def test_takeover_no_file_key(client: TestClient) -> None: response = client.post("/api/kernel/takeover", headers=HEADERS) assert response.status_code == 401, response.text + + @staticmethod + @with_session(SESSION_ID) + def with_read_session(client: TestClient) -> None: + response = client.post( + "/api/kernel/run", + headers=HEADERS, + json={ + "cell_ids": ["test-1"], + "codes": [ + "import marimo as mo\n" + "import json\n" + "request = dict(mo.app_meta().request)\n" + "request['user'] = bool(request['user'])\n" # user is not serializable + "print(json.dumps(request))" + ], + }, + ) + assert response.status_code == 200, response.text + assert response.headers["content-type"] == "application/json" + assert "success" in response.json() + + # Sleep for .5 seconds + time.sleep(0.5) + + # Check keys + app_meta_response = get_printed_object(client, "test-1") + assert set(app_meta_response.keys()) == { + "base_url", + "cookies", + "headers", + "user", + "path_params", + "query_params", + "url", + } + # Check no marimo in headers + assert all( + "marimo" not in header + for header in app_meta_response["headers"].keys() + ) + # Check user is True + assert app_meta_response["user"] is True + + +def get_printed_object(client: TestClient, cell_id: str) -> dict[str, Any]: + session = get_session_manager(client).get_session(SESSION_ID) + assert session + console = first(session.session_view.cell_operations[cell_id].console) + assert console + assert isinstance(console.data, str) + return json.loads(console.data) From b13efd625278422354a05c84840e0376877ebf01 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Wed, 29 Jan 2025 15:53:38 -0500 Subject: [PATCH 2/5] fixes --- marimo/_cli/development/commands.py | 4 ++-- marimo/_messaging/context.py | 1 - marimo/_runtime/requests.py | 8 +++---- marimo/_server/models/models.py | 2 +- openapi/api.yaml | 37 ++++++++++++++++------------- openapi/src/api.ts | 12 ++++++---- 6 files changed, 35 insertions(+), 29 deletions(-) diff --git a/marimo/_cli/development/commands.py b/marimo/_cli/development/commands.py index 839e1db0029..a681d757746 100644 --- a/marimo/_cli/development/commands.py +++ b/marimo/_cli/development/commands.py @@ -163,7 +163,6 @@ def _generate_schema() -> dict[str, Any]: models.SuccessResponse, models.UpdateComponentValuesRequest, requests.CodeCompletionRequest, - requests.CreationRequest, requests.DeleteCellRequest, requests.ExecuteMultipleRequest, requests.ExecuteScratchpadRequest, @@ -192,7 +191,8 @@ def _generate_schema() -> dict[str, Any]: {"type": "boolean"}, {"type": "null"}, ] - } + }, + "HTTPRequest": {"type": "null"}, } # We must override the names of some Union Types, # otherwise, their __name__ is "Union" diff --git a/marimo/_messaging/context.py b/marimo/_messaging/context.py index 3a5556ec2c6..a7f90f75a87 100644 --- a/marimo/_messaging/context.py +++ b/marimo/_messaging/context.py @@ -37,7 +37,6 @@ class http_request_context: request: Optional[HTTPRequest] def __init__(self, request: Optional[HTTPRequest]) -> None: - assert request is not None self.request = request def __enter__(self) -> None: diff --git a/marimo/_runtime/requests.py b/marimo/_runtime/requests.py index f7a972cf234..378c665589e 100644 --- a/marimo/_runtime/requests.py +++ b/marimo/_runtime/requests.py @@ -120,7 +120,7 @@ class ExecutionRequest: cell_id: CellId_t code: str # incoming request, e.g. from Starlette or FastAPI - request: Optional[HTTPRequest] + request: Optional[HTTPRequest] = None timestamp: float = field(default_factory=time.time) @@ -135,7 +135,7 @@ class ExecuteMultipleRequest: # code to register/run for each cell codes: List[str] # incoming request, e.g. from Starlette or FastAPI - request: Optional[Any] + request: Optional[HTTPRequest] # time at which the request was received timestamp: float = field(default_factory=time.time) @@ -161,7 +161,7 @@ def __post_init__(self) -> None: class ExecuteScratchpadRequest: code: str # incoming request, e.g. from Starlette or FastAPI - request: Optional[Any] + request: Optional[HTTPRequest] @dataclass @@ -174,7 +174,7 @@ class SetUIElementValueRequest: object_ids: List[UIElementId] values: List[Any] # Incoming request, e.g. from Starlette or FastAPI - request: Optional[HTTPRequest] + request: Optional[HTTPRequest] = None # uniquely identifies the request token: str = field(default_factory=lambda: str(uuid4())) diff --git a/marimo/_server/models/models.py b/marimo/_server/models/models.py index 3c32fefaf37..31394db2f6d 100644 --- a/marimo/_server/models/models.py +++ b/marimo/_server/models/models.py @@ -105,7 +105,7 @@ def __post_init__(self) -> None: class RunScratchpadRequest: code: str # incoming request, e.g. from Starlette or FastAPI - request: Optional[Any] = None + request: Optional[HTTPRequest] = None def as_execution_request(self) -> ExecuteScratchpadRequest: return ExecuteScratchpadRequest( diff --git a/openapi/api.yaml b/openapi/api.yaml index 5fc3c9ae497..5cd02225c4c 100644 --- a/openapi/api.yaml +++ b/openapi/api.yaml @@ -330,21 +330,6 @@ components: - source - destination type: object - CreationRequest: - properties: - autoRun: - type: boolean - executionRequests: - items: - $ref: '#/components/schemas/ExecutionRequest' - type: array - setUiElementValueRequest: - $ref: '#/components/schemas/SetUIElementValueRequest' - required: - - executionRequests - - setUiElementValueRequest - - autoRun - type: object CycleError: properties: edges_with_vars: @@ -520,6 +505,9 @@ components: items: type: string type: array + request: + $ref: '#/components/schemas/HTTPRequest' + nullable: true timestamp: type: number required: @@ -531,6 +519,9 @@ components: properties: code: type: string + request: + $ref: '#/components/schemas/HTTPRequest' + nullable: true required: - code type: object @@ -543,6 +534,9 @@ components: type: string code: type: string + request: + $ref: '#/components/schemas/HTTPRequest' + nullable: true timestamp: type: number required: @@ -816,6 +810,8 @@ components: - status - name type: object + HTTPRequest: + type: 'null' HumanReadableStatus: properties: code: @@ -1638,6 +1634,9 @@ components: items: type: string type: array + request: + $ref: '#/components/schemas/HTTPRequest' + nullable: true required: - cellIds - codes @@ -1646,6 +1645,9 @@ components: properties: code: type: string + request: + $ref: '#/components/schemas/HTTPRequest' + nullable: true required: - code type: object @@ -1753,6 +1755,9 @@ components: items: type: string type: array + request: + $ref: '#/components/schemas/HTTPRequest' + nullable: true token: type: string values: @@ -1968,7 +1973,7 @@ components: type: object info: title: marimo API - version: 0.10.13 + version: 0.10.17 openapi: 3.1.0 paths: /@file/{filename_and_length}: diff --git a/openapi/src/api.ts b/openapi/src/api.ts index b07d79d9278..1a04c9099d9 100644 --- a/openapi/src/api.ts +++ b/openapi/src/api.ts @@ -2214,11 +2214,6 @@ export interface components { destination: string; source: string; }; - CreationRequest: { - autoRun: boolean; - executionRequests: components["schemas"]["ExecutionRequest"][]; - setUiElementValueRequest: components["schemas"]["SetUIElementValueRequest"]; - }; CycleError: { edges_with_vars: [string, string[], string][]; /** @enum {string} */ @@ -2292,15 +2287,18 @@ export interface components { ExecuteMultipleRequest: { cellIds: string[]; codes: string[]; + request?: components["schemas"]["HTTPRequest"]; timestamp: number; }; ExecuteScratchpadRequest: { code: string; + request?: components["schemas"]["HTTPRequest"]; }; ExecuteStaleRequest: Record; ExecutionRequest: { cellId: string; code: string; + request?: components["schemas"]["HTTPRequest"]; timestamp: number; }; ExportAsHTMLRequest: { @@ -2410,6 +2408,7 @@ export interface components { return_value?: components["schemas"]["JSONType"]; status: components["schemas"]["HumanReadableStatus"]; }; + HTTPRequest: null; HumanReadableStatus: { /** @enum {string} */ code: "ok" | "error"; @@ -2767,9 +2766,11 @@ export interface components { RunRequest: { cellIds: string[]; codes: string[]; + request?: components["schemas"]["HTTPRequest"]; }; RunScratchpadRequest: { code: string; + request?: components["schemas"]["HTTPRequest"]; }; RunningNotebooksResponse: { files: components["schemas"]["MarimoFile"][]; @@ -2815,6 +2816,7 @@ export interface components { }; SetUIElementValueRequest: { objectIds: string[]; + request?: components["schemas"]["HTTPRequest"]; token: string; values: unknown[]; }; From d25d50e9c3d7d589c0dbce5b2ad85bd54f4daefa Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Wed, 29 Jan 2025 15:59:26 -0500 Subject: [PATCH 3/5] fixes --- marimo/_runtime/requests.py | 2 +- tests/_messaging/test_context_vars.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/marimo/_runtime/requests.py b/marimo/_runtime/requests.py index 378c665589e..586878603ad 100644 --- a/marimo/_runtime/requests.py +++ b/marimo/_runtime/requests.py @@ -135,7 +135,7 @@ class ExecuteMultipleRequest: # code to register/run for each cell codes: List[str] # incoming request, e.g. from Starlette or FastAPI - request: Optional[HTTPRequest] + request: Optional[HTTPRequest] = None # time at which the request was received timestamp: float = field(default_factory=time.time) diff --git a/tests/_messaging/test_context_vars.py b/tests/_messaging/test_context_vars.py index 9ef51c75d89..59a8bbd45e8 100644 --- a/tests/_messaging/test_context_vars.py +++ b/tests/_messaging/test_context_vars.py @@ -93,7 +93,6 @@ def test_context_cleanup(self, mock_request: HTTPRequest): with pytest.raises(LookupError): HTTP_REQUEST_CTX.get() - def test_none_request_assertion(self): - with pytest.raises(AssertionError): - with http_request_context(None): - pass + def test_none_request(self): + with http_request_context(None): + assert HTTP_REQUEST_CTX.get() is None From cdc0f77f0dad7f8474f13cbc4ca943be8ab0922d Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Wed, 29 Jan 2025 16:08:28 -0500 Subject: [PATCH 4/5] try asdict --- marimo/_runtime/requests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/marimo/_runtime/requests.py b/marimo/_runtime/requests.py index 586878603ad..125fea10f96 100644 --- a/marimo/_runtime/requests.py +++ b/marimo/_runtime/requests.py @@ -69,7 +69,10 @@ def __len__(self) -> int: return len(self.__dict__) def _display_(self) -> Any: - return asdict(self) + try: + return asdict(self) + except TypeError: + return self.__dict__ @staticmethod def from_request(request: HTTPConnection) -> "HTTPRequest": From 5b51143eab2ea6f030ce6ab4801ef901d548554d Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Wed, 29 Jan 2025 17:34:35 -0500 Subject: [PATCH 5/5] docs --- docs/guides/deploying/programmatically.md | 50 +++++++++++++++++++++++ marimo/_runtime/app_meta.py | 23 ++++++++++- marimo/_runtime/runtime.py | 8 ++++ 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/docs/guides/deploying/programmatically.md b/docs/guides/deploying/programmatically.md index 19e698b04db..0a3c97132a2 100644 --- a/docs/guides/deploying/programmatically.md +++ b/docs/guides/deploying/programmatically.md @@ -64,3 +64,53 @@ for filename in sorted(notebooks_dir.iterdir()): server = server.with_app(path=f"/{app_name}", root=filename) app_names.append(app_name) ``` + +## Accessing Request Data + +Inside your marimo notebooks, you can access the current request data using `mo.app_meta().request`. This is particularly useful when implementing authentication or accessing user data. + +```python +import marimo as mo + +# Access request data in your notebook +request = mo.app_meta().request +if request and request.user and request.user["is_authenticated"]: + content = f"Welcome {request.user['username']}!" +else: + content = "Please log in" + +mo.md(content) +``` + +### Authentication Middleware Example + +Here's an example of how to implement authentication middleware that populates `request.user`: + +```python +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +class AuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Add user data to the request scope + # This will be accessible via mo.app_meta().request.user + request.scope["user"] = { + "is_authenticated": True, + "username": "example_user", + # Add any other user data + } + response = await call_next(request) + return response + +# Add the middleware to your FastAPI app +app.add_middleware(AuthMiddleware) +``` + +The `request` object provides access to: + +- `request.headers`: Request headers +- `request.cookies`: Request cookies +- `request.query_params`: Query parameters +- `request.path_params`: Path parameters +- `request.user`: User data added by authentication middleware +- `request.url`: URL information including path, query parameters diff --git a/marimo/_runtime/app_meta.py b/marimo/_runtime/app_meta.py index 944211ae4da..4430a3e292a 100644 --- a/marimo/_runtime/app_meta.py +++ b/marimo/_runtime/app_meta.py @@ -80,10 +80,29 @@ def mode(self) -> Optional[RunMode]: @property def request(self) -> Optional[HTTPRequest]: """ - The current HTTP request if any. + The current HTTP request if any. The shape of the request object depends on the ASGI framework used, + but typically includes: + + - `headers`: Request headers + - `cookies`: Request cookies + - `query_params`: Query parameters + - `path_params`: Path parameters + - `user`: User data added by authentication middleware + - `url`: URL information including path, query parameters + + Examples: + Get the current request and print the path: + + ```python + request = mo.app_meta().request + user = request.user + print( + user["is_authenticated"], user["username"], request.url["path"] + ) + ``` Returns: - Optional[Request]: The current request object if available, None otherwise. + Optional[HTTPRequest]: The current request object if available, None otherwise. """ try: context = get_context() diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index ae0e8fac038..af722a2beb8 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -262,6 +262,14 @@ def app_meta() -> AppMeta: mo.md("# Developer Notes") if mo.app_meta().mode == "edit" else None ``` + Get the current request headers or user info: + + ```python + request = mo.app_meta().request + print(request.headers) + print(request.user) + ``` + Returns: AppMeta: An AppMeta object containing the app's metadata. """