diff --git a/marimo/_runtime/packages/conda_package_manager.py b/marimo/_runtime/packages/conda_package_manager.py index f1641a1a48d..d1f426c657e 100644 --- a/marimo/_runtime/packages/conda_package_manager.py +++ b/marimo/_runtime/packages/conda_package_manager.py @@ -1,14 +1,11 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import Optional - from marimo._runtime.packages.module_name_to_conda_name import ( module_name_to_conda_name, ) from marimo._runtime.packages.package_manager import ( CanonicalizingPackageManager, - LogCallback, PackageDescription, ) from marimo._runtime.packages.utils import split_packages @@ -25,23 +22,12 @@ def _construct_module_name_mapping(self) -> dict[str, str]: class PixiPackageManager(CondaPackageManager): name = "pixi" - async def _install( - self, - package: str, - *, - upgrade: bool, - log_callback: Optional[LogCallback] = None, - ) -> bool: - if upgrade: - return self.run( - ["pixi", "upgrade", *split_packages(package)], - log_callback=log_callback, - ) - else: - return self.run( - ["pixi", "add", *split_packages(package)], - log_callback=log_callback, - ) + def install_command(self, package: str, *, upgrade: bool) -> list[str]: + return [ + "pixi", + "upgrade" if upgrade else "add", + *split_packages(package), + ] async def uninstall(self, package: str) -> bool: return self.run( diff --git a/marimo/_runtime/packages/package_manager.py b/marimo/_runtime/packages/package_manager.py index a3338e509c8..c38f1f35d6f 100644 --- a/marimo/_runtime/packages/package_manager.py +++ b/marimo/_runtime/packages/package_manager.py @@ -56,7 +56,16 @@ def is_manager_installed(self) -> bool: ) return False - @abc.abstractmethod + def install_command(self, package: str, *, upgrade: bool) -> list[str]: + """ + Get the shell command to install a package (where applicable). + + Used by the _install method. If not applicable (for example, with micropip), + override the _install method instead. + """ + # PackageManager's may not implement this method if they override _install + raise NotImplementedError + async def _install( self, package: str, @@ -65,7 +74,10 @@ async def _install( log_callback: Optional[LogCallback] = None, ) -> bool: """Installation logic.""" - ... + return self.run( + self.install_command(package, upgrade=upgrade), + log_callback=log_callback, + ) async def install( self, diff --git a/marimo/_runtime/packages/pypi_package_manager.py b/marimo/_runtime/packages/pypi_package_manager.py index c70f59dcc57..96af6c2c960 100644 --- a/marimo/_runtime/packages/pypi_package_manager.py +++ b/marimo/_runtime/packages/pypi_package_manager.py @@ -56,19 +56,15 @@ class PipPackageManager(PypiPackageManager): name = "pip" docs_url = "https://pip.pypa.io/" - async def _install( - self, - package: str, - *, - upgrade: bool, - log_callback: Optional[LogCallback] = None, - ) -> bool: - LOGGER.info(f"Installing {package} with pip") - cmd = ["pip", "--python", PY_EXE, "install"] - if upgrade: - cmd.append("--upgrade") - cmd.extend(split_packages(package)) - return self.run(cmd, log_callback=log_callback) + def install_command(self, package: str, *, upgrade: bool) -> list[str]: + return [ + "pip", + "--python", + PY_EXE, + "install", + *(["--upgrade"] if upgrade else []), + *split_packages(package), + ] async def uninstall(self, package: str) -> bool: LOGGER.info(f"Uninstalling {package} with pip") @@ -165,20 +161,11 @@ def _uv_bin(self) -> str: def is_manager_installed(self) -> bool: return self._uv_bin != "uv" or super().is_manager_installed() - async def _install( - self, - package: str, - *, - upgrade: bool, - log_callback: Optional[LogCallback] = None, - ) -> bool: + def install_command(self, package: str, *, upgrade: bool) -> list[str]: install_cmd: list[str] if self.is_in_uv_project: - LOGGER.info(f"Installing in {package} with 'uv add'") install_cmd = [self._uv_bin, "add"] else: - LOGGER.info(f"Installing in {package} with 'uv pip install'") - install_cmd = [self._uv_bin, "pip", "install"] # Allow for explicit site directory location if needed @@ -189,10 +176,28 @@ async def _install( if upgrade: install_cmd.append("--upgrade") - return self.run( + return install_cmd + [ # trade installation time for faster start time - install_cmd - + ["--compile", *split_packages(package), "-p", PY_EXE], + "--compile", + *split_packages(package), + "-p", + PY_EXE, + ] + + async def _install( + self, + package: str, + *, + upgrade: bool, + log_callback: Optional[LogCallback] = None, + ) -> bool: + """Installation logic.""" + LOGGER.info( + f"Installing in {package} with 'uv {'add' if self.is_in_uv_project else 'pip install'}'" + ) + return await super()._install( + package, + upgrade=upgrade, log_callback=log_callback, ) @@ -474,21 +479,12 @@ class RyePackageManager(PypiPackageManager): name = "rye" docs_url = "https://rye.astral.sh/" - async def _install( - self, - package: str, - *, - upgrade: bool, - log_callback: Optional[LogCallback] = None, - ) -> bool: - if upgrade: - return self.run( - ["rye", "sync", "--update", *split_packages(package)], - log_callback=log_callback, - ) - return self.run( - ["rye", "add", *split_packages(package)], log_callback=log_callback - ) + def install_command(self, package: str, *, upgrade: bool) -> list[str]: + return [ + "rye", + *(["sync", "--update"] if upgrade else ["add"]), + *split_packages(package), + ] async def uninstall(self, package: str) -> bool: return self.run( @@ -504,28 +500,13 @@ class PoetryPackageManager(PypiPackageManager): name = "poetry" docs_url = "https://python-poetry.org/docs/" - async def _install( - self, - package: str, - *, - upgrade: bool, - log_callback: Optional[LogCallback] = None, - ) -> bool: - if upgrade: - return self.run( - [ - "poetry", - "update", - "--no-interaction", - *split_packages(package), - ], - log_callback=log_callback, - ) - - return self.run( - ["poetry", "add", "--no-interaction", *split_packages(package)], - log_callback=log_callback, - ) + def install_command(self, package: str, *, upgrade: bool) -> list[str]: + return [ + "poetry", + "update" if upgrade else "add", + "--no-interaction", + *split_packages(package), + ] async def uninstall(self, package: str) -> bool: return self.run(