Skip to content

Commit 87234d2

Browse files
committed
Fix error in the batch method of the MLXLM model
1 parent ca0cf18 commit 87234d2

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

docs/features/models/mlxlm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ for chunk in model.stream("Write a short story about a cat.", max_tokens=100):
100100

101101
#### Batch Generation
102102

103-
The `MLXLM` model supports generating text in batches. To do so, use the `batch` method and provide a list of strings as a model input. For instance:
103+
The `MLXLM` model supports generating text in batches. To do so, use the `batch` method and provide a list of strings as a model input. However, constrained generation is not supported with batching, so you cannot provide an `output_type`. For instance:
104104

105105
```python
106106
import outlines

outlines/models/mlxlm.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,41 @@ def generate_batch(
173173
The list of text generated by the model.
174174
175175
"""
176-
from mlx_lm import generate_batch
176+
from mlx_lm import batch_generate
177177

178-
return generate_batch(
178+
if output_type:
179+
raise NotImplementedError(
180+
"mlx-lm does not support constrained generation with batching."
181+
+ "You cannot provide an `output_type` with this method."
182+
)
183+
184+
model_input = [self.type_adapter.format_input(item) for item in model_input]
185+
186+
# Contrarily to the other generate methods, batch_generate requires
187+
# tokenized prompts
188+
add_special_tokens = [
189+
(
190+
self.mlx_tokenizer.bos_token is None
191+
or not prompt.startswith(self.mlx_tokenizer.bos_token)
192+
)
193+
for prompt in model_input
194+
]
195+
tokenized_model_input = [
196+
self.mlx_tokenizer.encode(
197+
model_input[i], add_special_tokens=add_special_tokens[i]
198+
)
199+
for i in range(len(model_input))
200+
]
201+
202+
response = batch_generate(
179203
self.model,
180204
self.mlx_tokenizer,
181-
[self.type_adapter.format_input(item) for item in model_input],
182-
logits_processors=self.type_adapter.format_output_type(output_type),
205+
tokenized_model_input,
183206
**kwargs,
184207
)
185208

209+
return response.texts
210+
186211
def generate_stream(
187212
self,
188213
model_input: str,

tests/models/test_mlxlm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,15 @@ def test_mlxlm_batch(model):
128128
assert len(result) == 2
129129
assert isinstance(result[0], str)
130130
assert isinstance(result[1], str)
131+
132+
133+
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
134+
def test_mlxlm_batch_output_type(model):
135+
with pytest.raises(
136+
NotImplementedError,
137+
match="mlx-lm does not support constrained generation with batching."
138+
):
139+
model.batch(
140+
["Respond with one word.", "Respond with one word."],
141+
Regex(r"[0-9]")
142+
)

0 commit comments

Comments
 (0)