Skip to content

Commit 9ed2b19

Browse files
authored
fix: preserve comments in jupyter conversion (#6516)
## 📝 Summary Previous comments were removed on jupyter conversion, this will maintain them
1 parent 3414073 commit 9ed2b19

File tree

7 files changed

+591
-26
lines changed

7 files changed

+591
-26
lines changed

marimo/_ast/transformers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import ast
55
import inspect
66
import textwrap
7-
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
7+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast
88

99
from marimo._ast.parse import ast_parse
1010
from marimo._ast.variables import unmangle_local
@@ -14,6 +14,8 @@
1414

1515
ARG_PREFIX: str = "*"
1616

17+
T = TypeVar("T", bound="ast.Import | ast.ImportFrom")
18+
1719

1820
class BlockException(Exception):
1921
pass
@@ -190,36 +192,52 @@ class RemoveImportTransformer(ast.NodeTransformer):
190192
To prevent module collisions in top level definitions.
191193
"""
192194

193-
def __init__(self, import_name: str) -> None:
195+
def __init__(self, import_name: str, keep_one: bool = False) -> None:
194196
super().__init__()
197+
self.keep_one = keep_one
195198
self.import_name = import_name
196199

200+
def _return_once(
201+
self,
202+
node: T,
203+
original_names: list[ast.alias],
204+
) -> Optional[T]:
205+
if node.names:
206+
return node
207+
elif self.keep_one:
208+
self.keep_one = False
209+
node.names = original_names
210+
return node
211+
return None
212+
197213
def strip_imports(self, code: str) -> str:
198214
tree = ast_parse(code)
199215
tree = self.visit(tree)
200216
return ast.unparse(tree).strip()
201217

202218
def visit_Import(self, node: ast.Import) -> Optional[ast.Import]:
203219
name = self.import_name
220+
original_names = list(node.names)
204221
node.names = [
205222
alias
206223
for alias in node.names
207224
if (alias.asname and alias.asname != name)
208225
or (not alias.asname and alias.name != name)
209226
]
210-
return node if node.names else None
227+
return self._return_once(node, original_names)
211228

212229
def visit_ImportFrom(
213230
self, node: ast.ImportFrom
214231
) -> Optional[ast.ImportFrom]:
215232
name = self.import_name
233+
original_names = list(node.names)
216234
node.names = [
217235
alias
218236
for alias in node.names
219237
if (alias.asname and alias.asname != name)
220238
or (not alias.asname and alias.name != name)
221239
]
222-
return node if node.names else None
240+
return self._return_once(node, original_names)
223241

224242

225243
class ExtractWithBlock(ast.NodeTransformer):
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2025 Marimo. All rights reserved.
2+
from __future__ import annotations
3+
4+
import io
5+
import token as token_types
6+
from dataclasses import dataclass
7+
from tokenize import TokenError, tokenize
8+
from typing import Callable
9+
10+
11+
@dataclass
12+
class CommentToken:
13+
text: str
14+
line: int
15+
col: int
16+
17+
18+
class CommentPreserver:
19+
"""Functor to preserve comments during source code transformations."""
20+
21+
def __init__(self, sources: list[str]):
22+
self.sources = sources
23+
self.comments_by_source: dict[int, list[CommentToken]] = {}
24+
self._extract_all_comments()
25+
26+
def _extract_all_comments(self) -> None:
27+
"""Extract comments from all sources during initialization."""
28+
for i, source in enumerate(self.sources):
29+
self.comments_by_source[i] = self._extract_comments_from_source(
30+
source
31+
)
32+
33+
def _extract_comments_from_source(self, source: str) -> list[CommentToken]:
34+
"""Extract comments from a single source string."""
35+
if not source.strip():
36+
return []
37+
38+
comments = []
39+
try:
40+
tokens = tokenize(io.BytesIO(source.encode("utf-8")).readline)
41+
for token in tokens:
42+
if token.type == token_types.COMMENT:
43+
comments.append(
44+
CommentToken(
45+
text=token.string,
46+
line=token.start[0],
47+
col=token.start[1],
48+
)
49+
)
50+
except (TokenError, SyntaxError):
51+
# If tokenization fails, return empty list - no comments preserved
52+
pass
53+
54+
return comments
55+
56+
def __call__(
57+
self, transform_func: Callable[..., list[str]]
58+
) -> Callable[..., list[str]]:
59+
"""
60+
Method decorator that returns a comment-preserving version of transform_func.
61+
62+
Usage: preserver(transform_func)(sources, *args, **kwargs)
63+
"""
64+
65+
def wrapper(*args: object, **kwargs: object) -> list[str]:
66+
# Apply the original transformation
67+
transformed_sources = transform_func(*args, **kwargs)
68+
69+
# If sources weren't provided or transformation failed, return as-is
70+
if not args or not isinstance(args[0], list):
71+
return transformed_sources
72+
73+
original_sources = args[0]
74+
75+
# Merge comments back into transformed sources
76+
result = self._merge_comments(
77+
original_sources, transformed_sources
78+
)
79+
80+
# Update our internal comment data to track only the clean transformed sources
81+
# This clears old comments that no longer apply
82+
self._update_comments_for_transformed_sources(transformed_sources)
83+
84+
return result
85+
86+
return wrapper
87+
88+
def _merge_comments(
89+
self,
90+
original_sources: list[str],
91+
transformed_sources: list[str],
92+
) -> list[str]:
93+
"""Merge comments from original sources into transformed sources."""
94+
if len(original_sources) != len(transformed_sources):
95+
# If cell count changed, we can't preserve comments reliably
96+
return transformed_sources
97+
98+
result = []
99+
for i, (original, transformed) in enumerate(
100+
zip(original_sources, transformed_sources)
101+
):
102+
comments = self.comments_by_source.get(i, [])
103+
if not comments:
104+
result.append(transformed)
105+
continue
106+
107+
# Apply comment preservation with variable name updates if needed
108+
preserved_source = self._apply_comments_to_source(
109+
original, transformed, comments
110+
)
111+
result.append(preserved_source)
112+
113+
return result
114+
115+
def _apply_comments_to_source(
116+
self,
117+
original: str,
118+
transformed: str,
119+
comments: list[CommentToken],
120+
) -> str:
121+
"""Apply comments to a single transformed source."""
122+
if not comments:
123+
return transformed
124+
125+
original_lines = original.split("\n")
126+
transformed_lines = transformed.split("\n")
127+
128+
# Create a mapping of line numbers to comments
129+
comments_by_line: dict[int, list[CommentToken]] = {}
130+
for comment in comments:
131+
line_num = comment.line
132+
if line_num not in comments_by_line:
133+
comments_by_line[line_num] = []
134+
comments_by_line[line_num].append(comment)
135+
136+
# Apply comments to transformed lines
137+
result_lines = transformed_lines.copy()
138+
139+
for line_num, line_comments in comments_by_line.items():
140+
target_line_idx = min(
141+
line_num - 1, len(result_lines) - 1
142+
) # Convert to 0-based, clamp to bounds
143+
144+
if target_line_idx < 0:
145+
continue
146+
147+
# Select the best comment for this line (line comments take precedence)
148+
line_comment = None
149+
inline_comment = None
150+
151+
for comment in line_comments:
152+
if comment.col == 0: # Line comment (starts at column 0)
153+
line_comment = comment
154+
break # Line comment takes precedence, no need to check others
155+
else: # Inline comment
156+
inline_comment = comment
157+
158+
# Prefer line comment over inline comment
159+
chosen_comment = line_comment if line_comment else inline_comment
160+
161+
if chosen_comment:
162+
comment_text = chosen_comment.text
163+
if chosen_comment.col > 0 and target_line_idx < len(
164+
original_lines
165+
):
166+
# Inline comment - append to the line if not already present
167+
current_line = result_lines[target_line_idx]
168+
if not current_line.rstrip().endswith(
169+
comment_text.rstrip()
170+
):
171+
result_lines[target_line_idx] = (
172+
current_line.rstrip() + " " + comment_text
173+
)
174+
elif target_line_idx >= 0 and comment_text not in result_lines:
175+
# Standalone comment - insert above the line if not already present
176+
result_lines.insert(target_line_idx, comment_text)
177+
178+
return "\n".join(result_lines)
179+
180+
def _update_comments_for_transformed_sources(
181+
self, sources: list[str]
182+
) -> None:
183+
"""Update internal comment data to track the transformed sources."""
184+
self.sources = sources
185+
self.comments_by_source = {}
186+
self._extract_all_comments()

marimo/_convert/ipynb.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from marimo._ast.cell import CellConfig
1313
from marimo._ast.compiler import compile_cell
14-
from marimo._ast.transformers import NameTransformer
14+
from marimo._ast.transformers import NameTransformer, RemoveImportTransformer
1515
from marimo._ast.variables import is_local
1616
from marimo._ast.visitor import Block, NamedNode, ScopedVisitor
1717
from marimo._convert.utils import markdown_to_marimo
@@ -651,20 +651,26 @@ def transform_remove_duplicate_imports(sources: list[str]) -> list[str]:
651651
imports: set[str] = set()
652652
new_sources: list[str] = []
653653
for source in sources:
654-
new_lines: list[str] = []
655-
for line in source.split("\n"):
656-
stripped_line = line.strip()
657-
if stripped_line.startswith("import ") or stripped_line.startswith(
658-
"from "
659-
):
660-
if stripped_line not in imports:
661-
imports.add(stripped_line)
662-
new_lines.append(line)
663-
else:
664-
new_lines.append(line)
665-
666-
new_source = "\n".join(new_lines)
667-
new_sources.append(new_source.strip())
654+
try:
655+
cell = compile_cell(source, cell_id=CellId_t("temp"))
656+
except SyntaxError:
657+
new_sources.append(source)
658+
continue
659+
scoped = set()
660+
for var, instances in cell.variable_data.items():
661+
for instance in instances:
662+
if (
663+
var in imports or var in scoped
664+
) and instance.kind == "import":
665+
# If it's not in global imports, we keep one instance
666+
keep_one = var not in imports
667+
transformer = RemoveImportTransformer(
668+
var, keep_one=keep_one
669+
)
670+
source = transformer.strip_imports(source)
671+
scoped.add(var)
672+
imports.update(scoped)
673+
new_sources.append(source)
668674

669675
return new_sources
670676

@@ -715,23 +721,42 @@ def _transform_sources(
715721
716722
After this step, cells are ready for execution or rendering.
717723
"""
718-
source_transforms: list[Transform] = [
724+
from marimo._convert.comment_preserver import CommentPreserver
725+
726+
# Define transforms that don't need comment preservation
727+
simple_transforms = [
719728
transform_strip_whitespace,
720729
transform_magic_commands,
721730
transform_exclamation_mark,
731+
]
732+
733+
# Define transforms that should preserve comments
734+
comment_preserving_transforms = [
722735
transform_remove_duplicate_imports,
723736
transform_fixup_multiple_definitions,
724737
transform_duplicate_definitions,
725738
]
726739

727-
# Run all the source transforms
728-
for source_transform in source_transforms:
740+
# Run simple transforms first (no comment preservation needed)
741+
for source_transform in simple_transforms:
729742
new_sources = source_transform(sources)
730743
assert len(new_sources) == len(sources), (
731744
f"{source_transform.__name__} changed cell count"
732745
)
733746
sources = new_sources
734747

748+
# Create comment preserver from the simplified sources
749+
comment_preserver = CommentPreserver(sources)
750+
751+
# Run comment-preserving transforms
752+
for base_transform in comment_preserving_transforms:
753+
transform = comment_preserver(base_transform)
754+
new_sources = transform(sources)
755+
assert len(new_sources) == len(sources), (
756+
f"{base_transform.__name__} changed cell count"
757+
)
758+
sources = new_sources
759+
735760
cells = bind_cell_metadata(sources, metadata, hide_flags)
736761

737762
# may change cell count

0 commit comments

Comments
 (0)