Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion frontend/src/core/cells/__tests__/cells.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1305,9 +1305,29 @@ describe("cell reducer", () => {
actions.setCellIds({ cellIds: newIds });
expect(state.cellIds.atOrThrow(FIRST_COLUMN).topLevelIds).toEqual(newIds);

actions.setCellCodes({ codes: newCodes, ids: newIds });
// When codeIsStale is false, lastCodeRun should match code
actions.setCellCodes({
codes: newCodes,
ids: newIds,
codeIsStale: false,
});
newIds.forEach((id, index) => {
expect(state.cellData[id].code).toBe(newCodes[index]);
expect(state.cellData[id].lastCodeRun).toBe(newCodes[index]);
expect(state.cellData[id].edited).toBe(false);
});

// When codeIsStale is true, lastCodeRun should not change
const staleCodes = ["stale1", "stale2", "stale3"];
actions.setCellCodes({
codes: staleCodes,
ids: newIds,
codeIsStale: true,
});
newIds.forEach((id, index) => {
expect(state.cellData[id].code).toBe(staleCodes[index]);
expect(state.cellData[id].lastCodeRun).toBe(newCodes[index]);
expect(state.cellData[id].edited).toBe(true);
});
});

Expand Down
26 changes: 22 additions & 4 deletions frontend/src/core/cells/cells.ts
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,10 @@ const {
cellHandles: nextCellHandles,
};
},
setCellCodes: (state, action: { codes: string[]; ids: CellId[] }) => {
setCellCodes: (
state,
action: { codes: string[]; ids: CellId[]; codeIsStale: boolean },
) => {
invariant(
action.codes.length === action.ids.length,
"Expected codes and ids to have the same length",
Expand All @@ -791,11 +794,26 @@ const {
const code = action.codes[i];

state = updateCellData(state, cellId, (cell) => {
// No change
if (cell.code.trim() === code.trim()) {
return cell;
}

// Update codemirror if mounted
const cellHandle = state.cellHandles[cellId].current;
if (cellHandle?.editorView) {
updateEditorCodeFromPython(cellHandle.editorView, code);
}

// If code is stale, we don't promote it to lastCodeRun
const lastCodeRun = action.codeIsStale ? cell.lastCodeRun : code;

return {
...cell,
code,
edited: false,
lastCodeRun: code,
code: code,
// Mark as edited if the code has changed
edited: lastCodeRun ? lastCodeRun.trim() !== code.trim() : false,
lastCodeRun,
};
});
}
Expand Down
1 change: 1 addition & 0 deletions frontend/src/core/websocket/useMarimoWebSocket.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ export function useMarimoWebSocket(opts: {
setCellCodes({
codes: msg.data.codes,
ids: msg.data.cell_ids as CellId[],
codeIsStale: msg.data.code_is_stale,
});
return;
case "update-cell-ids":
Expand Down
11 changes: 10 additions & 1 deletion marimo/_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,14 @@ def main(
help=sandbox_message,
)
@click.option("--profile-dir", default=None, type=str, hidden=True)
@click.option(
"--watch",
is_flag=True,
default=False,
show_default=True,
type=bool,
help="Watch the file for changes and reload the code when saved in another editor.",
)
@click.argument("name", required=False, type=click.Path())
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
def edit(
Expand All @@ -300,6 +308,7 @@ def edit(
skip_update_check: bool,
sandbox: bool,
profile_dir: Optional[str],
watch: bool,
name: Optional[str],
args: tuple[str, ...],
) -> None:
Expand Down Expand Up @@ -369,7 +378,7 @@ def edit(
headless=headless,
mode=SessionMode.EDIT,
include_code=True,
watch=False,
watch=watch,
cli_args=parse_args(args),
auth_token=_resolve_token(token, token_password),
base_url=base_url,
Expand Down
1 change: 1 addition & 0 deletions marimo/_messaging/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ class UpdateCellCodes(Op):
name: ClassVar[str] = "update-cell-codes"
cell_ids: List[CellId_t]
codes: List[str]
code_is_stale: bool


@dataclass
Expand Down
9 changes: 0 additions & 9 deletions marimo/_server/api/lifespans.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,6 @@ async def lsp(app: Starlette) -> AsyncIterator[None]:
yield


@contextlib.asynccontextmanager
async def watcher(app: Starlette) -> AsyncIterator[None]:
state = AppState.from_app(app)
if state.watch:
session_mgr = state.session_manager
session_mgr.start_file_watcher()
yield


@contextlib.asynccontextmanager
async def open_browser(app: Starlette) -> AsyncIterator[None]:
state = AppState.from_app(app)
Expand Down
114 changes: 76 additions & 38 deletions marimo/_server/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
ConnectionDistributor,
QueueDistributor,
)
from marimo._utils.file_watcher import FileWatcher
from marimo._utils.file_watcher import FileWatcherManager
from marimo._utils.paths import import_files
from marimo._utils.repr import format_repr
from marimo._utils.typed_connection import TypedConnection
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(
virtual_files_supported: bool,
redirect_console_to_browser: bool,
) -> None:
self.kernel_task: Optional[threading.Thread] | Optional[mp.Process]
self.kernel_task: Optional[threading.Thread | mp.Process] = None
self.queue_manager = queue_manager
self.mode = mode
self.configs = configs
Expand Down Expand Up @@ -554,6 +554,8 @@ def put_control_request(
UpdateCellCodes(
cell_ids=request.cell_ids,
codes=request.codes,
# Not stale because we just ran the code
code_is_stale=False,
),
except_consumer=from_consumer_id,
)
Expand Down Expand Up @@ -704,6 +706,7 @@ def __init__(
auth_token: Optional[AuthToken],
redirect_console_to_browser: bool,
ttl_seconds: Optional[int],
watch: bool = False,
) -> None:
self.file_router = file_router
self.mode = mode
Expand All @@ -713,7 +716,8 @@ def __init__(
self.include_code = include_code
self.ttl_seconds = ttl_seconds
self.lsp_server = lsp_server
self.watcher: Optional[FileWatcher] = None
self.watcher_manager = FileWatcherManager()
self.watch = watch
self.recents = RecentFilesManager()
self.user_config_manager = user_config_manager
self.cli_args = cli_args
Expand Down Expand Up @@ -772,7 +776,7 @@ def create_session(
if app_file_manager.path:
self.recents.touch(app_file_manager.path)

self.sessions[session_id] = Session.create(
session = Session.create(
initialization_id=file_key,
session_consumer=session_consumer,
mode=self.mode,
Expand All @@ -787,8 +791,60 @@ def create_session(
redirect_console_to_browser=self.redirect_console_to_browser,
ttl_seconds=self.ttl_seconds,
)
self.sessions[session_id] = session

# Start file watcher if enabled
if self.watch and app_file_manager.path:
self._start_file_watcher_for_session(session)

return self.sessions[session_id]

def _start_file_watcher_for_session(self, session: Session) -> None:
"""Start a file watcher for a session."""
if not session.app_file_manager.path:
return

async def on_file_changed(path: Path) -> None:
LOGGER.debug(f"{path} was modified")
# Skip if the session does not relate to the file
if session.app_file_manager.path != os.path.abspath(path):
return

# Reload the file manager to get the latest code
try:
session.app_file_manager.reload()
except Exception as e:
# If there are syntax errors, we just skip
# and don't send the changes
LOGGER.error(f"Error loading file: {e}")
return
# In run, we just call Reload()
if self.mode == SessionMode.RUN:
session.write_operation(Reload(), from_consumer_id=None)
return

# Get the latest codes
codes = list(session.app_file_manager.app.cell_manager.codes())
cell_ids = list(
session.app_file_manager.app.cell_manager.cell_ids()
)
# Send the updated codes to the frontend
session.write_operation(
UpdateCellCodes(
cell_ids=cell_ids,
codes=codes,
# The code is considered stale
# because it has not been run yet.
# In the future, we may add auto-run here.
code_is_stale=True,
),
from_consumer_id=None,
)

self.watcher_manager.add_callback(
Path(session.app_file_manager.path), on_file_changed
)

def get_session(self, session_id: SessionId) -> Optional[Session]:
session = self.sessions.get(session_id)
if session:
Expand Down Expand Up @@ -905,13 +961,22 @@ async def start_lsp_server(self) -> None:
return

def close_session(self, session_id: SessionId) -> bool:
"""Close a session and remove its file watcher if it has one."""
LOGGER.debug("Closing session %s", session_id)
session = self.get_session(session_id)
if session is not None:
session.close()
del self.sessions[session_id]
return True
return False
if session is None:
return False

# Remove the file watcher callback for this session
if session.app_file_manager.path and self.watch:
self.watcher_manager.remove_callback(
Path(session.app_file_manager.path),
self._start_file_watcher_for_session(session).__wrapped__, # type: ignore
)

session.close()
del self.sessions[session_id]
return True

def close_all_sessions(self) -> None:
LOGGER.debug("Closing all sessions (sessions: %s)", self.sessions)
Expand All @@ -921,43 +986,16 @@ def close_all_sessions(self) -> None:
self.sessions = {}

def shutdown(self) -> None:
"""Shutdown the session manager and stop all file watchers."""
LOGGER.debug("Shutting down")
self.close_all_sessions()
self.lsp_server.stop()
if self.watcher:
self.watcher.stop()
self.watcher_manager.stop_all()

def should_send_code_to_frontend(self) -> bool:
"""Returns True if the server can send messages to the frontend."""
return self.mode == SessionMode.EDIT or self.include_code

def start_file_watcher(self) -> Disposable:
"""Starts the file watcher if it is not already started"""
if self.mode == SessionMode.EDIT:
# We don't support file watching in edit mode yet
# as there are some edge cases that would need to be handled.
# - what to do if the file is deleted, or is renamed
# - do we re-run the app or just show the changed code
# - we don't properly handle saving from the frontend
LOGGER.warning("Cannot start file watcher in edit mode")
return Disposable.empty()
file = self.file_router.maybe_get_single_file()
if not file:
return Disposable.empty()

file_path = file.path

async def on_file_changed(path: Path) -> None:
LOGGER.debug(f"{path} was modified")
for _, session in self.sessions.items():
session.app_file_manager.reload()
session.write_operation(Reload(), from_consumer_id=None)

LOGGER.debug("Starting file watcher for %s", file_path)
self.watcher = FileWatcher.create(Path(file_path), on_file_changed)
self.watcher.start()
return Disposable(self.watcher.stop)

def get_active_connection_count(self) -> int:
return len(
[
Expand Down
17 changes: 16 additions & 1 deletion marimo/_server/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
initialize_fd_limit,
)
from marimo._server.uvicorn_utils import initialize_signals
from marimo._tracer import LOGGER
from marimo._utils.paths import import_files

DEFAULT_PORT = 2718
Expand Down Expand Up @@ -103,6 +104,20 @@ def start(

config_reader = get_default_config_manager(current_path=start_path)

# If watch is true, disable auto-save and format-on-save,
# watch is enabled when they are editing in another editor
if watch:
config_reader = config_reader.with_overrides(
{
"save": {
"autosave": "off",
"format_on_save": False,
"autosave_delay": 1000,
}
}
)
LOGGER.info("Watch mode enabled, auto-save is disabled")

session_manager = SessionManager(
file_router=file_router,
mode=mode,
Expand All @@ -115,6 +130,7 @@ def start(
cli_args=cli_args,
auth_token=auth_token,
redirect_console_to_browser=redirect_console_to_browser,
watch=watch,
)

log_level = "info" if development_mode else "error"
Expand All @@ -126,7 +142,6 @@ def start(
lifespan=lifespans.Lifespans(
[
lifespans.lsp,
lifespans.watcher,
lifespans.etc,
lifespans.signal_handler,
lifespans.logging,
Expand Down
Loading
Loading