Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions marimo/_ast/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,36 +190,52 @@ class RemoveImportTransformer(ast.NodeTransformer):
To prevent module collisions in top level definitions.
"""

def __init__(self, import_name: str) -> None:
def __init__(self, import_name: str, keep_one: bool = False) -> None:
super().__init__()
self.keep_one = keep_one
self.import_name = import_name

def _return_once(
self,
node: ast.Import | ast.ImportFrom,
original_names: list[ast.alias],
) -> Optional[ast.Import | ast.ImportFrom]:
if node.names:
return node
elif self.keep_one:
self.keep_one = False
node.names = original_names
return node
return None

def strip_imports(self, code: str) -> str:
tree = ast_parse(code)
tree = self.visit(tree)
return ast.unparse(tree).strip()

def visit_Import(self, node: ast.Import) -> Optional[ast.Import]:
name = self.import_name
original_names = list(node.names)
node.names = [
alias
for alias in node.names
if (alias.asname and alias.asname != name)
or (not alias.asname and alias.name != name)
]
return node if node.names else None
return self._return_once(node, original_names)

def visit_ImportFrom(
self, node: ast.ImportFrom
) -> Optional[ast.ImportFrom]:
name = self.import_name
original_names = list(node.names)
node.names = [
alias
for alias in node.names
if (alias.asname and alias.asname != name)
or (not alias.asname and alias.name != name)
]
return node if node.names else None
return self._return_once(node, original_names)


class ExtractWithBlock(ast.NodeTransformer):
Expand Down
186 changes: 186 additions & 0 deletions marimo/_convert/comment_preserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import io
import token as token_types
from dataclasses import dataclass
from tokenize import TokenError, tokenize
from typing import Callable


@dataclass
class CommentToken:
text: str
line: int
col: int


class CommentPreserver:
"""Functor to preserve comments during source code transformations."""

def __init__(self, sources: list[str]):
self.sources = sources
self.comments_by_source: dict[int, list[CommentToken]] = {}
self._extract_all_comments()

def _extract_all_comments(self) -> None:
"""Extract comments from all sources during initialization."""
for i, source in enumerate(self.sources):
self.comments_by_source[i] = self._extract_comments_from_source(
source
)

def _extract_comments_from_source(self, source: str) -> list[CommentToken]:
"""Extract comments from a single source string."""
if not source.strip():
return []

comments = []
try:
tokens = tokenize(io.BytesIO(source.encode("utf-8")).readline)
for token in tokens:
if token.type == token_types.COMMENT:
comments.append(
CommentToken(
text=token.string,
line=token.start[0],
col=token.start[1],
)
)
except (TokenError, SyntaxError):
# If tokenization fails, return empty list - no comments preserved
pass

return comments

def __call__(
self, transform_func: Callable[..., list[str]]
) -> Callable[..., list[str]]:
"""
Method decorator that returns a comment-preserving version of transform_func.

