Conversation
|
This looks interesting. But in general I'm not convinced we need to go this route of fine-grained locking. It might work just as well, maybe even better, and be a lot cleaner/ faster to have a lock where we do the task submission in the main A couple higher level comments:
|
|
Thank you! Let me see what I can do along those lines. The tests definitely make sense, I wasn't sure what performance benchmark made sense along these lines. If the existing ones are fine and see what impact the changes have? |
I would do more end-to-end benchmarks. And focuse on more latency sensitive ones since this type of change matters there. So for example LM inference with a smallish LM (like 4-bit 1-3B in size) would be a good place to start (you can use mlx-lm for that). |
82a117e to
28902ec
Compare
|
I'm still working on some good simple tests, I ran in to a few more errors with the prior proposed changes. But I wanted to ask what you thought of this approach, I appreciate any feedback. I've spot checked the default model with |
|
I like this new approach as it's much simpler. Though I do wonder about the possibility of deadlock. Say we have two streams: Stream A is waiting on the output of Stream B Something like that seems plausible in a multi-threaded setup. I'm not sure it's necessarily a dealbreaker because sharing graphs across threads is not a good idea for other reasons. But it would be good to setup up a few C++ tests to really exercise the multi-threaded cases we expect this to work for. |
|
I've added a few tests, how do they look to you? The changes caused one test around buffers to very occasionally fail ( |
|
I ran a few benchmarks, apologies for the delay! @awni Prompt TPS
Generation TPS
The benchmark was pretty simple, prompt was very short (could make it longer). I set the max tokens to 1000 (which the qwen models sometimes reached in my benchmark). Here's the code too for reference. More trials could be run, and with a longer prompt, but hopefully this gives a decent idea on the time difference. from mlx_lm import load, stream_generate
import pandas as pd
max_tokens = 1_000
verbose = False
warmup_count = 3
num_trials = 10
df_results = pd.DataFrame()
checkpoints = [
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/Llama-3.2-3B-Instruct-4bit",
"mlx-community/Qwen3-0.6B-4bit",
"mlx-community/Qwen3-0.6B-6bit",
"mlx-community/Qwen3-0.6B-8bit",
"mlx-community/Qwen3-1.7B-3bit",
"mlx-community/Qwen3-1.7B-4bit",
]
for checkpoint in checkpoints:
model, tokenizer = load(path_or_hf_repo=checkpoint)
prompt = "Hello! I'm teaching a science class on our solar system and wanted to ask for your help! " \
"Could you tell what the planets in our solar system are called, and a little about each one?"
conversation = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
conversation=conversation, add_generation_prompt=True
)
for _ in range(warmup_count):
text = ""
for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
if verbose:
print(response.text, end="", flush=True)
text += response.text
for i in range(num_trials):
text = ""
for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
if verbose:
print(response.text, end="", flush=True)
text += response.text
response_dict = {
'model': checkpoint,
'trial': i,
'prompt_tokens': response.prompt_tokens,
'prompt_tps': response.prompt_tps,
'generation_tokens': response.generation_tokens,
'generation_tps': response.generation_tps,
'peak_memory': response.peak_memory,
}
df_trial = pd.DataFrame(response_dict, index=[0])
df_results = pd.concat([df_results, df_trial], ignore_index=True)
print(df_results.head())
df_results.to_csv('trial_runs.csv', index=False)As usual, any feedback is greatly appreciated! |
|
The |
|
It does seem to be slightly slower (with some high variance on the prompt tps). Not sure where to go from here. If this makes the PR a no go, if it's possible I can try relaxing some of the mutex locks, or if it's something with my benchmark. Prompt TPS
Generation TPS
|
|
I have put some thoughts about thread safety in #3078 (comment): basically I think we should not try to achieve thread safety for arrays in different threads, at least not in the first try, for more practical targets we shouldn't need a global mutex. |
Proposed changes
These changes are an attempt to improve thread safety for the metal backend. This is related to #2067
Please let me know what you think.
Checklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes