Skip to content

Commit 8d06bf3

Browse files
authored
feat(huggingface): support reasoning tokens (#558)
## Description <!-- What does this PR do? --> ## PR Type <!-- Delete the types that don't apply --!> 🆕 New Feature ## Relevant issues <!-- e.g. "Fixes #123" --> ## Checklist - [x] I have added unit tests that prove my fix/feature works - [x] New and existing tests pass locally - [x] Documentation was updated where necessary - [x] I have read and followed the [contribution guidelines](https://github.com/mozilla-ai/any-llm/blob/main/CONTRIBUTING.md)```
1 parent d7a9e26 commit 8d06bf3

File tree

6 files changed

+271
-14
lines changed

6 files changed

+271
-14
lines changed

src/any_llm/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55

66
INSIDE_NOTEBOOK = hasattr(builtins, "__IPYTHON__")
77

8+
REASONING_FIELD_NAMES = [
9+
"reasoning_content",
10+
"thinking",
11+
"think",
12+
"chain_of_thought",
13+
]
14+
815

916
class LLMProvider(StrEnum):
1017
"""String enum for supported providers."""

src/any_llm/providers/huggingface/huggingface.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Any
44

55
from any_llm.any_llm import AnyLLM
6+
from any_llm.constants import REASONING_FIELD_NAMES
67
from any_llm.types.completion import (
78
ChatCompletion,
89
ChatCompletionChunk,
@@ -11,6 +12,7 @@
1112
CompletionParams,
1213
CompletionUsage,
1314
CreateEmbeddingResponse,
15+
Reasoning,
1416
)
1517

1618
MISSING_PACKAGES_ERROR = None
@@ -21,6 +23,7 @@
2123
_convert_models_list,
2224
_convert_params,
2325
_create_openai_chunk_from_huggingface_chunk,
26+
_normalize_reasoning_on_message,
2427
)
2528
except ImportError as e:
2629
MISSING_PACKAGES_ERROR = e
@@ -47,7 +50,7 @@ class HuggingfaceProvider(AnyLLM):
4750
SUPPORTS_RESPONSES = False
4851
SUPPORTS_COMPLETION_IMAGE = False
4952
SUPPORTS_COMPLETION_PDF = False
50-
SUPPORTS_COMPLETION_REASONING = False
53+
SUPPORTS_COMPLETION_REASONING = True
5154
SUPPORTS_EMBEDDING = False
5255
SUPPORTS_LIST_MODELS = True
5356

@@ -101,14 +104,96 @@ def _init_client(self, api_key: str | None = None, api_base: str | None = None,
101104
**kwargs,
102105
)
103106

107+
@staticmethod
108+
def _find_reasoning_tag(text: str, opening: bool = True) -> tuple[int, str] | None:
109+
"""Find the first reasoning tag (opening or closing) in text.
110+
111+
Returns (position, tag_name) or None if no tag found.
112+
"""
113+
earliest_pos = len(text)
114+
earliest_tag = None
115+
116+
for tag_name in REASONING_FIELD_NAMES:
117+
tag = f"<{tag_name}>" if opening else f"</{tag_name}>"
118+
pos = text.find(tag)
119+
if pos != -1 and pos < earliest_pos:
120+
earliest_pos = pos
121+
earliest_tag = tag_name
122+
123+
return (earliest_pos, earliest_tag) if earliest_tag else None
124+
125+
@staticmethod
126+
def _is_partial_reasoning_tag(text: str, opening: bool = True) -> bool:
127+
"""Check if text could be the start of any reasoning tag."""
128+
for tag_name in REASONING_FIELD_NAMES:
129+
tag = f"<{tag_name}>" if opening else f"</{tag_name}>"
130+
for i in range(1, len(tag) + 1):
131+
if text.startswith(tag[:i]):
132+
return True
133+
return False
134+
104135
async def _stream_completion_async(
105136
self,
106137
**kwargs: Any,
107138
) -> AsyncIterator[ChatCompletionChunk]:
108139
response: AsyncIterator[HuggingFaceChatCompletionStreamOutput] = await self.client.chat_completion(**kwargs)
109140

