-
Notifications
You must be signed in to change notification settings - Fork 233
Add macOS-only mlx-lm node to node-hub (#882) #985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Clement795
wants to merge
5
commits into
dora-rs:main
Choose a base branch
from
Clement795:add-mlx-lm-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
10d9e4b
Add macOS-only mlx-lm node to node-hub (#882)
Clement795 148b2fd
Add dora-mlx-lm to skip_test_folders to fix CI
Clement795 f665b13
Add dora-mlx-lm to ignored_folders to fix CI
Clement795 9bd6dfb
Rename input event id from 'prompt' to 'text' for consistency
Clement795 8d37291
Merge branch 'main' into add-mlx-lm-v2
phil-opp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # Dora MLX-LM Node | ||
|
|
||
| ## Overview | ||
|
|
||
| The `dora-mlx-lm` node integrates the [`mlx-lm`](https://github.com/ml-explore/mlx-lm) library to run large language models (LLMs) optimized for Apple Silicon (M1, M2, M3, and later) on macOS. It processes text prompts as input and generates text responses using a model such as `mlx-community/SmolLM-135M-Instruct-4bit`. The node is designed for use within a [Dora-rs-cli](https://github.com/dora-rs/dora) pipeline, supporting features like activation words, conversation history, and performance metadata.`. | ||
|
|
||
| ## Installation | ||
|
|
||
| To use the `dora-mlx-lm` node, install the required dependencies: | ||
|
|
||
| ```bash | ||
| pip install dora-rs-cli mlx-lm | ||
| ``` | ||
|
|
||
| ## Usage | ||
|
|
||
| 1. **Add the node to your Dora pipeline**: | ||
|
|
||
| Include the `dora-mlx-lm` node in your pipeline YAML file. Below is an example configuration: | ||
|
|
||
| ```yaml | ||
| nodes: | ||
| - id: mlx_lm | ||
| build: pip install mlx-lm | ||
| path: dora-mlx-lm/main.py | ||
| inputs: | ||
| prompt: dora/input | ||
| outputs: | ||
| - text | ||
| env: | ||
| MODEL_PATH: mlx-community/SmolLM-135M-Instruct-4bit | ||
| SYSTEM_PROMPT: "You are a helpful assistant optimized for Apple M-series chips." | ||
| MAX_TOKENS: "100" | ||
| TEMPERATURE: "0.7" | ||
| CONTEXT_SIZE: "2048" | ||
| ACTIVATION_WORDS: "hey assistant" | ||
| ``` | ||
|
|
||
| ### Environment Variables | ||
| - `MODEL_PATH`: Path or Hugging Face ID of the model (default: `mlx-community/SmolLM-135M-Instruct-4bit`). | ||
| - `SYSTEM_PROMPT`: Optional system prompt to define the model's behavior (default: empty). | ||
| - `MAX_TOKENS`: Maximum number of tokens to generate (default: 100). | ||
| - `TEMPERATURE`: Sampling temperature for generation (default: 0.7). | ||
| - `CONTEXT_SIZE`: Maximum context length for conversation history (default: 2048). | ||
| - `ACTIVATION_WORDS`: Space-separated list of words to trigger the node (default: empty, processes all inputs). | ||
|
|
||
| 2. **Run the pipeline**: | ||
|
|
||
| Build and execute your pipeline using the Dora CLI: | ||
|
|
||
| ```bash | ||
| dora build your_pipeline.yml --uv | ||
| dora run your_pipeline.yml --uv | ||
| ``` | ||
| ## Inputs | ||
|
|
||
| - **prompt**: A text string to be processed by the LLM (e.g., "Write a short story about a robot"). The node validates that the input is a non-empty `pyarrow.Array` containing a string. | ||
|
|
||
| ## Outputs | ||
|
|
||
| - **text**: The text response generated by the LLM, sent as a `pyarrow.Array`. The output includes metadata such as: | ||
| - `processing_time`: Time taken to generate the response (in seconds). | ||
| - `model`: The model used (e.g., `mlx-community/SmolLM-135M-Instruct-4bit`). | ||
| - `optimized_for`: Indicates optimization for Apple's M-series chips. | ||
|
|
||
| ## Features | ||
|
|
||
| - **Apple Silicon Optimization**: Leverages the MLX framework for efficient inference on M1, M2, M3, and later chips, with automatic GPU and Neural Engine acceleration. | ||
| - **Conversation History**: Maintains a conversation history with a configurable system prompt, truncated based on `CONTEXT_SIZE`. | ||
| - **Activation Words**: Optionally processes inputs only when they contain specified activation words. | ||
| - **Robust Error Handling**: Validates inputs and logs errors for reliable pipeline integration. | ||
| - **Metadata**: Provides performance metrics and configuration details in output metadata. | ||
|
|
||
| ### Using mlx-lm in Dora Node Hub | ||
| - **Platform**: macOS 13.5+ (ARM-native Python required) | ||
| - Note: This node is only supported on macOS and skips execution on Linux/Windows. | ||
|
|
||
| ## Notes | ||
|
|
||
| - The node uses `mlx-lm`, which is optimized for Apple Silicon. Parameters like `N_GPU_LAYERS` or `N_THREADS` (common in other frameworks like `llama_cpp`) are not applicable, as MLX manages resource allocation internally. | ||
| - For large models, use quantized versions (e.g., 4-bit) to optimize memory usage and performance. | ||
| - The conversation history is truncated to respect the `CONTEXT_SIZE` limit, ensuring compatibility with the model's context length. | ||
|
|
||
| ## License | ||
|
|
||
| This node is licensed under the [MIT License](https://opensource.org/licenses/MIT), consistent with the `mlx-lm` library. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| """TODO: Add docstring.""" | ||
|
|
||
| import os | ||
|
|
||
| # Define the path to the README file relative to the package directory | ||
| readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md") | ||
|
|
||
| # Read the content of the README file | ||
| try: | ||
| with open(readme_path, encoding="utf-8") as f: | ||
| __doc__ = f.read() | ||
| except FileNotFoundError: | ||
| __doc__ = "README file not found." |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .main import main | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| """Dora node for generating text responses using a pre-trained language model, optimized for Apple M1, M2, M3 chips. | ||
|
|
||
| This node listens for input prompts on the 'prompt' channel, generates text using | ||
| a pre-trained model (default: SmolLM-135M-Instruct-4bit) optimized for Apple's M-series | ||
| chips via MLX, and sends responses to the 'text' output channel. The node can be configured | ||
| via environment variables and supports activation words to filter inputs. | ||
|
|
||
| Note: This node is only supported on macOS. It skips execution on Linux and Windows. | ||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
| import platform | ||
| import sys | ||
| import time | ||
| from pathlib import Path | ||
|
|
||
| # Vérifier si la plateforme est macOS | ||
| if platform.system() != "Darwin": | ||
| logging.basicConfig(level=logging.INFO) | ||
| logging.info("mlx-lm is only supported on macOS. Skipping execution on %s.", platform.system()) | ||
| sys.exit(0) # Sortir sans erreur pour éviter un échec CI | ||
|
|
||
| import pyarrow as pa | ||
| from dora import Node | ||
| from mlx_lm import load, generate | ||
|
|
||
| # Configure logging | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
||
| # Environment variables for model configuration | ||
| MODEL_PATH = os.getenv("MODEL_PATH", "mlx-community/SmolLM-135M-Instruct-4bit") | ||
| SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "") | ||
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "100")) | ||
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) | ||
| CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "2048")) # Context length for the model | ||
| ACTIVATION_WORDS = os.getenv("ACTIVATION_WORDS", "").split() | ||
|
|
||
| def get_model(): | ||
| """Load a pre-trained language model and tokenizer optimized for Apple M1/M2/M3 chips.""" | ||
| try: | ||
| logging.info(f"Loading model from {MODEL_PATH} for Apple M-series optimization") | ||
| model, tokenizer = load( | ||
| MODEL_PATH, tokenizer_config={"eos_token": "<|im_end|>"} | ||
| ) | ||
| logging.info("Model loaded successfully with MLX for M1/M2/M3 performance") | ||
| return model, tokenizer | ||
| except Exception as e: | ||
| logging.exception(f"Error loading model: {e}") | ||
| raise | ||
|
|
||
| def main(): | ||
| """Process input events and generate text responses using the loaded model. | ||
|
|
||
| Optimized for Apple M1, M2, M3 chips using the MLX framework for efficient inference. | ||
| Generates responses independently for each input, using only the system prompt as context. | ||
| """ | ||
| # Initialize model and tokenizer | ||
| model, tokenizer = get_model() | ||
| node = Node() | ||
| history = [{"role": "system", "content": SYSTEM_PROMPT}] if SYSTEM_PROMPT else [] | ||
|
|
||
| for event in node: | ||
| if event["type"] == "INPUT" and event["id"] == "prompt": | ||
| # Validate input | ||
| if not isinstance(event["value"], pa.Array) or len(event["value"]) == 0: | ||
| logging.error("Invalid input: expected a non-empty pyarrow.Array") | ||
| continue | ||
| text = event["value"][0].as_py() | ||
| if not isinstance(text, str): | ||
| logging.error("Invalid input: expected a string") | ||
| continue | ||
|
|
||
| words = text.lower().split() | ||
| if len(ACTIVATION_WORDS) == 0 or any( | ||
| word in ACTIVATION_WORDS for word in words | ||
| ): | ||
| try: | ||
| start_time = time.time() | ||
| messages = history + [{"role": "user", "content": text}] | ||
| formatted_prompt = tokenizer.apply_chat_template( | ||
| messages, add_generation_prompt=True | ||
| ) | ||
|
|
||
| response = generate( | ||
| model, | ||
| tokenizer, | ||
| prompt=formatted_prompt, | ||
| max_tokens=MAX_TOKENS, | ||
| temp=TEMPERATURE, | ||
| verbose=False, | ||
| ) | ||
|
|
||
| processing_time = time.time() - start_time | ||
| node.send_output( | ||
| output_id="text", | ||
| data=pa.array([response]), | ||
| metadata={ | ||
| "processing_time": processing_time, | ||
| "model": MODEL_PATH, | ||
| "optimized_for": "Apple M1/M2/M3", | ||
| }, | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| logging.exception(f"Error generating response: {e}") | ||
|
|
||
| elif event["type"] == "STOP": | ||
| logging.info("Received STOP event, cleaning up...") | ||
| model = None | ||
| tokenizer = None | ||
| break | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| [project] | ||
| name = "dora-mlx-lm" | ||
| version = "0.1.0" | ||
| authors = [{ name = "Clément Leprêtre", email = "[email protected]" }] | ||
| description = "DORA node for running MLX-LM large language models" | ||
| license = { text = "MIT" } | ||
| readme = "README.md" | ||
| requires-python = ">=3.7" | ||
| dependencies = [ | ||
| "mlx-lm>=0.23.2", | ||
| "dora-rs>=0.3.11" | ||
| ] | ||
|
|
||
| [project.urls] | ||
| Repository = "https://github.com/dora-rs/dora" | ||
|
|
||
| [tool.ruff.lint] | ||
| extend-select = [ | ||
| "D", # pydocstyle | ||
| "UP", # Ruff's UP rule | ||
| "PERF", # Ruff's PERF rule | ||
| "RET", # Ruff's RET rule | ||
| "RSE", # Ruff's RSE rule | ||
| "NPY", # Ruff's NPY rule | ||
| "N", # Ruff's N rule | ||
| "I", # Ruff's I rule | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| import pytest | ||
|
|
||
|
|
||
| def test_mlx_lm_node(): | ||
| """ | ||
| Test the import and execution of the mlx_lm_node function. | ||
|
|
||
| This test verifies that the mlx_lm_node function can be imported from the dora_mlx_lm module | ||
| and checks that calling it outside a DORA dataflow raises a RuntimeError, as expected. | ||
| """ | ||
| from dora_mlx_lm.main import main | ||
|
|
||
| # Check that calling the node function raises a RuntimeError, as it requires a DORA dataflow environment. | ||
| with pytest.raises(RuntimeError): | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.