Usage: preserver(transform_func)(sources, *args, **kwargs)
"""

def wrapper(*args: object, **kwargs: object) -> list[str]:
# Apply the original transformation
transformed_sources = transform_func(*args, **kwargs)

# If sources weren't provided or transformation failed, return as-is
if not args or not isinstance(args[0], list):
return transformed_sources

original_sources = args[0]

# Merge comments back into transformed sources
result = self._merge_comments(
original_sources, transformed_sources
)

# Update our internal comment data to track only the clean transformed sources
# This clears old comments that no longer apply
self._update_comments_for_transformed_sources(transformed_sources)

return result

return wrapper

def _merge_comments(
self,
original_sources: list[str],
transformed_sources: list[str],
) -> list[str]:
"""Merge comments from original sources into transformed sources."""
if len(original_sources) != len(transformed_sources):
# If cell count changed, we can't preserve comments reliably
return transformed_sources

result = []
for i, (original, transformed) in enumerate(
zip(original_sources, transformed_sources)
):
comments = self.comments_by_source.get(i, [])
if not comments:
result.append(transformed)
continue

# Apply comment preservation with variable name updates if needed
preserved_source = self._apply_comments_to_source(
original, transformed, comments
)
result.append(preserved_source)

return result

def _apply_comments_to_source(
self,
original: str,
transformed: str,
comments: list[CommentToken],
) -> str:
"""Apply comments to a single transformed source."""
if not comments:
return transformed

original_lines = original.split("\n")
transformed_lines = transformed.split("\n")

# Create a mapping of line numbers to comments
comments_by_line: dict[int, list[CommentToken]] = {}
for comment in comments:
line_num = comment.line
if line_num not in comments_by_line:
comments_by_line[line_num] = []
comments_by_line[line_num].append(comment)

# Apply comments to transformed lines
result_lines = transformed_lines.copy()

for line_num, line_comments in comments_by_line.items():
target_line_idx = min(
line_num - 1, len(result_lines) - 1
) # Convert to 0-based, clamp to bounds

if target_line_idx < 0:
continue

# Select the best comment for this line (line comments take precedence)
line_comment = None
inline_comment = None

for comment in line_comments:
if comment.col == 0: # Line comment (starts at column 0)
line_comment = comment
break # Line comment takes precedence, no need to check others
else: # Inline comment
inline_comment = comment

# Prefer line comment over inline comment
chosen_comment = line_comment if line_comment else inline_comment

if chosen_comment:
comment_text = chosen_comment.text
if chosen_comment.col > 0 and target_line_idx < len(
original_lines
):
# Inline comment - append to the line if not already present
current_line = result_lines[target_line_idx]
if not current_line.rstrip().endswith(
comment_text.rstrip()
):
result_lines[target_line_idx] = (
current_line.rstrip() + " " + comment_text
)
elif target_line_idx >= 0 and comment_text not in result_lines:
# Standalone comment - insert above the line if not already present
result_lines.insert(target_line_idx, comment_text)

return "\n".join(result_lines)

def _update_comments_for_transformed_sources(
self, sources: list[str]
) -> None:
"""Update internal comment data to track the transformed sources."""
self.sources = sources
self.comments_by_source = {}
self._extract_all_comments()
61 changes: 43 additions & 18 deletions marimo/_convert/ipynb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from marimo._ast.cell import CellConfig
from marimo._ast.compiler import compile_cell
from marimo._ast.transformers import NameTransformer
from marimo._ast.transformers import NameTransformer, RemoveImportTransformer
from marimo._ast.variables import is_local
from marimo._ast.visitor import Block, NamedNode, ScopedVisitor
from marimo._convert.utils import markdown_to_marimo
Expand Down Expand Up @@ -651,20 +651,26 @@ def transform_remove_duplicate_imports(sources: list[str]) -> list[str]:
imports: set[str] = set()
new_sources: list[str] = []
for source in sources:
new_lines: list[str] = []
for line in source.split("\n"):
stripped_line = line.strip()
if stripped_line.startswith("import ") or stripped_line.startswith(
"from "
):
if stripped_line not in imports:
imports.add(stripped_line)
new_lines.append(line)
else:
new_lines.append(line)

new_source = "\n".join(new_lines)
new_sources.append(new_source.strip())
try:
cell = compile_cell(source, cell_id=CellId_t("temp"))
except SyntaxError:
new_sources.append(source)
continue
scoped = set()
for var, instances in cell.variable_data.items():
for instance in instances:
if (
var in imports or var in scoped
) and instance.kind == "import":
# If it's not in global imports, we keep one instance
keep_one = var not in imports
transformer = RemoveImportTransformer(
var, keep_one=keep_one
)
source = transformer.strip_imports(source)
scoped.add(var)
imports.update(scoped)
new_sources.append(source)

return new_sources

Expand Down Expand Up @@ -715,23 +721,42 @@ def _transform_sources(

After this step, cells are ready for execution or rendering.
"""
source_transforms: list[Transform] = [
from marimo._convert.comment_preserver import CommentPreserver

# Define transforms that don't need comment preservation
simple_transforms = [
transform_strip_whitespace,
transform_magic_commands,
transform_exclamation_mark,
]

# Define transforms that should preserve comments
comment_preserving_transforms = [
transform_remove_duplicate_imports,
transform_fixup_multiple_definitions,
transform_duplicate_definitions,
]

# Run all the source transforms
for source_transform in source_transforms:
# Run simple transforms first (no comment preservation needed)
for source_transform in simple_transforms:
new_sources = source_transform(sources)
assert len(new_sources) == len(sources), (
f"{source_transform.__name__} changed cell count"
)
sources = new_sources

# Create comment preserver from the simplified sources
comment_preserver = CommentPreserver(sources)

# Run comment-preserving transforms
for base_transform in comment_preserving_transforms:
transform = comment_preserver(base_transform)
new_sources = transform(sources)
assert len(new_sources) == len(sources), (
f"{base_transform.__name__} changed cell count"
)
sources = new_sources

cells = bind_cell_metadata(sources, metadata, hide_flags)

# may change cell count
Expand Down
Loading
Loading