141+
buffer = ""
142+
current_tag = None
143+
reasoning_buffer = ""
144+
110145
async for chunk in response:
111-
yield self._convert_completion_chunk_response(chunk)
146+
original_chunk = self._convert_completion_chunk_response(chunk)
147+
148+
if not (len(original_chunk.choices) > 0 and original_chunk.choices[0].delta.content):
149+
yield original_chunk
150+
continue
151+
152+
buffer += original_chunk.choices[0].delta.content
153+
content_parts = []
154+
reasoning_parts = []
155+
156+
while buffer:
157+
if current_tag is None:
158+
tag_info = self._find_reasoning_tag(buffer, opening=True)
159+
if tag_info:
160+
tag_start, tag_name = tag_info
161+
if tag_start > 0:
162+
content_parts.append(buffer[:tag_start])
163+
tag_full = f"<{tag_name}>"
164+
buffer = buffer[tag_start + len(tag_full) :]
165+
current_tag = tag_name
166+
elif self._is_partial_reasoning_tag(buffer, opening=True):
167+
break
168+
else:
169+
content_parts.append(buffer)
170+
buffer = ""
171+
else:
172+
tag_close = f"</{current_tag}>"
173+
tag_end = buffer.find(tag_close)
174+
if tag_end != -1:
175+
reasoning_parts.append(reasoning_buffer + buffer[:tag_end])
176+
reasoning_buffer = ""
177+
buffer = buffer[tag_end + len(tag_close) :]
178+
current_tag = None
179+
elif self._is_partial_reasoning_tag(buffer, opening=False):
180+
reasoning_buffer += buffer
181+
buffer = ""
182+
break
183+
else:
184+
reasoning_buffer += buffer
185+
buffer = ""
186+
187+
if content_parts or reasoning_parts:
188+
modified_chunk = original_chunk.model_copy(deep=True)
189+
modified_chunk.choices[0].delta.content = "".join(content_parts) if content_parts else None
190+
if reasoning_parts:
191+
modified_chunk.choices[0].delta.reasoning = Reasoning(content="".join(reasoning_parts))
192+
yield modified_chunk
193+
elif not buffer:
194+
modified_chunk = original_chunk.model_copy(deep=True)
195+
modified_chunk.choices[0].delta.content = None
196+
yield modified_chunk
112197

