Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 143 additions & 3 deletions tools/query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

__version__ = "0.2"
__version__ = "0.3"

### Imports ###

import argparse
import asyncio
from collections.abc import Mapping
from dataclasses import dataclass
from dataclasses import dataclass, field, replace
import difflib
import json
import os
Expand Down Expand Up @@ -83,6 +83,63 @@ class SearchResultData(typing.TypedDict):
results: list[RawSearchResultData]


@dataclass
class HistoryEntry:
"""A single Q&A pair in conversation history."""

question: str
answer: str
had_answer: bool


@dataclass
class ConversationHistory:
"""Tracks recent Q&A pairs for context resolution.

This enables the query engine to resolve pronouns and references
like "it", "she", or "the first point" by providing recent context
to the LLM during query translation.
"""

entries: list[HistoryEntry] = field(default_factory=list)
max_entries: int = 5

def add(self, question: str, answer: str, had_answer: bool) -> None:
"""Add a new Q&A pair, removing oldest if at capacity."""
self.entries.append(HistoryEntry(question, answer, had_answer))
if len(self.entries) > self.max_entries:
self.entries.pop(0)

def clear(self) -> None:
"""Clear all history."""
self.entries.clear()

def to_prompt_section(self) -> typechat.PromptSection | None:
"""Format history as a prompt section for LLM.

Returns None if there's no history to include.
"""
if not self.entries:
return None

lines = [
"Recent conversation history (use this to resolve pronouns and references like 'it', 'he', 'she', 'the first point', etc.):"
]
for i, entry in enumerate(self.entries, 1):
lines.append(f"Q{i}: {entry.question}")
if entry.had_answer:
answer = (
entry.answer[:500] + "..."
if len(entry.answer) > 500
else entry.answer
)
lines.append(f"A{i}: {answer}")
else:
lines.append(f"A{i}: [No answer found]")

return typechat.PromptSection(role="user", content="\n".join(lines))


@dataclass
class ProcessingContext:
query_context: query.QueryEvalContext
Expand All @@ -101,6 +158,7 @@ class ProcessingContext:
]
lang_search_options: searchlang.LanguageSearchOptions
answer_context_options: answers.AnswerContextOptions
history: ConversationHistory = field(default_factory=ConversationHistory)

def __repr__(self) -> str:
parts = []
Expand All @@ -114,6 +172,7 @@ def __repr__(self) -> str:
parts.append(f"debug4={self.debug4}")
parts.append(f"lang_search_options={self.lang_search_options}")
parts.append(f"answer_context_options={self.answer_context_options}")
parts.append(f"history={len(self.history.entries)}/{self.history.max_entries}")
return f"Context({', '.join(parts)})"


Expand Down Expand Up @@ -377,11 +436,69 @@ async def cmd_stats(context: ProcessingContext, args: list[str]) -> None:
await print_conversation_stats(context.query_context.conversation)


async def cmd_history(context: ProcessingContext, args: list[str]) -> None:
"""Show or manage conversation history. Usage: @history [--clear] [--size N]

Without arguments, shows current history entries.
--clear: Clears all history.
--size N: Sets max history size (0 to disable history).

History is used to resolve pronouns and references in follow-up questions
like "it", "he", "she", or "the first point".
"""

parser = argparse.ArgumentParser(prog="@history", add_help=True)
parser.add_argument("--clear", action="store_true", help="Clear history")
parser.add_argument("--size", type=int, help="Set max history size")
ns = _parse_command_args(parser, args)
if ns is None:
return

if ns.clear:
context.history.clear()
print("History cleared.")
return

if ns.size is not None:
context.history.max_entries = ns.size
while len(context.history.entries) > ns.size:
context.history.entries.pop(0)
print(f"History size set to {ns.size}.")
return

if not context.history.entries:
print(
f"No history yet (max {context.history.max_entries} entries). "
"Ask some questions first."
)
return

print(
f"Conversation history "
f"({len(context.history.entries)}/{context.history.max_entries} entries):"
)
for i, entry in enumerate(context.history.entries, 1):
q_preview = (
entry.question[:70] + "..." if len(entry.question) > 70 else entry.question
)
a_preview = (
entry.answer[:70] + "..." if len(entry.answer) > 70 else entry.answer
)
status = (
Fore.GREEN + "✓" + Fore.RESET
if entry.had_answer
else Fore.RED + "✗" + Fore.RESET
)
print(f" {i}. [{status}] Q: {q_preview}")
print(f" A: {a_preview}")


commands: dict[str, CommandHandler] = {
"help": cmd_help,
"debug": cmd_debug,
"stage": cmd_stage,
"stats": cmd_stats,
"history": cmd_history,
}


Expand Down Expand Up @@ -478,6 +595,7 @@ async def main():
answers.AnswerContextOptions(
entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
),
ConversationHistory(max_entries=args.history_size),
)

if args.verbose:
Expand Down Expand Up @@ -669,11 +787,20 @@ async def process_query(context: ProcessingContext, query_text: str) -> float |
)
prsep()

history_section = context.history.to_prompt_section()
if history_section:
lang_search_options = replace(
context.lang_search_options,
model_instructions=[history_section],
)
else:
lang_search_options = context.lang_search_options

result = await searchlang.search_conversation_with_language(
context.query_context.conversation,
context.query_translator,
query_text,
context.lang_search_options,
lang_search_options,
debug_context=debug_context,
)
if isinstance(result, typechat.Failure):
Expand Down Expand Up @@ -743,6 +870,12 @@ async def process_query(context: ProcessingContext, query_text: str) -> float |
options=context.answer_context_options,
)

if context.history.max_entries > 0:
if combined_answer.type == "Answered":
context.history.add(query_text, combined_answer.answer or "", True)
else:
context.history.add(query_text, combined_answer.whyNoAnswer or "", False)

if context.debug4 == "full":
utils.pretty_print(all_answers)
prsep()
Expand Down Expand Up @@ -841,6 +974,13 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser:
action="store_true",
help="Show verbose startup information and timing logs",
)
parser.add_argument(
"--history-size",
type=int,
default=5,
help="Number of recent Q&A pairs to keep for resolving pronouns/references "
"(default: 5, 0 to disable)",
)

batch = parser.add_argument_group("Batch mode options")
batch.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.