-
Notifications
You must be signed in to change notification settings - Fork 828
fix(compaction): estimate context usage after compaction and show 0.1% precision #1269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable | ||
|
|
||
| import kosong | ||
| from kosong.chat_provider import TokenUsage | ||
| from kosong.message import Message | ||
| from kosong.tooling.empty import EmptyToolset | ||
|
|
||
|
|
@@ -14,9 +15,47 @@ | |
| from kimi_cli.wire.types import ContentPart, TextPart, ThinkPart | ||
|
|
||
|
|
||
| class CompactionResult(NamedTuple): | ||
| messages: Sequence[Message] | ||
| usage: TokenUsage | None | ||
|
|
||
| @property | ||
| def estimated_token_count(self) -> int: | ||
| """Estimate the token count of the compacted messages. | ||
|
|
||
| When LLM usage is available, ``usage.output`` gives the exact token count | ||
| of the generated summary (the first message). Preserved messages (all | ||
| subsequent messages) are estimated from their text length. | ||
|
|
||
| When usage is not available (no compaction LLM call was made), all | ||
| messages are estimated from text length. | ||
|
|
||
| The estimate is intentionally conservative — it will be replaced by the | ||
| real value on the next LLM call. | ||
| """ | ||
| if self.usage is not None and len(self.messages) > 0: | ||
| summary_tokens = self.usage.output | ||
| preserved_tokens = _estimate_text_tokens(self.messages[1:]) | ||
| return summary_tokens + preserved_tokens | ||
|
|
||
| return _estimate_text_tokens(self.messages) | ||
|
|
||
|
|
||
| def _estimate_text_tokens(messages: Sequence[Message]) -> int: | ||
| """Estimate tokens from message text content using a character-based heuristic.""" | ||
| total_chars = 0 | ||
| for msg in messages: | ||
| for part in msg.content: | ||
| if isinstance(part, TextPart): | ||
| total_chars += len(part.text) | ||
| # ~4 chars per token for English; somewhat underestimates for CJK text, | ||
| # but this is a temporary estimate that gets corrected on the next LLM call. | ||
| return total_chars // 4 | ||
|
Comment on lines
+51
to
+53
|
||
|
|
||
|
|
||
| @runtime_checkable | ||
| class Compaction(Protocol): | ||
| async def compact(self, messages: Sequence[Message], llm: LLM) -> Sequence[Message]: | ||
| async def compact(self, messages: Sequence[Message], llm: LLM) -> CompactionResult: | ||
| """ | ||
| Compact a sequence of messages into a new sequence of messages. | ||
|
|
||
|
|
@@ -25,7 +64,7 @@ async def compact(self, messages: Sequence[Message], llm: LLM) -> Sequence[Messa | |
| llm (LLM): The LLM to use for compaction. | ||
|
|
||
| Returns: | ||
| Sequence[Message]: The compacted messages. | ||
| CompactionResult: The compacted messages and token usage from the compaction LLM call. | ||
|
|
||
| Raises: | ||
| ChatProviderError: When the chat provider returns an error. | ||
|
|
@@ -43,10 +82,10 @@ class SimpleCompaction: | |
| def __init__(self, max_preserved_messages: int = 2) -> None: | ||
| self.max_preserved_messages = max_preserved_messages | ||
|
|
||
| async def compact(self, messages: Sequence[Message], llm: LLM) -> Sequence[Message]: | ||
| async def compact(self, messages: Sequence[Message], llm: LLM) -> CompactionResult: | ||
| compact_message, to_preserve = self.prepare(messages) | ||
| if compact_message is None: | ||
| return to_preserve | ||
| return CompactionResult(messages=to_preserve, usage=None) | ||
|
|
||
| # Call kosong.step to get the compacted context | ||
| # TODO: set max completion tokens | ||
|
|
@@ -73,7 +112,7 @@ async def compact(self, messages: Sequence[Message], llm: LLM) -> Sequence[Messa | |
| content.extend(part for part in compacted_msg.content if not isinstance(part, ThinkPart)) | ||
| compacted_messages: list[Message] = [Message(role="user", content=content)] | ||
| compacted_messages.extend(to_preserve) | ||
| return compacted_messages | ||
| return CompactionResult(messages=compacted_messages, usage=result.usage) | ||
|
|
||
| class PrepareResult(NamedTuple): | ||
| compact_message: Message | None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ | |
| wire_send, | ||
| ) | ||
| from kimi_cli.soul.agent import Agent, Runtime | ||
| from kimi_cli.soul.compaction import SimpleCompaction | ||
| from kimi_cli.soul.compaction import CompactionResult, SimpleCompaction | ||
| from kimi_cli.soul.context import Context | ||
| from kimi_cli.soul.message import check_message, system, tool_result_to_message | ||
| from kimi_cli.soul.slash import registry as soul_slash_registry | ||
|
|
@@ -555,7 +555,7 @@ async def compact_context(self) -> None: | |
|
|
||
| chat_provider = self._runtime.llm.chat_provider if self._runtime.llm is not None else None | ||
|
|
||
| async def _run_compaction_once() -> Sequence[Message]: | ||
| async def _run_compaction_once() -> CompactionResult: | ||
| if self._runtime.llm is None: | ||
| raise LLMNotSet() | ||
| return await self._compaction.compact(self._context.history, self._runtime.llm) | ||
|
|
@@ -567,18 +567,22 @@ async def _run_compaction_once() -> Sequence[Message]: | |
| stop=stop_after_attempt(self._loop_control.max_retries_per_step), | ||
| reraise=True, | ||
| ) | ||
| async def _compact_with_retry() -> Sequence[Message]: | ||
| async def _compact_with_retry() -> CompactionResult: | ||
| return await self._run_with_connection_recovery( | ||
| "compaction", | ||
| _run_compaction_once, | ||
| chat_provider=chat_provider, | ||
| ) | ||
|
|
||
| wire_send(CompactionBegin()) | ||
| compacted_messages = await _compact_with_retry() | ||
| compaction_result = await _compact_with_retry() | ||
| await self._context.clear() | ||
| await self._checkpoint() | ||
| await self._context.append_message(compacted_messages) | ||
| await self._context.append_message(compaction_result.messages) | ||
|
|
||
| # Estimate token count so context_usage is not reported as 0% | ||
| await self._context.update_token_count(compaction_result.estimated_token_count) | ||
|
|
||
|
Comment on lines
579
to
+585
|
||
| wire_send(CompactionEnd()) | ||
|
|
||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_estimate_text_tokenscurrently only countsTextPartinmessage.contentand ignores other token-bearing fields likeMessage.tool_calls(function names/arguments) and any non-text content that still consumes tokens (e.g., images). SinceContext.token_countis used to decide when to compact (token_count + reserved >= max_context_size), this underestimation can prevent compaction and lead to provider context-limit errors. Consider extending the estimator to include tool call names/arguments (and optionally apply a fallback cost for non-text parts) so the estimate is biased high rather than low.