113198
async def _acompletion(
114199
self,
@@ -127,10 +212,19 @@ async def _acompletion(
127212
choices_out: list[Choice] = []
128213
for i, ch in enumerate(data.get("choices", [])):
129214
msg = ch.get("message", {})
215+
216+
_normalize_reasoning_on_message(msg)
217+
218+
reasoning_obj = None
219+
if msg.get("reasoning") and isinstance(msg["reasoning"], dict):
220+
if "content" in msg["reasoning"]:
221+
reasoning_obj = Reasoning(content=msg["reasoning"]["content"])
222+
130223
message = ChatCompletionMessage(
131224
role="assistant",
132225
content=msg.get("content"),
133226
tool_calls=msg.get("tool_calls"),
227+
reasoning=reasoning_obj,
134228
)
135229
choices_out.append(Choice(index=i, finish_reason=ch.get("finish_reason"), message=message))
136230

src/any_llm/providers/huggingface/utils.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import uuid
23
from collections.abc import Iterable
34
from typing import Any, Literal, cast
@@ -8,16 +9,54 @@
89
)
910
from openai.lib._parsing import type_to_response_format_param
1011

12+
from any_llm.constants import REASONING_FIELD_NAMES
1113
from any_llm.types.completion import (
1214
ChatCompletionChunk,
1315
ChoiceDelta,
1416
ChunkChoice,
1517
CompletionParams,
1618
CompletionUsage,
19+
Reasoning,
1720
)
1821
from any_llm.types.model import Model
1922

2023

24+
def _normalize_reasoning_on_message(message_dict: dict[str, Any]) -> None:
25+
"""Mutate a message dict to extract reasoning from content tags and provider-specific fields."""
26+
if isinstance(message_dict.get("reasoning"), dict) and "content" in message_dict["reasoning"]:
27+
return
28+
29+
reasoning_content = None
30+
31+
for field_name in REASONING_FIELD_NAMES:
32+
if field_name in message_dict and message_dict[field_name] is not None:
33+
reasoning_content = message_dict[field_name]
34+
break
35+
36+
if reasoning_content is None and isinstance(message_dict.get("reasoning"), str):
37+
reasoning_content = message_dict["reasoning"]
38+
39+
content = message_dict.get("content")
40+
if isinstance(content, str):
41+
for tag_name in REASONING_FIELD_NAMES:
42+
tag_open = f"<{tag_name}>"
43+
tag_close = f"</{tag_name}>"
44+
think_pattern = re.escape(tag_open) + r"(.*?)" + re.escape(tag_close)
45+
matches = re.findall(think_pattern, content, re.DOTALL)
46+
if matches:
47+
extracted_reasoning = "\n".join(matches)
48+
if reasoning_content:
49+
reasoning_content = f"{reasoning_content}\n{extracted_reasoning}"
50+
else:
51+
reasoning_content = extracted_reasoning
52+
content = re.sub(think_pattern, "", content, flags=re.DOTALL).strip()
53+
54+
message_dict["content"] = content
55+
56+
if reasoning_content is not None:
57+
message_dict["reasoning"] = {"content": str(reasoning_content)}
58+
59+
2160
def _create_openai_chunk_from_huggingface_chunk(chunk: HuggingFaceChatCompletionStreamOutput) -> ChatCompletionChunk:
2261
"""Convert a HuggingFace streaming chunk to OpenAI ChatCompletionChunk format."""
2362

@@ -30,14 +69,31 @@ def _create_openai_chunk_from_huggingface_chunk(chunk: HuggingFaceChatCompletion
3069

3170
for i, hf_choice in enumerate(hf_choices):
3271
hf_delta = hf_choice.delta
33-
content = hf_delta.content
34-
role = hf_delta.role
3572

36-
openai_role = None
37-
if role:
38-
openai_role = cast("Literal['developer', 'system', 'user', 'assistant', 'tool']", role)
73+
delta_dict: dict[str, Any] = {}
74+
if hf_delta.content is not None:
75+
delta_dict["content"] = hf_delta.content
76+
if hf_delta.role is not None:
77+
delta_dict["role"] = hf_delta.role
78+
if hasattr(hf_delta, "reasoning"):
79+
delta_dict["reasoning"] = hf_delta.reasoning
3980

40-
delta = ChoiceDelta(content=content, role=openai_role)
81+
_normalize_reasoning_on_message(delta_dict)
82+
83+
openai_role = None
84+
if delta_dict.get("role"):
85+
openai_role = cast("Literal['developer', 'system', 'user', 'assistant', 'tool']", delta_dict["role"])
86+
87+
reasoning_obj = None
88+
if delta_dict.get("reasoning") and isinstance(delta_dict["reasoning"], dict):
89+
if "content" in delta_dict["reasoning"]:
90+
reasoning_obj = Reasoning(content=delta_dict["reasoning"]["content"])
91+
92+
delta = ChoiceDelta(
93+
content=delta_dict.get("content"),
94+
role=openai_role,
95+
reasoning=reasoning_obj,
96+
)
4197

4298
choice = ChunkChoice(
4399
index=i,

src/any_llm/providers/openai/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from openai.types.chat.chat_completion import ChatCompletion as OpenAIChatCompletion
66

7+
from any_llm.constants import REASONING_FIELD_NAMES
78
from any_llm.logging import logger
89
from any_llm.types.completion import ChatCompletion
910

@@ -13,11 +14,7 @@ def _normalize_reasoning_on_message(message_dict: dict[str, Any]) -> None:
1314
if isinstance(message_dict.get("reasoning"), dict) and "content" in message_dict["reasoning"]:
1415
return
1516

16-
possible_fields = [
17-
"reasoning_content",
18-
"thinking",
19-
"chain_of_thought",
20-
]
17+
possible_fields = REASONING_FIELD_NAMES
2118
value: Any | None = None
2219
for field_name in possible_fields:
2320
if field_name in message_dict and message_dict[field_name] is not None:

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def provider_reasoning_model_map() -> dict[LLMProvider, str]:
3030
LLMProvider.MOONSHOT: "kimi-thinking-preview",
3131
LLMProvider.DATABRICKS: "databricks-gpt-oss-20b", # Untested, needs to be verified once we get a Databricks account
3232
LLMProvider.BEDROCK: "us.anthropic.claude-haiku-4-5-20251001-v1:0",
33+
LLMProvider.HUGGINGFACE: "huggingface/tgi",
3334
}
3435

3536

0 commit comments

Comments
 (0)