diff --git a/marimo/_ast/transformers.py b/marimo/_ast/transformers.py index 5fd54253803..d15f1185e58 100644 --- a/marimo/_ast/transformers.py +++ b/marimo/_ast/transformers.py @@ -4,7 +4,7 @@ import ast import inspect import textwrap -from typing import TYPE_CHECKING, Any, Callable, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast from marimo._ast.parse import ast_parse from marimo._ast.variables import unmangle_local @@ -14,6 +14,8 @@ ARG_PREFIX: str = "*" +T = TypeVar("T", bound="ast.Import | ast.ImportFrom") + class BlockException(Exception): pass @@ -190,10 +192,24 @@ class RemoveImportTransformer(ast.NodeTransformer): To prevent module collisions in top level definitions. """ - def __init__(self, import_name: str) -> None: + def __init__(self, import_name: str, keep_one: bool = False) -> None: super().__init__() + self.keep_one = keep_one self.import_name = import_name + def _return_once( + self, + node: T, + original_names: list[ast.alias], + ) -> Optional[T]: + if node.names: + return node + elif self.keep_one: + self.keep_one = False + node.names = original_names + return node + return None + def strip_imports(self, code: str) -> str: tree = ast_parse(code) tree = self.visit(tree) @@ -201,25 +217,27 @@ def strip_imports(self, code: str) -> str: def visit_Import(self, node: ast.Import) -> Optional[ast.Import]: name = self.import_name + original_names = list(node.names) node.names = [ alias for alias in node.names if (alias.asname and alias.asname != name) or (not alias.asname and alias.name != name) ] - return node if node.names else None + return self._return_once(node, original_names) def visit_ImportFrom( self, node: ast.ImportFrom ) -> Optional[ast.ImportFrom]: name = self.import_name + original_names = list(node.names) node.names = [ alias for alias in node.names if (alias.asname and alias.asname != name) or (not alias.asname and alias.name != name) ] - return node if node.names else None + return self._return_once(node, original_names) class ExtractWithBlock(ast.NodeTransformer): diff --git a/marimo/_convert/comment_preserver.py b/marimo/_convert/comment_preserver.py new file mode 100644 index 00000000000..c9c539bda8b --- /dev/null +++ b/marimo/_convert/comment_preserver.py @@ -0,0 +1,186 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +import io +import token as token_types +from dataclasses import dataclass +from tokenize import TokenError, tokenize +from typing import Callable + + +@dataclass +class CommentToken: + text: str + line: int + col: int + + +class CommentPreserver: + """Functor to preserve comments during source code transformations.""" + + def __init__(self, sources: list[str]): + self.sources = sources + self.comments_by_source: dict[int, list[CommentToken]] = {} + self._extract_all_comments() + + def _extract_all_comments(self) -> None: + """Extract comments from all sources during initialization.""" + for i, source in enumerate(self.sources): + self.comments_by_source[i] = self._extract_comments_from_source( + source + ) + + def _extract_comments_from_source(self, source: str) -> list[CommentToken]: + """Extract comments from a single source string.""" + if not source.strip(): + return [] + + comments = [] + try: + tokens = tokenize(io.BytesIO(source.encode("utf-8")).readline) + for token in tokens: + if token.type == token_types.COMMENT: + comments.append( + CommentToken( + text=token.string, + line=token.start[0], + col=token.start[1], + ) + ) + except (TokenError, SyntaxError): + # If tokenization fails, return empty list - no comments preserved + pass + + return comments + + def __call__( + self, transform_func: Callable[..., list[str]] + ) -> Callable[..., list[str]]: + """ + Method decorator that returns a comment-preserving version of transform_func. + + Usage: preserver(transform_func)(sources, *args, **kwargs) + """ + + def wrapper(*args: object, **kwargs: object) -> list[str]: + # Apply the original transformation + transformed_sources = transform_func(*args, **kwargs) + + # If sources weren't provided or transformation failed, return as-is + if not args or not isinstance(args[0], list): + return transformed_sources + + original_sources = args[0] + + # Merge comments back into transformed sources + result = self._merge_comments( + original_sources, transformed_sources + ) + + # Update our internal comment data to track only the clean transformed sources + # This clears old comments that no longer apply + self._update_comments_for_transformed_sources(transformed_sources) + + return result + + return wrapper + + def _merge_comments( + self, + original_sources: list[str], + transformed_sources: list[str], + ) -> list[str]: + """Merge comments from original sources into transformed sources.""" + if len(original_sources) != len(transformed_sources): + # If cell count changed, we can't preserve comments reliably + return transformed_sources + + result = [] + for i, (original, transformed) in enumerate( + zip(original_sources, transformed_sources) + ): + comments = self.comments_by_source.get(i, []) + if not comments: + result.append(transformed) + continue + + # Apply comment preservation with variable name updates if needed + preserved_source = self._apply_comments_to_source( + original, transformed, comments + ) + result.append(preserved_source) + + return result + + def _apply_comments_to_source( + self, + original: str, + transformed: str, + comments: list[CommentToken], + ) -> str: + """Apply comments to a single transformed source.""" + if not comments: + return transformed + + original_lines = original.split("\n") + transformed_lines = transformed.split("\n") + + # Create a mapping of line numbers to comments + comments_by_line: dict[int, list[CommentToken]] = {} + for comment in comments: + line_num = comment.line + if line_num not in comments_by_line: + comments_by_line[line_num] = [] + comments_by_line[line_num].append(comment) + + # Apply comments to transformed lines + result_lines = transformed_lines.copy() + + for line_num, line_comments in comments_by_line.items(): + target_line_idx = min( + line_num - 1, len(result_lines) - 1 + ) # Convert to 0-based, clamp to bounds + + if target_line_idx < 0: + continue + + # Select the best comment for this line (line comments take precedence) + line_comment = None + inline_comment = None + + for comment in line_comments: + if comment.col == 0: # Line comment (starts at column 0) + line_comment = comment + break # Line comment takes precedence, no need to check others + else: # Inline comment + inline_comment = comment + + # Prefer line comment over inline comment + chosen_comment = line_comment if line_comment else inline_comment + + if chosen_comment: + comment_text = chosen_comment.text + if chosen_comment.col > 0 and target_line_idx < len( + original_lines + ): + # Inline comment - append to the line if not already present + current_line = result_lines[target_line_idx] + if not current_line.rstrip().endswith( + comment_text.rstrip() + ): + result_lines[target_line_idx] = ( + current_line.rstrip() + " " + comment_text + ) + elif target_line_idx >= 0 and comment_text not in result_lines: + # Standalone comment - insert above the line if not already present + result_lines.insert(target_line_idx, comment_text) + + return "\n".join(result_lines) + + def _update_comments_for_transformed_sources( + self, sources: list[str] + ) -> None: + """Update internal comment data to track the transformed sources.""" + self.sources = sources + self.comments_by_source = {} + self._extract_all_comments() diff --git a/marimo/_convert/ipynb.py b/marimo/_convert/ipynb.py index 660397f8418..ba5c89cfafe 100644 --- a/marimo/_convert/ipynb.py +++ b/marimo/_convert/ipynb.py @@ -11,7 +11,7 @@ from marimo._ast.cell import CellConfig from marimo._ast.compiler import compile_cell -from marimo._ast.transformers import NameTransformer +from marimo._ast.transformers import NameTransformer, RemoveImportTransformer from marimo._ast.variables import is_local from marimo._ast.visitor import Block, NamedNode, ScopedVisitor from marimo._convert.utils import markdown_to_marimo @@ -651,20 +651,26 @@ def transform_remove_duplicate_imports(sources: list[str]) -> list[str]: imports: set[str] = set() new_sources: list[str] = [] for source in sources: - new_lines: list[str] = [] - for line in source.split("\n"): - stripped_line = line.strip() - if stripped_line.startswith("import ") or stripped_line.startswith( - "from " - ): - if stripped_line not in imports: - imports.add(stripped_line) - new_lines.append(line) - else: - new_lines.append(line) - - new_source = "\n".join(new_lines) - new_sources.append(new_source.strip()) + try: + cell = compile_cell(source, cell_id=CellId_t("temp")) + except SyntaxError: + new_sources.append(source) + continue + scoped = set() + for var, instances in cell.variable_data.items(): + for instance in instances: + if ( + var in imports or var in scoped + ) and instance.kind == "import": + # If it's not in global imports, we keep one instance + keep_one = var not in imports + transformer = RemoveImportTransformer( + var, keep_one=keep_one + ) + source = transformer.strip_imports(source) + scoped.add(var) + imports.update(scoped) + new_sources.append(source) return new_sources @@ -715,23 +721,42 @@ def _transform_sources( After this step, cells are ready for execution or rendering. """ - source_transforms: list[Transform] = [ + from marimo._convert.comment_preserver import CommentPreserver + + # Define transforms that don't need comment preservation + simple_transforms = [ transform_strip_whitespace, transform_magic_commands, transform_exclamation_mark, + ] + + # Define transforms that should preserve comments + comment_preserving_transforms = [ transform_remove_duplicate_imports, transform_fixup_multiple_definitions, transform_duplicate_definitions, ] - # Run all the source transforms - for source_transform in source_transforms: + # Run simple transforms first (no comment preservation needed) + for source_transform in simple_transforms: new_sources = source_transform(sources) assert len(new_sources) == len(sources), ( f"{source_transform.__name__} changed cell count" ) sources = new_sources + # Create comment preserver from the simplified sources + comment_preserver = CommentPreserver(sources) + + # Run comment-preserving transforms + for base_transform in comment_preserving_transforms: + transform = comment_preserver(base_transform) + new_sources = transform(sources) + assert len(new_sources) == len(sources), ( + f"{base_transform.__name__} changed cell count" + ) + sources = new_sources + cells = bind_cell_metadata(sources, metadata, hide_flags) # may change cell count diff --git a/tests/_convert/ipynb_data/comments_preservation.ipynb b/tests/_convert/ipynb_data/comments_preservation.ipynb new file mode 100644 index 00000000000..81a3de0bb07 --- /dev/null +++ b/tests/_convert/ipynb_data/comments_preservation.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + } + ], + "source": [ + "# Cell 1: Basic inline and line comments\n", + "x = 1 # This is an inline comment\n", + "print(x) # Another inline comment" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10\n" + ] + } + ], + "source": [ + "# Cell 2: Comments with duplicate definitions\n", + "# This variable will be redefined later\n", + "y = 10 # First definition\n", + "print(y) # Should preserve this comment" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40\n" + ] + } + ], + "source": [ + "# Cell 3: Redefinition with more comments\n", + "# This is the second definition of y\n", + "y = 20 # Second definition with inline comment\n", + "result = y * 2 # Calculate result\n", + "print(result) # Print the result" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 4 μs, sys: 1 μs, total: 5 μs\n", + "Wall time: 8.11 μs\n", + "cell_metadata.ipynb\n", + "comments_preservation.ipynb\n", + "duplicate_definitions_and_aug_assign.ipynb\n", + "duplicate_definitions_read_before_write.ipynb\n", + "duplicate_definitions_syntax_error.ipynb\n", + "hides_markdown_cells.ipynb\n", + "multiple_definitions_multiline.ipynb\n", + "multiple_definitions.ipynb\n", + "pip_commands.ipynb\n" + ] + } + ], + "source": [ + "# Cell 4: Magic commands with comments\n", + "%time x = 5 # Time this operation\n", + "# Comment before magic\n", + "%ls # List files" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "UsageError: Cell magic `%%sql` not found.\n" + ] + } + ], + "source": [ + "%%sql\n", + "-- SQL comment inside magic\n", + "SELECT * FROM table -- Another SQL comment\n", + "WHERE id > 0 # Python-style comment in SQL" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'numpy'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Cell 6: Comments with import statements\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m \u001b[38;5;66;03m# Import numpy\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpd\u001b[39;00m \u001b[38;5;66;03m# Import pandas\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;66;03m# These imports might be duplicated elsewhere\u001b[39;00m\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'numpy'" + ] + } + ], + "source": [ + "# Cell 6: Comments with import statements\n", + "import numpy as np # Import numpy\n", + "import pandas as pd # Import pandas\n", + "# These imports might be duplicated elsewhere" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'numpy'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Cell 7: Duplicate import with different comment\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m \u001b[38;5;66;03m# Import numpy again with different comment\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;66;03m# This should be deduplicated but preserve one of the comments\u001b[39;00m\n\u001b[32m 4\u001b[39m arr = np.array([\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, \u001b[32m3\u001b[39m]) \u001b[38;5;66;03m# Create array\u001b[39;00m\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'numpy'" + ] + } + ], + "source": [ + "# Cell 7: Duplicate import with different comment\n", + "import numpy as np # Import numpy again with different comment\n", + "# This should be deduplicated but preserve one of the comments\n", + "arr = np.array([1, 2, 3]) # Create array" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 8: Complex expressions with comments\n", + "# Calculate something complex\n", + "z = (x * 2 # Multiply x by 2\n", + " + y) # Add y to the result\n", + "# Final comment in cell" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 9: Augmented assignment with comments\n", + "counter = 0 # Initialize counter\n", + "counter += 1 # Increment counter (this will be transformed)\n", + "# This should preserve comments during aug assign transformation" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 10: Function definitions with comments\n", + "def my_function(): # Define a function\n", + " \"\"\"This is a docstring, not a comment.\"\"\"\n", + " # This is a comment inside the function\n", + " return 42 # Return a value\n", + "# Comment after function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/_convert/snapshots/convert_comments_preservation.py.txt b/tests/_convert/snapshots/convert_comments_preservation.py.txt new file mode 100644 index 00000000000..aec962af943 --- /dev/null +++ b/tests/_convert/snapshots/convert_comments_preservation.py.txt @@ -0,0 +1,107 @@ +import marimo + +app = marimo.App() + + +@app.cell +def _(): + # Cell 1: Basic inline and line comments + x = 1 # This is an inline comment + print(x) # Another inline comment + return (x,) + + +@app.cell +def _(): + # Cell 2: Comments with duplicate definitions + # This variable will be redefined later + y = 10 # First definition + print(y) # Should preserve this comment + return + + +@app.cell +def _(): + # Cell 3: Redefinition with more comments + # This is the second definition of y + y_1 = 20 # Second definition with inline comment + result = y_1 * 2 # Calculate result + print(result) # Print the result + return (y_1,) + + +@app.cell +def _(): + # Cell 4: Magic commands with comments + # magic command not supported in marimo; please file an issue to add support + # %time x = 5 # Time this operation + # Comment before magic + import os + os.listdir() + return + + +@app.cell +def _(mo): + _df = mo.sql(""" + -- SQL comment inside magic + SELECT * FROM table -- Another SQL comment + WHERE id > 0 # Python-style comment in SQL + """) + return + + +@app.cell +def _(): + # Cell 6: Comments with import statements + import numpy as np # Import numpy + import pandas as pd # Import pandas + # These imports might be duplicated elsewhere + return (np,) + + +@app.cell +def _(np): + # Cell 7: Duplicate import with different comment + # This should be deduplicated but preserve one of the comments + arr = np.array([1, 2, 3]) # Import numpy again with different comment # Create array + return + + +@app.cell +def _(x, y_1): + # Cell 8: Complex expressions with comments + # Calculate something complex + # Final comment in cell + z = x * 2 + y_1 # Multiply x by 2 # Add y to the result + return + + +@app.cell +def _(): + # Cell 9: Augmented assignment with comments + counter = 0 # Initialize counter + # This should preserve comments during aug assign transformation + counter = counter + 1 # Increment counter (this will be transformed) + return + + +@app.cell +def _(): + # Cell 10: Function definitions with comments + def my_function(): # Define a function + """This is a docstring, not a comment.""" + # This is a comment inside the function + return 42 # Return a value + # Comment after function + return + + +@app.cell +def _(): + import marimo as mo + return (mo,) + + +if __name__ == "__main__": + app.run() diff --git a/tests/_convert/snapshots/convert_multiple_definitions.py.txt b/tests/_convert/snapshots/convert_multiple_definitions.py.txt index e668aef49a7..54aca2a8577 100644 --- a/tests/_convert/snapshots/convert_multiple_definitions.py.txt +++ b/tests/_convert/snapshots/convert_multiple_definitions.py.txt @@ -6,14 +6,14 @@ app = marimo.App() @app.cell def _(): _x = 1 - print(_x) + print(_x) # print return @app.cell def _(): _x = 2 - print(_x) + print(_x) # print return diff --git a/tests/_convert/test_ipynb.py b/tests/_convert/test_ipynb.py index 6a3eab5a1c9..af0e02d4719 100644 --- a/tests/_convert/test_ipynb.py +++ b/tests/_convert/test_ipynb.py @@ -423,7 +423,7 @@ def test_transform_remove_duplicate_imports_complex(): assert result == [ "import numpy as np\nfrom pandas import DataFrame\nimport matplotlib.pyplot as plt", # noqa: E501 "from sklearn.model_selection import train_test_split, cross_val_score", # noqa: E501 - "from pandas import Series\nfrom matplotlib import pyplot as plt\nimport pandas as pd", # noqa: E501 + "from pandas import Series\nimport pandas as pd", # noqa: E501 ] @@ -505,7 +505,7 @@ def test_transform_remove_duplicate_imports_with_aliases(): assert result == [ "import numpy as np\nimport pandas as pd", "import numpy as numpy\nfrom pandas import DataFrame as DF", - "import numpy\nfrom pandas import Series", + "from pandas import Series", ]