From a41e1df57ff01b94623b54709fe0a78770ee7ccd Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 17 Apr 2025 17:03:27 -0700 Subject: [PATCH 1/4] renames session manager classes to message history --- ...manager.ipynb => 07_message_history.ipynb} | 79 ++-- redisvl/extensions/constants.py | 10 +- .../extensions/message_history/__init__.py | 7 + .../message_history/base_history.py | 157 +++++++ .../message_history/message_history.py | 230 +++++++++++ redisvl/extensions/message_history/schema.py | 101 +++++ .../semantic_message_history.py | 375 +++++++++++++++++ .../extensions/session_manager/__init__.py | 35 +- .../session_manager/base_session.py | 165 +------- redisvl/extensions/session_manager/schema.py | 117 +----- .../session_manager/semantic_session.py | 387 +----------------- .../session_manager/standard_session.py | 244 +---------- ...ion_manager.py => test_message_history.py} | 309 +++++++------- ...hema.py => test_message_history_schema.py} | 2 +- 14 files changed, 1160 insertions(+), 1058 deletions(-) rename docs/user_guide/{07_session_manager.ipynb => 07_message_history.ipynb} (83%) create mode 100644 redisvl/extensions/message_history/__init__.py create mode 100644 redisvl/extensions/message_history/base_history.py create mode 100644 redisvl/extensions/message_history/message_history.py create mode 100644 redisvl/extensions/message_history/schema.py create mode 100644 redisvl/extensions/message_history/semantic_message_history.py rename tests/integration/{test_session_manager.py => test_message_history.py} (65%) rename tests/unit/{test_session_schema.py => test_message_history_schema.py} (98%) diff --git a/docs/user_guide/07_session_manager.ipynb b/docs/user_guide/07_message_history.ipynb similarity index 83% rename from docs/user_guide/07_session_manager.ipynb rename to docs/user_guide/07_message_history.ipynb index ed6c61e4..80f0ae14 100644 --- a/docs/user_guide/07_session_manager.ipynb +++ b/docs/user_guide/07_message_history.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# LLM Session Memory" + "# LLM Message History" ] }, { @@ -15,12 +15,12 @@ "\n", "The solution to this problem is to append the previous conversation history to each subsequent call to the LLM.\n", "\n", - "This notebook will show how to use Redis to structure and store and retrieve this conversational session memory." + "This notebook will show how to use Redis to structure and store and retrieve this conversational message history." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -32,8 +32,8 @@ } ], "source": [ - "from redisvl.extensions.session_manager import StandardSessionManager\n", - "chat_session = StandardSessionManager(name='student tutor')" + "from redisvl.extensions.message_history import MessageHistory\n", + "chat_history = MessageHistory(name='student tutor')" ] }, { @@ -48,12 +48,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "chat_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful geography tutor, giving simple and short answers to questions about Europen countries.\"})\n", - "chat_session.add_messages([\n", + "chat_history.add_message({\"role\":\"system\", \"content\":\"You are a helpful geography tutor, giving simple and short answers to questions about European countries.\"})\n", + "chat_history.add_messages([\n", " {\"role\":\"user\", \"content\":\"What is the capital of France?\"},\n", " {\"role\":\"llm\", \"content\":\"The capital is Paris.\"},\n", " {\"role\":\"user\", \"content\":\"And what is the capital of Spain?\"},\n", @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -88,7 +88,7 @@ } ], "source": [ - "context = chat_session.get_recent()\n", + "context = chat_history.get_recent()\n", "for message in context:\n", " print(message)" ] @@ -97,12 +97,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In many LLM flows the conversation progresses in a series of prompt and response pairs. session managers offer a convienience function `store()` to add these simply." + "In many LLM flows the conversation progresses in a series of prompt and response pairs. Message history offer a convenience function `store()` to add these simply." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -121,9 +121,9 @@ "source": [ "prompt = \"what is the size of England compared to Portugal?\"\n", "response = \"England is larger in land area than Portal by about 15000 square miles.\"\n", - "chat_session.store(prompt, response)\n", + "chat_history.store(prompt, response)\n", "\n", - "context = chat_session.get_recent(top_k=6)\n", + "context = chat_history.get_recent(top_k=6)\n", "for message in context:\n", " print(message)" ] @@ -144,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -160,8 +160,8 @@ } ], "source": [ - "chat_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful algebra tutor, giving simple answers to math problems.\"}, session_tag='student two')\n", - "chat_session.add_messages([\n", + "chat_history.add_message({\"role\":\"system\", \"content\":\"You are a helpful algebra tutor, giving simple answers to math problems.\"}, session_tag='student two')\n", + "chat_history.add_messages([\n", " {\"role\":\"user\", \"content\":\"What is the value of x in the equation 2x + 3 = 7?\"},\n", " {\"role\":\"llm\", \"content\":\"The value of x is 2.\"},\n", " {\"role\":\"user\", \"content\":\"What is the value of y in the equation 3y - 5 = 7?\"},\n", @@ -169,7 +169,7 @@ " session_tag='student two'\n", " )\n", "\n", - "for math_message in chat_session.get_recent(session_tag='student two'):\n", + "for math_message in chat_history.get_recent(session_tag='student two'):\n", " print(math_message)" ] }, @@ -177,16 +177,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Semantic conversation memory\n", + "## Semantic message history\n", "For longer conversations our list of messages keeps growing. Since LLMs are stateless we have to continue to pass this conversation history on each subsequent call to ensure the LLM has the correct context.\n", "\n", "A typical flow looks like this:\n", "```\n", "while True:\n", " prompt = input('enter your next question')\n", - " context = chat_session.get_recent()\n", + " context = chat_history.get_recent()\n", " response = LLM_api_call(prompt=prompt, context=context)\n", - " chat_session.store(prompt, response)\n", + " chat_history.store(prompt, response)\n", "```\n", "\n", "This works, but as context keeps growing so too does our LLM token count, which increases latency and cost.\n", @@ -195,12 +195,12 @@ "\n", "A better solution is to pass only the relevant conversational context on each subsequent call.\n", "\n", - "For this, RedisVL has the `SemanticSessionManager`, which uses vector similarity search to return only semantically relevant sections of the conversation." + "For this, RedisVL has the `SemanticMessageHistory`, which uses vector similarity search to return only semantically relevant sections of the conversation." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -212,15 +212,15 @@ } ], "source": [ - "from redisvl.extensions.session_manager import SemanticSessionManager\n", - "semantic_session = SemanticSessionManager(name='tutor')\n", + "from redisvl.extensions.message_history import SemanticMessageHistory\n", + "semantic_history = SemanticMessageHistory(name='tutor')\n", "\n", - "semantic_session.add_messages(chat_session.get_recent(top_k=8))" + "semantic_history.add_messages(chat_history.get_recent(top_k=8))" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -234,8 +234,8 @@ ], "source": [ "prompt = \"what have I learned about the size of England?\"\n", - "semantic_session.set_distance_threshold(0.35)\n", - "context = semantic_session.get_relevant(prompt)\n", + "semantic_history.set_distance_threshold(0.35)\n", + "context = semantic_history.get_relevant(prompt)\n", "for message in context:\n", " print(message)" ] @@ -251,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -266,9 +266,9 @@ } ], "source": [ - "semantic_session.set_distance_threshold(0.7)\n", + "semantic_history.set_distance_threshold(0.7)\n", "\n", - "larger_context = semantic_session.get_relevant(prompt)\n", + "larger_context = semantic_history.get_relevant(prompt)\n", "for message in larger_context:\n", " print(message)" ] @@ -284,7 +284,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -300,28 +300,29 @@ } ], "source": [ - "semantic_session.store(\n", + "semantic_history.store(\n", " prompt=\"what is the smallest country in Europe?\",\n", " response=\"Monaco is the smallest country in Europe at 0.78 square miles.\" # Incorrect. Vatican City is the smallest country in Europe\n", " )\n", "\n", "# get the key of the incorrect message\n", - "context = semantic_session.get_recent(top_k=1, raw=True)\n", + "context = semantic_history.get_recent(top_k=1, raw=True)\n", "bad_key = context[0]['entry_id']\n", - "semantic_session.drop(bad_key)\n", + "semantic_history.drop(bad_key)\n", "\n", - "corrected_context = semantic_session.get_recent()\n", + "corrected_context = semantic_history.get_recent()\n", "for message in corrected_context:\n", " print(message)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "chat_session.clear()" + "chat_history.clear()\n", + "semantic_history.clear()" ] } ], diff --git a/redisvl/extensions/constants.py b/redisvl/extensions/constants.py index dfd2ceef..d6d7945a 100644 --- a/redisvl/extensions/constants.py +++ b/redisvl/extensions/constants.py @@ -1,10 +1,10 @@ """ -Constants used within the extension classes SemanticCache, BaseSessionManager, -StandardSessionManager,SemanticSessionManager and SemanticRouter. +Constants used within the extension classes SemanticCache, BaseMessageHistory, +MessageHistory, SemanticMessageHistory and SemanticRouter. These constants are also used within theses classes corresponding schema. """ -# BaseSessionManager +# BaseMessageHistory ID_FIELD_NAME: str = "entry_id" ROLE_FIELD_NAME: str = "role" CONTENT_FIELD_NAME: str = "content" @@ -12,8 +12,8 @@ TIMESTAMP_FIELD_NAME: str = "timestamp" SESSION_FIELD_NAME: str = "session_tag" -# SemanticSessionManager -SESSION_VECTOR_FIELD_NAME: str = "vector_field" +# SemanticMessageHistory +MESSAGE_VECTOR_FIELD_NAME: str = "vector_field" # SemanticCache REDIS_KEY_FIELD_NAME: str = "key" diff --git a/redisvl/extensions/message_history/__init__.py b/redisvl/extensions/message_history/__init__.py new file mode 100644 index 00000000..a98d8f2a --- /dev/null +++ b/redisvl/extensions/message_history/__init__.py @@ -0,0 +1,7 @@ +from redisvl.extensions.message_history.base_history import BaseMessageHistory +from redisvl.extensions.message_history.message_history import MessageHistory +from redisvl.extensions.message_history.semantic_message_history import ( + SemanticMessageHistory, +) + +__all__ = ["BaseMessageHistory", "MessageHistory", "SemanticMessageHistory"] diff --git a/redisvl/extensions/message_history/base_history.py b/redisvl/extensions/message_history/base_history.py new file mode 100644 index 00000000..72825877 --- /dev/null +++ b/redisvl/extensions/message_history/base_history.py @@ -0,0 +1,157 @@ +from typing import Any, Dict, List, Optional, Union + +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ROLE_FIELD_NAME, + TOOL_FIELD_NAME, +) +from redisvl.extensions.message_history.schema import ChatMessage +from redisvl.utils.utils import create_ulid + + +class BaseMessageHistory: + + def __init__( + self, + name: str, + session_tag: Optional[str] = None, + ): + """Initialize message history with index + + Message History stores the current and previous user text prompts and + LLM responses to allow for enriching future prompts with session + context. Message history is stored in individual user or LLM prompts and + responses. + + Args: + name (str): The name of the message history index. + session_tag (str): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + """ + self._name = name + self._session_tag = session_tag or create_ulid() + + def clear(self) -> None: + """Clears the chat message history.""" + raise NotImplementedError + + def delete(self) -> None: + """Clear all conversation history and remove any search indices.""" + raise NotImplementedError + + def drop(self, id_field: Optional[str] = None) -> None: + """Remove a specific exchange from the conversation history. + + Args: + id_field (Optional[str]): The id_field of the entry to delete. + If None then the last entry is deleted. + """ + raise NotImplementedError + + @property + def messages(self) -> Union[List[str], List[Dict[str, str]]]: + """Returns the full chat history.""" + raise NotImplementedError + + def get_recent( + self, + top_k: int = 5, + as_text: bool = False, + raw: bool = False, + session_tag: Optional[str] = None, + ) -> Union[List[str], List[Dict[str, str]]]: + """Retreive the recent conversation history in sequential order. + + Args: + top_k (int): The number of previous exchanges to return. Default is 5. + Note that one exchange contains both a prompt and response. + as_text (bool): Whether to return the conversation as a single string, + or list of alternating prompts and responses. + raw (bool): Whether to return the full Redis hash entry or just the + prompt and response + session_tag (str): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + + Returns: + Union[str, List[str]]: A single string transcription of the messages + or list of strings if as_text is false. + + Raises: + ValueError: If top_k is not an integer greater than or equal to 0. + """ + raise NotImplementedError + + def _format_context( + self, messages: List[Dict[str, Any]], as_text: bool + ) -> Union[List[str], List[Dict[str, str]]]: + """Extracts the prompt and response fields from the Redis hashes and + formats them as either flat dictionaries or strings. + + Args: + messages (List[Dict[str, Any]]): The messages from the message history index. + as_text (bool): Whether to return the conversation as a single string, + or list of alternating prompts and responses. + + Returns: + Union[str, List[str]]: A single string transcription of the messages + or list of strings if as_text is false. + """ + context = [] + + for message in messages: + + chat_message = ChatMessage(**message) + + if as_text: + context.append(chat_message.content) + else: + chat_message_dict = { + ROLE_FIELD_NAME: chat_message.role, + CONTENT_FIELD_NAME: chat_message.content, + } + if chat_message.tool_call_id is not None: + chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id + + context.append(chat_message_dict) # type: ignore + + return context + + def store( + self, prompt: str, response: str, session_tag: Optional[str] = None + ) -> None: + """Insert a prompt:response pair into the message history. A timestamp + is associated with each exchange so that they can be later sorted + in sequential ordering after retrieval. + + Args: + prompt (str): The user prompt to the LLM. + response (str): The corresponding LLM response. + session_tag (Optional[str]): The tag to mark the message with. Defaults to None. + """ + raise NotImplementedError + + def add_messages( + self, messages: List[Dict[str, str]], session_tag: Optional[str] = None + ) -> None: + """Insert a list of prompts and responses into the message history. + A timestamp is associated with each so that they can be later sorted + in sequential ordering after retrieval. + + Args: + messages (List[Dict[str, str]]): The list of user prompts and LLM responses. + session_tag (Optional[str]): The tag to mark the messages with. Defaults to None. + """ + raise NotImplementedError + + def add_message( + self, message: Dict[str, str], session_tag: Optional[str] = None + ) -> None: + """Insert a single prompt or response into the message history. + A timestamp is associated with it so that it can be later sorted + in sequential ordering after retrieval. + + Args: + message (Dict[str,str]): The user prompt or LLM response. + session_tag (Optional[str]): The tag to mark the message with. Defaults to None. + """ + raise NotImplementedError diff --git a/redisvl/extensions/message_history/message_history.py b/redisvl/extensions/message_history/message_history.py new file mode 100644 index 00000000..4520d7a4 --- /dev/null +++ b/redisvl/extensions/message_history/message_history.py @@ -0,0 +1,230 @@ +from typing import Any, Dict, List, Optional, Union + +from redis import Redis + +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ID_FIELD_NAME, + ROLE_FIELD_NAME, + SESSION_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, +) +from redisvl.extensions.message_history import BaseMessageHistory +from redisvl.extensions.message_history.schema import ChatMessage, MessageHistorySchema +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery +from redisvl.query.filter import Tag + + +class MessageHistory(BaseMessageHistory): + + def __init__( + self, + name: str, + session_tag: Optional[str] = None, + prefix: Optional[str] = None, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + """Initialize message history + + Message History stores the current and previous user text prompts and + LLM responses to allow for enriching future prompts with session + context. Message history is stored in individual user or LLM prompts and + responses. + + Args: + name (str): The name of the message history index. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + prefix (Optional[str]): Prefix for the keys for this conversation data. + Defaults to None and will be replaced with the index name. + redis_client (Optional[Redis]): A Redis client instance. Defaults to + None. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. + + """ + super().__init__(name, session_tag) + + prefix = prefix or name + + schema = MessageHistorySchema.from_params(name, prefix) + + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) + + self._index.create(overwrite=False) + + self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag + + def clear(self) -> None: + """Clears the conversation message history.""" + self._index.clear() + + def delete(self) -> None: + """Clear all conversation keys and remove the search index.""" + self._index.delete(drop=True) + + def drop(self, id: Optional[str] = None) -> None: + """Remove a specific exchange from the conversation history. + + Args: + id (Optional[str]): The id of the message entry to delete. + If None then the last entry is deleted. + """ + if id is None: + id = self.get_recent(top_k=1, raw=True)[0][ID_FIELD_NAME] # type: ignore + + self._index.client.delete(self._index.key(id)) # type: ignore + + @property + def messages(self) -> Union[List[str], List[Dict[str, str]]]: + """Returns the full message history.""" + # TODO raw or as_text? + # TODO refactor this method to use get_recent and support other session tags? + return_fields = [ + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + ] + + query = FilterQuery( + filter_expression=self._default_session_filter, + return_fields=return_fields, + ) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) + messages = self._index.query(query) + + return self._format_context(messages, as_text=False) + + def get_recent( + self, + top_k: int = 5, + as_text: bool = False, + raw: bool = False, + session_tag: Optional[str] = None, + ) -> Union[List[str], List[Dict[str, str]]]: + """Retrieve the recent message history in sequential order. + + Args: + top_k (int): The number of previous messages to return. Default is 5. + as_text (bool): Whether to return the conversation as a single string, + or list of alternating prompts and responses. + raw (bool): Whether to return the full Redis hash entry or just the + prompt and response. + session_tag (Optional[str]): Tag of the entries linked to a specific + conversation session. Defaults to instance ULID. + + Returns: + Union[str, List[str]]: A single string transcription of the messages + or list of strings if as_text is false. + + Raises: + ValueError: if top_k is not an integer greater than or equal to 0. + """ + if type(top_k) != int or top_k < 0: + raise ValueError("top_k must be an integer greater than or equal to 0") + + return_fields = [ + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + ] + + session_filter = ( + Tag(SESSION_FIELD_NAME) == session_tag + if session_tag + else self._default_session_filter + ) + + query = FilterQuery( + filter_expression=session_filter, + return_fields=return_fields, + num_results=top_k, + ) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=False) + messages = self._index.query(query) + + if raw: + return messages[::-1] + return self._format_context(messages[::-1], as_text) + + def store( + self, prompt: str, response: str, session_tag: Optional[str] = None + ) -> None: + """Insert a prompt:response pair into the message history. A timestamp + is associated with each exchange so that they can be later sorted + in sequential ordering after retrieval. + + Args: + prompt (str): The user prompt to the LLM. + response (str): The corresponding LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + """ + self.add_messages( + [ + {ROLE_FIELD_NAME: "user", CONTENT_FIELD_NAME: prompt}, + {ROLE_FIELD_NAME: "llm", CONTENT_FIELD_NAME: response}, + ], + session_tag, + ) + + def add_messages( + self, messages: List[Dict[str, str]], session_tag: Optional[str] = None + ) -> None: + """Insert a list of prompts and responses into the message history. + A timestamp is associated with each so that they can be later sorted + in sequential ordering after retrieval. + + Args: + messages (List[Dict[str, str]]): The list of user prompts and LLM responses. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + """ + session_tag = session_tag or self._session_tag + chat_messages: List[Dict[str, Any]] = [] + + for message in messages: + + chat_message = ChatMessage( + role=message[ROLE_FIELD_NAME], + content=message[CONTENT_FIELD_NAME], + session_tag=session_tag, + ) + + if TOOL_FIELD_NAME in message: + chat_message.tool_call_id = message[TOOL_FIELD_NAME] + + chat_messages.append(chat_message.to_dict()) + + self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) + + def add_message( + self, message: Dict[str, str], session_tag: Optional[str] = None + ) -> None: + """Insert a single prompt or response into the message history. + A timestamp is associated with it so that it can be later sorted + in sequential ordering after retrieval. + + Args: + message (Dict[str,str]): The user prompt or LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + """ + self.add_messages([message], session_tag) diff --git a/redisvl/extensions/message_history/schema.py b/redisvl/extensions/message_history/schema.py new file mode 100644 index 00000000..839b84ff --- /dev/null +++ b/redisvl/extensions/message_history/schema.py @@ -0,0 +1,101 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ID_FIELD_NAME, + MESSAGE_VECTOR_FIELD_NAME, + ROLE_FIELD_NAME, + SESSION_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, +) +from redisvl.redis.utils import array_to_buffer +from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp + + +class ChatMessage(BaseModel): + """A single chat message exchanged between a user and an LLM.""" + + entry_id: Optional[str] = Field(default=None) + """A unique identifier for the message.""" + role: str # TODO -- do we enumify this? + """The role of the message sender (e.g., 'user' or 'llm').""" + content: str + """The content of the message.""" + session_tag: str + """Tag associated with the current conversation session.""" + timestamp: Optional[float] = Field(default=None) + """The time the message was sent, in UTC, rounded to milliseconds.""" + tool_call_id: Optional[str] = Field(default=None) + """An optional identifier for a tool call associated with the message.""" + vector_field: Optional[List[float]] = Field(default=None) + """The vector representation of the message content.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + + @model_validator(mode="before") + @classmethod + def generate_id(cls, values): + if TIMESTAMP_FIELD_NAME not in values: + values[TIMESTAMP_FIELD_NAME] = current_timestamp() + if ID_FIELD_NAME not in values: + values[ID_FIELD_NAME] = ( + f"{values[SESSION_FIELD_NAME]}:{values[TIMESTAMP_FIELD_NAME]}" + ) + return values + + def to_dict(self, dtype: Optional[str] = None) -> Dict: + data = self.model_dump(exclude_none=True) + + # handle optional fields + if MESSAGE_VECTOR_FIELD_NAME in data: + data[MESSAGE_VECTOR_FIELD_NAME] = array_to_buffer( + data[MESSAGE_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type] + ) + return data + + +class MessageHistorySchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": ROLE_FIELD_NAME, "type": "tag"}, + {"name": CONTENT_FIELD_NAME, "type": "text"}, + {"name": TOOL_FIELD_NAME, "type": "tag"}, + {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, + {"name": SESSION_FIELD_NAME, "type": "tag"}, + ], + ) + + +class SemanticMessageHistorySchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": ROLE_FIELD_NAME, "type": "tag"}, + {"name": CONTENT_FIELD_NAME, "type": "text"}, + {"name": TOOL_FIELD_NAME, "type": "tag"}, + {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, + {"name": SESSION_FIELD_NAME, "type": "tag"}, + { + "name": MESSAGE_VECTOR_FIELD_NAME, + "type": "vector", + "attrs": { + "dims": vectorizer_dims, + "datatype": dtype, + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) diff --git a/redisvl/extensions/message_history/semantic_message_history.py b/redisvl/extensions/message_history/semantic_message_history.py new file mode 100644 index 00000000..529a9a86 --- /dev/null +++ b/redisvl/extensions/message_history/semantic_message_history.py @@ -0,0 +1,375 @@ +from typing import Any, Dict, List, Optional, Union + +from redis import Redis + +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ID_FIELD_NAME, + MESSAGE_VECTOR_FIELD_NAME, + ROLE_FIELD_NAME, + SESSION_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, +) +from redisvl.extensions.message_history import BaseMessageHistory +from redisvl.extensions.message_history.schema import ( + ChatMessage, + SemanticMessageHistorySchema, +) +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery, RangeQuery +from redisvl.query.filter import Tag +from redisvl.utils.utils import deprecated_argument, validate_vector_dims +from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer + + +class SemanticMessageHistory(BaseMessageHistory): + + @deprecated_argument("dtype", "vectorizer") + def __init__( + self, + name: str, + session_tag: Optional[str] = None, + prefix: Optional[str] = None, + vectorizer: Optional[BaseVectorizer] = None, + distance_threshold: float = 0.3, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + overwrite: bool = False, + **kwargs, + ): + """Initialize message history with index + + Semantic Message History stores the current and previous user text prompts + and LLM responses to allow for enriching future prompts with session + context. Message history is stored in individual user or LLM prompts and + responses. + + Args: + name (str): The name of the message history index. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + prefix (Optional[str]): Prefix for the keys for this message data. + Defaults to None and will be replaced with the index name. + vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings. + distance_threshold (float): The maximum semantic distance to be + included in the context. Defaults to 0.3. + redis_client (Optional[Redis]): A Redis client instance. Defaults to + None. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. + overwrite (bool): Whether or not to force overwrite the schema for + the semantic message index. Defaults to false. + + The proposed schema will support a single vector embedding constructed + from either the prompt or response in a single string. + """ + super().__init__(name, session_tag) + + prefix = prefix or name + dtype = kwargs.pop("dtype", None) + + # Validate a provided vectorizer or set the default + if vectorizer: + if not isinstance(vectorizer, BaseVectorizer): + raise TypeError("Must provide a valid redisvl.vectorizer class.") + if dtype and vectorizer.dtype != dtype: + raise ValueError( + f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}" + ) + else: + vectorizer_kwargs = kwargs + + if dtype: + vectorizer_kwargs.update(**{"dtype": dtype}) + + vectorizer = HFTextVectorizer( + model="sentence-transformers/all-mpnet-base-v2", + **vectorizer_kwargs, + ) + + self._vectorizer = vectorizer + + self.set_distance_threshold(distance_threshold) + + schema = SemanticMessageHistorySchema.from_params( + name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore + ) + + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) + + # Check for existing message history index + if not overwrite and self._index.exists(): + existing_index = SearchIndex.from_existing( + name, redis_client=self._index.client + ) + if existing_index.schema.to_dict() != self._index.schema.to_dict(): + raise ValueError( + f"Existing index {name} schema does not match the user provided schema for the semantic message history. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + self._index.create(overwrite=overwrite, drop=False) + + self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag + + def clear(self) -> None: + """Clears the message history.""" + self._index.clear() + + def delete(self) -> None: + """Clear all message keys and remove the search index.""" + self._index.delete(drop=True) + + def drop(self, id: Optional[str] = None) -> None: + """Remove a specific exchange from the message history. + + Args: + id (Optional[str]): The id of the message entry to delete. + If None then the last entry is deleted. + """ + if id is None: + id = self.get_recent(top_k=1, raw=True)[0][ID_FIELD_NAME] # type: ignore + + self._index.client.delete(self._index.key(id)) # type: ignore + + @property + def messages(self) -> Union[List[str], List[Dict[str, str]]]: + """Returns the full message history.""" + # TODO raw or as_text? + # TODO refactor method to use get_recent and support other session tags + return_fields = [ + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + ] + + query = FilterQuery( + filter_expression=self._default_session_filter, + return_fields=return_fields, + ) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) + messages = self._index.query(query) + + return self._format_context(messages, as_text=False) + + def get_relevant( + self, + prompt: str, + as_text: bool = False, + top_k: int = 5, + fall_back: bool = False, + session_tag: Optional[str] = None, + raw: bool = False, + distance_threshold: Optional[float] = None, + ) -> Union[List[str], List[Dict[str, str]]]: + """Searches the message history for information semantically related to + the specified prompt. + + This method uses vector similarity search with a text prompt as input. + It checks for semantically similar prompts and responses and gets + the top k most relevant previous prompts or responses to include as + context to the next LLM call. + + Args: + prompt (str): The message text to search for in message history + as_text (bool): Whether to return the prompts and responses as text + or as JSON. + top_k (int): The number of previous messages to return. Default is 5. + session_tag (Optional[str]): Tag of the entries linked to a specific + conversation session. Defaults to instance ULID. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. + fall_back (bool): Whether to drop back to recent conversation history + if no relevant context is found. + raw (bool): Whether to return the full Redis hash entry or just the + message. + + Returns: + Union[List[str], List[Dict[str,str]]: Either a list of strings, or a + list of prompts and responses in JSON containing the most relevant. + + Raises ValueError: if top_k is not an integer greater or equal to 0. + """ + if type(top_k) != int or top_k < 0: + raise ValueError("top_k must be an integer greater than or equal to -1") + if top_k == 0: + return [] + + # override distance threshold + distance_threshold = distance_threshold or self._distance_threshold + + return_fields = [ + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, + ] + + session_filter = ( + Tag(SESSION_FIELD_NAME) == session_tag + if session_tag + else self._default_session_filter + ) + + query = RangeQuery( + vector=self._vectorizer.embed(prompt), + vector_field_name=MESSAGE_VECTOR_FIELD_NAME, + return_fields=return_fields, + distance_threshold=distance_threshold, + num_results=top_k, + return_score=True, + filter_expression=session_filter, + dtype=self._vectorizer.dtype, + ) + messages = self._index.query(query) + + # if we don't find semantic matches fallback to returning recent context + if not messages and fall_back: + return self.get_recent(as_text=as_text, top_k=top_k, raw=raw) + if raw: + return messages + return self._format_context(messages, as_text) + + def get_recent( + self, + top_k: int = 5, + as_text: bool = False, + raw: bool = False, + session_tag: Optional[str] = None, + ) -> Union[List[str], List[Dict[str, str]]]: + """Retreive the recent message history in sequential order. + + Args: + top_k (int): The number of previous exchanges to return. Default is 5. + as_text (bool): Whether to return the conversation as a single string, + or list of alternating prompts and responses. + raw (bool): Whether to return the full Redis hash entry or just the + prompt and response + session_tag (Optional[str]): Tag of the entries linked to a specific + conversation session. Defaults to instance ULID. + + Returns: + Union[str, List[str]]: A single string transcription of the session + or list of strings if as_text is false. + + Raises: + ValueError: if top_k is not an integer greater than or equal to 0. + """ + if type(top_k) != int or top_k < 0: + raise ValueError("top_k must be an integer greater than or equal to 0") + + return_fields = [ + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + ] + + session_filter = ( + Tag(SESSION_FIELD_NAME) == session_tag + if session_tag + else self._default_session_filter + ) + + query = FilterQuery( + filter_expression=session_filter, + return_fields=return_fields, + num_results=top_k, + ) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=False) + messages = self._index.query(query) + + if raw: + return messages[::-1] + return self._format_context(messages[::-1], as_text) + + @property + def distance_threshold(self): + return self._distance_threshold + + def set_distance_threshold(self, threshold): + self._distance_threshold = threshold + + def store( + self, prompt: str, response: str, session_tag: Optional[str] = None + ) -> None: + """Insert a prompt:response pair into the message history. A timestamp + is associated with each message so that they can be later sorted + in sequential ordering after retrieval. + + Args: + prompt (str): The user prompt to the LLM. + response (str): The corresponding LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + """ + self.add_messages( + [ + {ROLE_FIELD_NAME: "user", CONTENT_FIELD_NAME: prompt}, + {ROLE_FIELD_NAME: "llm", CONTENT_FIELD_NAME: response}, + ], + session_tag, + ) + + def add_messages( + self, messages: List[Dict[str, str]], session_tag: Optional[str] = None + ) -> None: + """Insert a list of prompts and responses into the session memory. + A timestamp is associated with each so that they can be later sorted + in sequential ordering after retrieval. + + Args: + messages (List[Dict[str, str]]): The list of user prompts and LLM responses. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + conversation session. Defaults to instance ULID. + """ + session_tag = session_tag or self._session_tag + chat_messages: List[Dict[str, Any]] = [] + + for message in messages: + content_vector = self._vectorizer.embed(message[CONTENT_FIELD_NAME]) + validate_vector_dims( + len(content_vector), + self._index.schema.fields[MESSAGE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore + ) + + chat_message = ChatMessage( + role=message[ROLE_FIELD_NAME], + content=message[CONTENT_FIELD_NAME], + session_tag=session_tag, + vector_field=content_vector, # type: ignore + ) + + if TOOL_FIELD_NAME in message: + chat_message.tool_call_id = message[TOOL_FIELD_NAME] + + chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype)) + + self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) + + def add_message( + self, message: Dict[str, str], session_tag: Optional[str] = None + ) -> None: + """Insert a single prompt or response into the message history. + A timestamp is associated with it so that it can be later sorted + in sequential ordering after retrieval. + + Args: + message (Dict[str,str]): The user prompt or LLM response. + session_tag (Optional[str]): Tag to be added to entry to link to a specific + conversation session. Defaults to instance ULID. + """ + self.add_messages([message], session_tag) diff --git a/redisvl/extensions/session_manager/__init__.py b/redisvl/extensions/session_manager/__init__.py index 53afbf8b..619c2ee5 100644 --- a/redisvl/extensions/session_manager/__init__.py +++ b/redisvl/extensions/session_manager/__init__.py @@ -1,5 +1,32 @@ -from redisvl.extensions.session_manager.base_session import BaseSessionManager -from redisvl.extensions.session_manager.semantic_session import SemanticSessionManager -from redisvl.extensions.session_manager.standard_session import StandardSessionManager +""" +RedisVL Session Manager Extensions (Deprecated Path) -__all__ = ["BaseSessionManager", "StandardSessionManager", "SemanticSessionManager"] +This module is kept for backward compatibility. Please use `redisvl.extensions.message_history` instead. +""" + +import warnings + +from redisvl.extensions.message_history.message_history import MessageHistory +from redisvl.extensions.message_history.schema import ( + ChatMessage, + MessageHistorySchema, + SemanticMessageHistorySchema, +) +from redisvl.extensions.message_history.semantic_message_history import ( + SemanticMessageHistory, +) + +warnings.warn( + "Importing from redisvl.extensions.session_manager is deprecated. " + "Please import from redisvl.extensions.message_history instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "MessageHistory", + "SemanticMessageHistory", + "ChatMessage", + "MessageHistorySchema", + "SemanticMessageHistorySchema", +] diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index 6bc4f1f5..9d61279b 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -1,157 +1,18 @@ -from typing import Any, Dict, List, Optional, Union +""" +RedisVL Standard Session Manager (Deprecated Path) -from redisvl.extensions.constants import ( - CONTENT_FIELD_NAME, - ROLE_FIELD_NAME, - TOOL_FIELD_NAME, -) -from redisvl.extensions.session_manager.schema import ChatMessage -from redisvl.utils.utils import create_ulid - - -class BaseSessionManager: - - def __init__( - self, - name: str, - session_tag: Optional[str] = None, - ): - """Initialize session memory with index - - Session Manager stores the current and previous user text prompts and - LLM responses to allow for enriching future prompts with session - context. Session history is stored in individual user or LLM prompts and - responses. - - Args: - name (str): The name of the session manager index. - session_tag (str): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - self._name = name - self._session_tag = session_tag or create_ulid() - - def clear(self) -> None: - """Clears the chat session history.""" - raise NotImplementedError - - def delete(self) -> None: - """Clear all conversation history and remove any search indices.""" - raise NotImplementedError - - def drop(self, id_field: Optional[str] = None) -> None: - """Remove a specific exchange from the conversation history. - - Args: - id_field (Optional[str]): The id_field of the entry to delete. - If None then the last entry is deleted. - """ - raise NotImplementedError - - @property - def messages(self) -> Union[List[str], List[Dict[str, str]]]: - """Returns the full chat history.""" - raise NotImplementedError - - def get_recent( - self, - top_k: int = 5, - as_text: bool = False, - raw: bool = False, - session_tag: Optional[str] = None, - ) -> Union[List[str], List[Dict[str, str]]]: - """Retreive the recent conversation history in sequential order. - - Args: - top_k (int): The number of previous exchanges to return. Default is 5. - Note that one exchange contains both a prompt and response. - as_text (bool): Whether to return the conversation as a single string, - or list of alternating prompts and responses. - raw (bool): Whether to return the full Redis hash entry or just the - prompt and response - session_tag (str): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. +This module is kept for backward compatibility. Please use `redisvl.extensions.standard_history` instead. +""" - Returns: - Union[str, List[str]]: A single string transcription of the session - or list of strings if as_text is false. +import warnings - Raises: - ValueError: If top_k is not an integer greater than or equal to 0. - """ - raise NotImplementedError +from redisvl.extensions.message_history.standard_history import StandardHistory - def _format_context( - self, messages: List[Dict[str, Any]], as_text: bool - ) -> Union[List[str], List[Dict[str, str]]]: - """Extracts the prompt and response fields from the Redis hashes and - formats them as either flat dictionaries or strings. - - Args: - messages (List[Dict[str, Any]]): The messages from the session index. - as_text (bool): Whether to return the conversation as a single string, - or list of alternating prompts and responses. - - Returns: - Union[str, List[str]]: A single string transcription of the session - or list of strings if as_text is false. - """ - context = [] - - for message in messages: - - chat_message = ChatMessage(**message) - - if as_text: - context.append(chat_message.content) - else: - chat_message_dict = { - ROLE_FIELD_NAME: chat_message.role, - CONTENT_FIELD_NAME: chat_message.content, - } - if chat_message.tool_call_id is not None: - chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id - - context.append(chat_message_dict) # type: ignore - - return context - - def store( - self, prompt: str, response: str, session_tag: Optional[str] = None - ) -> None: - """Insert a prompt:response pair into the session memory. A timestamp - is associated with each exchange so that they can be later sorted - in sequential ordering after retrieval. - - Args: - prompt (str): The user prompt to the LLM. - response (str): The corresponding LLM response. - session_tag (Optional[str]): The tag to mark the message with. Defaults to None. - """ - raise NotImplementedError - - def add_messages( - self, messages: List[Dict[str, str]], session_tag: Optional[str] = None - ) -> None: - """Insert a list of prompts and responses into the session memory. - A timestamp is associated with each so that they can be later sorted - in sequential ordering after retrieval. - - Args: - messages (List[Dict[str, str]]): The list of user prompts and LLM responses. - session_tag (Optional[str]): The tag to mark the messages with. Defaults to None. - """ - raise NotImplementedError - - def add_message( - self, message: Dict[str, str], session_tag: Optional[str] = None - ) -> None: - """Insert a single prompt or response into the session memory. - A timestamp is associated with it so that it can be later sorted - in sequential ordering after retrieval. +warnings.warn( + "Importing from redisvl.extensions.session_manager.standard_session is deprecated. " + "Please import from redisvl.extensions.message_history instead.", + DeprecationWarning, + stacklevel=2, +) - Args: - message (Dict[str,str]): The user prompt or LLM response. - session_tag (Optional[str]): The tag to mark the message with. Defaults to None. - """ - raise NotImplementedError +__all__ = ["StandardHistory"] diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 6be28f22..7a206f58 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -1,101 +1,26 @@ -from typing import Dict, List, Optional +""" +RedisVL Session Manager Schema (Deprecated Path) -from pydantic import BaseModel, ConfigDict, Field, model_validator +This module is kept for backward compatibility. Please use `redisvl.extensions.message_history.schema` instead. +""" -from redisvl.extensions.constants import ( - CONTENT_FIELD_NAME, - ID_FIELD_NAME, - ROLE_FIELD_NAME, - SESSION_FIELD_NAME, - SESSION_VECTOR_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - TOOL_FIELD_NAME, -) -from redisvl.redis.utils import array_to_buffer -from redisvl.schema import IndexSchema -from redisvl.utils.utils import current_timestamp - - -class ChatMessage(BaseModel): - """A single chat message exchanged between a user and an LLM.""" - - entry_id: Optional[str] = Field(default=None) - """A unique identifier for the message.""" - role: str # TODO -- do we enumify this? - """The role of the message sender (e.g., 'user' or 'llm').""" - content: str - """The content of the message.""" - session_tag: str - """Tag associated with the current session.""" - timestamp: Optional[float] = Field(default=None) - """The time the message was sent, in UTC, rounded to milliseconds.""" - tool_call_id: Optional[str] = Field(default=None) - """An optional identifier for a tool call associated with the message.""" - vector_field: Optional[List[float]] = Field(default=None) - """The vector representation of the message content.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="before") - @classmethod - def generate_id(cls, values): - if TIMESTAMP_FIELD_NAME not in values: - values[TIMESTAMP_FIELD_NAME] = current_timestamp() - if ID_FIELD_NAME not in values: - values[ID_FIELD_NAME] = ( - f"{values[SESSION_FIELD_NAME]}:{values[TIMESTAMP_FIELD_NAME]}" - ) - return values - - def to_dict(self, dtype: Optional[str] = None) -> Dict: - data = self.model_dump(exclude_none=True) - - # handle optional fields - if SESSION_VECTOR_FIELD_NAME in data: - data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer( - data[SESSION_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type] - ) - return data +import warnings +from redisvl.extensions.message_history.schema import ( + ChatMessage, + SemanticMessageHistorySchema, + MessageHistorySchema, +) -class StandardSessionIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, prefix: str): - - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": ROLE_FIELD_NAME, "type": "tag"}, - {"name": CONTENT_FIELD_NAME, "type": "text"}, - {"name": TOOL_FIELD_NAME, "type": "tag"}, - {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, - {"name": SESSION_FIELD_NAME, "type": "tag"}, - ], - ) - - -class SemanticSessionIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str): +warnings.warn( + "Importing from redisvl.extensions.session_manager.schema is deprecated. " + "Please import from redisvl.extensions.message_history.schema instead.", + DeprecationWarning, + stacklevel=2, +) - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": ROLE_FIELD_NAME, "type": "tag"}, - {"name": CONTENT_FIELD_NAME, "type": "text"}, - {"name": TOOL_FIELD_NAME, "type": "tag"}, - {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, - {"name": SESSION_FIELD_NAME, "type": "tag"}, - { - "name": SESSION_VECTOR_FIELD_NAME, - "type": "vector", - "attrs": { - "dims": vectorizer_dims, - "datatype": dtype, - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ], - ) +__all__ = [ + "ChatMessage", + "MessageHistory", + "SemanticMessageHistory", +] diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index d08a7002..d41727b3 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -1,379 +1,18 @@ -from typing import Any, Dict, List, Optional, Union +""" +RedisVL Semantic Session Manager (Deprecated Path) -from redis import Redis +This module is kept for backward compatibility. Please use `redisvl.extensions.semantic_history` instead. +""" -from redisvl.extensions.constants import ( - CONTENT_FIELD_NAME, - ID_FIELD_NAME, - ROLE_FIELD_NAME, - SESSION_FIELD_NAME, - SESSION_VECTOR_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - TOOL_FIELD_NAME, -) -from redisvl.extensions.session_manager import BaseSessionManager -from redisvl.extensions.session_manager.schema import ( - ChatMessage, - SemanticSessionIndexSchema, -) -from redisvl.index import SearchIndex -from redisvl.query import FilterQuery, RangeQuery -from redisvl.query.filter import Tag -from redisvl.utils.utils import deprecated_argument, validate_vector_dims -from redisvl.utils.vectorize.base import BaseVectorizer - - -class SemanticSessionManager(BaseSessionManager): - - @deprecated_argument("dtype", "vectorizer") - def __init__( - self, - name: str, - session_tag: Optional[str] = None, - prefix: Optional[str] = None, - vectorizer: Optional[BaseVectorizer] = None, - distance_threshold: float = 0.3, - redis_client: Optional[Redis] = None, - redis_url: str = "redis://localhost:6379", - connection_kwargs: Dict[str, Any] = {}, - overwrite: bool = False, - **kwargs, - ): - """Initialize session memory with index - - Session Manager stores the current and previous user text prompts and - LLM responses to allow for enriching future prompts with session - context. Session history is stored in individual user or LLM prompts and - responses. - - - Args: - name (str): The name of the session manager index. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - prefix (Optional[str]): Prefix for the keys for this session data. - Defaults to None and will be replaced with the index name. - vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings. - distance_threshold (float): The maximum semantic distance to be - included in the context. Defaults to 0.3. - redis_client (Optional[Redis]): A Redis client instance. Defaults to - None. - redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. - connection_kwargs (Dict[str, Any]): The connection arguments - for the redis client. Defaults to empty {}. - overwrite (bool): Whether or not to force overwrite the schema for - the semantic session index. Defaults to false. - - The proposed schema will support a single vector embedding constructed - from either the prompt or response in a single string. - - """ - super().__init__(name, session_tag) - - prefix = prefix or name - dtype = kwargs.pop("dtype", None) - - # Validate a provided vectorizer or set the default - if vectorizer: - if not isinstance(vectorizer, BaseVectorizer): - raise TypeError("Must provide a valid redisvl.vectorizer class.") - if dtype and vectorizer.dtype != dtype: - raise ValueError( - f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}" - ) - else: - from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer - - vectorizer_kwargs = kwargs - - if dtype: - vectorizer_kwargs.update(**{"dtype": dtype}) - - vectorizer = HFTextVectorizer( - model="sentence-transformers/all-mpnet-base-v2", - **vectorizer_kwargs, - ) - - self._vectorizer = vectorizer - - self.set_distance_threshold(distance_threshold) - - schema = SemanticSessionIndexSchema.from_params( - name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore - ) - - self._index = SearchIndex( - schema=schema, - redis_client=redis_client, - redis_url=redis_url, - **connection_kwargs, - ) - - # Check for existing session index - if not overwrite and self._index.exists(): - existing_index = SearchIndex.from_existing( - name, redis_client=self._index.client - ) - if existing_index.schema.to_dict() != self._index.schema.to_dict(): - raise ValueError( - f"Existing index {name} schema does not match the user provided schema for the semantic session. " - "If you wish to overwrite the index schema, set overwrite=True during initialization." - ) - self._index.create(overwrite=overwrite, drop=False) - - self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag - - def clear(self) -> None: - """Clears the chat session history.""" - self._index.clear() - - def delete(self) -> None: - """Clear all conversation keys and remove the search index.""" - self._index.delete(drop=True) - - def drop(self, id: Optional[str] = None) -> None: - """Remove a specific exchange from the conversation history. - - Args: - id (Optional[str]): The id of the session entry to delete. - If None then the last entry is deleted. - """ - if id is None: - id = self.get_recent(top_k=1, raw=True)[0][ID_FIELD_NAME] # type: ignore - - self._index.client.delete(self._index.key(id)) # type: ignore - - @property - def messages(self) -> Union[List[str], List[Dict[str, str]]]: - """Returns the full chat history.""" - # TODO raw or as_text? - # TODO refactor method to use get_recent and support other session tags - return_fields = [ - ID_FIELD_NAME, - SESSION_FIELD_NAME, - ROLE_FIELD_NAME, - CONTENT_FIELD_NAME, - TOOL_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - ] - - query = FilterQuery( - filter_expression=self._default_session_filter, - return_fields=return_fields, - ) - query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) - messages = self._index.query(query) - - return self._format_context(messages, as_text=False) +import warnings - def get_relevant( - self, - prompt: str, - as_text: bool = False, - top_k: int = 5, - fall_back: bool = False, - session_tag: Optional[str] = None, - raw: bool = False, - distance_threshold: Optional[float] = None, - ) -> Union[List[str], List[Dict[str, str]]]: - """Searches the chat history for information semantically related to - the specified prompt. +from redisvl.extensions.message_history.semantic_history import SemanticHistory - This method uses vector similarity search with a text prompt as input. - It checks for semantically similar prompts and responses and gets - the top k most relevant previous prompts or responses to include as - context to the next LLM call. - - Args: - prompt (str): The message text to search for in session memory - as_text (bool): Whether to return the prompts and responses as text - or as JSON - top_k (int): The number of previous messages to return. Default is 5. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - distance_threshold (Optional[float]): The threshold for semantic - vector distance. - fall_back (bool): Whether to drop back to recent conversation history - if no relevant context is found. - raw (bool): Whether to return the full Redis hash entry or just the - message. - - Returns: - Union[List[str], List[Dict[str,str]]: Either a list of strings, or a - list of prompts and responses in JSON containing the most relevant. - - Raises ValueError: if top_k is not an integer greater or equal to 0. - """ - if type(top_k) != int or top_k < 0: - raise ValueError("top_k must be an integer greater than or equal to -1") - if top_k == 0: - return [] - - # override distance threshold - distance_threshold = distance_threshold or self._distance_threshold - - return_fields = [ - SESSION_FIELD_NAME, - ROLE_FIELD_NAME, - CONTENT_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - TOOL_FIELD_NAME, - ] - - session_filter = ( - Tag(SESSION_FIELD_NAME) == session_tag - if session_tag - else self._default_session_filter - ) - - query = RangeQuery( - vector=self._vectorizer.embed(prompt), - vector_field_name=SESSION_VECTOR_FIELD_NAME, - return_fields=return_fields, - distance_threshold=distance_threshold, - num_results=top_k, - return_score=True, - filter_expression=session_filter, - dtype=self._vectorizer.dtype, - ) - messages = self._index.query(query) - - # if we don't find semantic matches fallback to returning recent context - if not messages and fall_back: - return self.get_recent(as_text=as_text, top_k=top_k, raw=raw) - if raw: - return messages - return self._format_context(messages, as_text) - - def get_recent( - self, - top_k: int = 5, - as_text: bool = False, - raw: bool = False, - session_tag: Optional[str] = None, - ) -> Union[List[str], List[Dict[str, str]]]: - """Retreive the recent conversation history in sequential order. - - Args: - top_k (int): The number of previous exchanges to return. Default is 5. - as_text (bool): Whether to return the conversation as a single string, - or list of alternating prompts and responses. - raw (bool): Whether to return the full Redis hash entry or just the - prompt and response - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - - Returns: - Union[str, List[str]]: A single string transcription of the session - or list of strings if as_text is false. - - Raises: - ValueError: if top_k is not an integer greater than or equal to 0. - """ - if type(top_k) != int or top_k < 0: - raise ValueError("top_k must be an integer greater than or equal to 0") - - return_fields = [ - ID_FIELD_NAME, - SESSION_FIELD_NAME, - ROLE_FIELD_NAME, - CONTENT_FIELD_NAME, - TOOL_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - ] - - session_filter = ( - Tag(SESSION_FIELD_NAME) == session_tag - if session_tag - else self._default_session_filter - ) - - query = FilterQuery( - filter_expression=session_filter, - return_fields=return_fields, - num_results=top_k, - ) - query.sort_by(TIMESTAMP_FIELD_NAME, asc=False) - messages = self._index.query(query) - - if raw: - return messages[::-1] - return self._format_context(messages[::-1], as_text) - - @property - def distance_threshold(self): - return self._distance_threshold - - def set_distance_threshold(self, threshold): - self._distance_threshold = threshold - - def store( - self, prompt: str, response: str, session_tag: Optional[str] = None - ) -> None: - """Insert a prompt:response pair into the session memory. A timestamp - is associated with each message so that they can be later sorted - in sequential ordering after retrieval. - - Args: - prompt (str): The user prompt to the LLM. - response (str): The corresponding LLM response. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - self.add_messages( - [ - {ROLE_FIELD_NAME: "user", CONTENT_FIELD_NAME: prompt}, - {ROLE_FIELD_NAME: "llm", CONTENT_FIELD_NAME: response}, - ], - session_tag, - ) - - def add_messages( - self, messages: List[Dict[str, str]], session_tag: Optional[str] = None - ) -> None: - """Insert a list of prompts and responses into the session memory. - A timestamp is associated with each so that they can be later sorted - in sequential ordering after retrieval. - - Args: - messages (List[Dict[str, str]]): The list of user prompts and LLM responses. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - session_tag = session_tag or self._session_tag - chat_messages: List[Dict[str, Any]] = [] - - for message in messages: - content_vector = self._vectorizer.embed(message[CONTENT_FIELD_NAME]) - validate_vector_dims( - len(content_vector), - self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.dims, # type: ignore - ) - - chat_message = ChatMessage( - role=message[ROLE_FIELD_NAME], - content=message[CONTENT_FIELD_NAME], - session_tag=session_tag, - vector_field=content_vector, # type: ignore - ) - - if TOOL_FIELD_NAME in message: - chat_message.tool_call_id = message[TOOL_FIELD_NAME] - - chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype)) - - self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) - - def add_message( - self, message: Dict[str, str], session_tag: Optional[str] = None - ) -> None: - """Insert a single prompt or response into the session memory. - A timestamp is associated with it so that it can be later sorted - in sequential ordering after retrieval. +warnings.warn( + "Importing from redisvl.extensions.session_manger.semantic_session is deprecated. " + "Please import from redisvl.extensions.semantic_history instead.", + DeprecationWarning, + stacklevel=2, +) - Args: - message (Dict[str,str]): The user prompt or LLM response. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - self.add_messages([message], session_tag) +__all__ = ["SemanticHistory"] diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 4e46010c..356e600d 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -1,236 +1,18 @@ -from typing import Any, Dict, List, Optional, Union +""" +RedisVL Standard Session Manager (Deprecated Path) -from redis import Redis +This module is kept for backward compatibility. Please use `redisvl.extensions.message_history` instead. +""" -from redisvl.extensions.constants import ( - CONTENT_FIELD_NAME, - ID_FIELD_NAME, - ROLE_FIELD_NAME, - SESSION_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - TOOL_FIELD_NAME, -) -from redisvl.extensions.session_manager import BaseSessionManager -from redisvl.extensions.session_manager.schema import ( - ChatMessage, - StandardSessionIndexSchema, -) -from redisvl.index import SearchIndex -from redisvl.query import FilterQuery -from redisvl.query.filter import Tag - - -class StandardSessionManager(BaseSessionManager): - - def __init__( - self, - name: str, - session_tag: Optional[str] = None, - prefix: Optional[str] = None, - redis_client: Optional[Redis] = None, - redis_url: str = "redis://localhost:6379", - connection_kwargs: Dict[str, Any] = {}, - **kwargs, - ): - """Initialize session memory - - Session Manager stores the current and previous user text prompts and - LLM responses to allow for enriching future prompts with session - context.Session history is stored in individual user or LLM prompts and - responses. - - Args: - name (str): The name of the session manager index. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - prefix (Optional[str]): Prefix for the keys for this session data. - Defaults to None and will be replaced with the index name. - redis_client (Optional[Redis]): A Redis client instance. Defaults to - None. - redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. - connection_kwargs (Dict[str, Any]): The connection arguments - for the redis client. Defaults to empty {}. - - The proposed schema will support a single combined vector embedding - constructed from the prompt & response in a single string. - - """ - super().__init__(name, session_tag) - - prefix = prefix or name - - schema = StandardSessionIndexSchema.from_params(name, prefix) - - self._index = SearchIndex( - schema=schema, - redis_client=redis_client, - redis_url=redis_url, - **connection_kwargs, - ) - - self._index.create(overwrite=False) - - self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag - - def clear(self) -> None: - """Clears the chat session history.""" - self._index.clear() - - def delete(self) -> None: - """Clear all conversation keys and remove the search index.""" - self._index.delete(drop=True) - - def drop(self, id: Optional[str] = None) -> None: - """Remove a specific exchange from the conversation history. - - Args: - id (Optional[str]): The id of the session entry to delete. - If None then the last entry is deleted. - """ - if id is None: - id = self.get_recent(top_k=1, raw=True)[0][ID_FIELD_NAME] # type: ignore - - self._index.client.delete(self._index.key(id)) # type: ignore - - @property - def messages(self) -> Union[List[str], List[Dict[str, str]]]: - """Returns the full chat history.""" - # TODO raw or as_text? - # TODO refactor this method to use get_recent and support other session tags? - return_fields = [ - ID_FIELD_NAME, - SESSION_FIELD_NAME, - ROLE_FIELD_NAME, - CONTENT_FIELD_NAME, - TOOL_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - ] +import warnings - query = FilterQuery( - filter_expression=self._default_session_filter, - return_fields=return_fields, - ) - query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) - messages = self._index.query(query) +from redisvl.extensions.message_history.standard_history import MessageHistory - return self._format_context(messages, as_text=False) - - def get_recent( - self, - top_k: int = 5, - as_text: bool = False, - raw: bool = False, - session_tag: Optional[str] = None, - ) -> Union[List[str], List[Dict[str, str]]]: - """Retrieve the recent conversation history in sequential order. - - Args: - top_k (int): The number of previous messages to return. Default is 5. - as_text (bool): Whether to return the conversation as a single string, - or list of alternating prompts and responses. - raw (bool): Whether to return the full Redis hash entry or just the - prompt and response - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - - Returns: - Union[str, List[str]]: A single string transcription of the session - or list of strings if as_text is false. - - Raises: - ValueError: if top_k is not an integer greater than or equal to 0. - """ - if type(top_k) != int or top_k < 0: - raise ValueError("top_k must be an integer greater than or equal to 0") - - return_fields = [ - ID_FIELD_NAME, - SESSION_FIELD_NAME, - ROLE_FIELD_NAME, - CONTENT_FIELD_NAME, - TOOL_FIELD_NAME, - TIMESTAMP_FIELD_NAME, - ] - - session_filter = ( - Tag(SESSION_FIELD_NAME) == session_tag - if session_tag - else self._default_session_filter - ) - - query = FilterQuery( - filter_expression=session_filter, - return_fields=return_fields, - num_results=top_k, - ) - query.sort_by(TIMESTAMP_FIELD_NAME, asc=False) - messages = self._index.query(query) - - if raw: - return messages[::-1] - return self._format_context(messages[::-1], as_text) - - def store( - self, prompt: str, response: str, session_tag: Optional[str] = None - ) -> None: - """Insert a prompt:response pair into the session memory. A timestamp - is associated with each exchange so that they can be later sorted - in sequential ordering after retrieval. - - Args: - prompt (str): The user prompt to the LLM. - response (str): The corresponding LLM response. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - self.add_messages( - [ - {ROLE_FIELD_NAME: "user", CONTENT_FIELD_NAME: prompt}, - {ROLE_FIELD_NAME: "llm", CONTENT_FIELD_NAME: response}, - ], - session_tag, - ) - - def add_messages( - self, messages: List[Dict[str, str]], session_tag: Optional[str] = None - ) -> None: - """Insert a list of prompts and responses into the session memory. - A timestamp is associated with each so that they can be later sorted - in sequential ordering after retrieval. - - Args: - messages (List[Dict[str, str]]): The list of user prompts and LLM responses. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - session_tag = session_tag or self._session_tag - chat_messages: List[Dict[str, Any]] = [] - - for message in messages: - - chat_message = ChatMessage( - role=message[ROLE_FIELD_NAME], - content=message[CONTENT_FIELD_NAME], - session_tag=session_tag, - ) - - if TOOL_FIELD_NAME in message: - chat_message.tool_call_id = message[TOOL_FIELD_NAME] - - chat_messages.append(chat_message.to_dict()) - - self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) - - def add_message( - self, message: Dict[str, str], session_tag: Optional[str] = None - ) -> None: - """Insert a single prompt or response into the session memory. - A timestamp is associated with it so that it can be later sorted - in sequential ordering after retrieval. +warnings.warn( + "Importing from redisvl.extensions.session_manger.standard_session is deprecated. " + "Please import from redisvl.extensions.message_history instead.", + DeprecationWarning, + stacklevel=2, +) - Args: - message (Dict[str,str]): The user prompt or LLM response. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance ULID. - """ - self.add_messages([message], session_tag) +__all__ = ["MessageHistory"] diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_message_history.py similarity index 65% rename from tests/integration/test_session_manager.py rename to tests/integration/test_message_history.py index 59d64b97..4123bf65 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_message_history.py @@ -5,10 +5,7 @@ from redisvl.exceptions import RedisModuleVersionError from redisvl.extensions.constants import ID_FIELD_NAME -from redisvl.extensions.session_manager import ( - SemanticSessionManager, - StandardSessionManager, -) +from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer @@ -18,18 +15,18 @@ def app_name(): @pytest.fixture -def standard_session(app_name, client): - session = StandardSessionManager(app_name, redis_client=client) - yield session - session.clear() +def standard_history(app_name, client): + history = MessageHistory(app_name, redis_client=client) + yield history + history.clear() @pytest.fixture -def semantic_session(app_name, client): - session = SemanticSessionManager(app_name, redis_client=client, overwrite=True) - yield session - session.clear() - session.delete() +def semantic_history(app_name, client): + history = SemanticMessageHistory(app_name, redis_client=client, overwrite=True) + yield history + history.clear() + history.delete() @pytest.fixture(autouse=True) @@ -39,42 +36,42 @@ def disable_deprecation_warnings(): yield -# test standard session manager +# test standard message history def test_specify_redis_client(client): - session = StandardSessionManager(name="test_app", redis_client=client) - assert isinstance(session._index.client, type(client)) + history = MessageHistory(name="test_app", redis_client=client) + assert isinstance(history._index.client, type(client)) def test_specify_redis_url(client, redis_url): - session = StandardSessionManager( + history = MessageHistory( name="test_app", session_tag="abc", redis_url=redis_url, ) - assert isinstance(session._index.client, type(client)) + assert isinstance(history._index.client, type(client)) def test_standard_bad_connection_info(): with pytest.raises(ConnectionError): - StandardSessionManager( + MessageHistory( name="test_app", session_tag="abc", redis_url="redis://localhost:6389", # bad url ) -def test_standard_store(standard_session): - context = standard_session.get_recent() +def test_standard_store(standard_history): + context = standard_history.get_recent() assert len(context) == 0 - standard_session.store(prompt="first prompt", response="first response") - standard_session.store(prompt="second prompt", response="second response") - standard_session.store(prompt="third prompt", response="third response") - standard_session.store(prompt="fourth prompt", response="fourth response") - standard_session.store(prompt="fifth prompt", response="fifth response") + standard_history.store(prompt="first prompt", response="first response") + standard_history.store(prompt="second prompt", response="second response") + standard_history.store(prompt="third prompt", response="third response") + standard_history.store(prompt="fourth prompt", response="fourth response") + standard_history.store(prompt="fifth prompt", response="fifth response") # test that order is maintained - full_context = standard_session.get_recent(top_k=10) + full_context = standard_history.get_recent(top_k=10) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -89,37 +86,37 @@ def test_standard_store(standard_session): ] -def test_standard_add_and_get(standard_session): - context = standard_session.get_recent() +def test_standard_add_and_get(standard_history): + context = standard_history.get_recent() assert len(context) == 0 - standard_session.add_message({"role": "user", "content": "first prompt"}) - standard_session.add_message({"role": "llm", "content": "first response"}) - standard_session.add_message({"role": "user", "content": "second prompt"}) - standard_session.add_message({"role": "llm", "content": "second response"}) - standard_session.add_message( + standard_history.add_message({"role": "user", "content": "first prompt"}) + standard_history.add_message({"role": "llm", "content": "first response"}) + standard_history.add_message({"role": "user", "content": "second prompt"}) + standard_history.add_message({"role": "llm", "content": "second response"}) + standard_history.add_message( { "role": "tool", "content": "tool result 1", "tool_call_id": "tool call one", } ) - standard_session.add_message( + standard_history.add_message( { "role": "tool", "content": "tool result 2", "tool_call_id": "tool call two", } ) - standard_session.add_message({"role": "user", "content": "third prompt"}) - standard_session.add_message({"role": "llm", "content": "third response"}) + standard_history.add_message({"role": "user", "content": "third prompt"}) + standard_history.add_message({"role": "llm", "content": "third response"}) # test default context history size - default_context = standard_session.get_recent() + default_context = standard_history.get_recent() assert len(default_context) == 5 # default is 5 # test specified context history size - partial_context = standard_session.get_recent(top_k=3) + partial_context = standard_history.get_recent(top_k=3) assert len(partial_context) == 3 assert partial_context == [ {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, @@ -128,7 +125,7 @@ def test_standard_add_and_get(standard_session): ] # test that order is maintained - full_context = standard_session.get_recent(top_k=10) + full_context = standard_history.get_recent(top_k=10) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -142,23 +139,23 @@ def test_standard_add_and_get(standard_session): # test that a ValueError is raised when top_k is invalid with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k=-2) + bad_context = standard_history.get_recent(top_k=-2) with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k=-2.0) + bad_context = standard_history.get_recent(top_k=-2.0) with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k=1.3) + bad_context = standard_history.get_recent(top_k=1.3) with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k="3") + bad_context = standard_history.get_recent(top_k="3") -def test_standard_add_messages(standard_session): - context = standard_session.get_recent() +def test_standard_add_messages(standard_history): + context = standard_history.get_recent() assert len(context) == 0 - standard_session.add_messages( + standard_history.add_messages( [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -179,7 +176,7 @@ def test_standard_add_messages(standard_session): ] ) - full_context = standard_session.get_recent(top_k=10) + full_context = standard_history.get_recent(top_k=10) assert len(full_context) == 8 assert full_context == [ {"role": "user", "content": "first prompt"}, @@ -193,8 +190,8 @@ def test_standard_add_messages(standard_session): ] -def test_standard_messages_property(standard_session): - standard_session.add_messages( +def test_standard_messages_property(standard_history): + standard_history.add_messages( [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -204,7 +201,7 @@ def test_standard_messages_property(standard_session): ] ) - assert standard_session.messages == [ + assert standard_history.messages == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, {"role": "user", "content": "second prompt"}, @@ -213,47 +210,47 @@ def test_standard_messages_property(standard_session): ] -def test_standard_scope(standard_session): +def test_standard_scope(standard_history): # store entries under default session tag - standard_session.store("some prompt", "some response") + standard_history.store("some prompt", "some response") # test that changing session tag does indeed change access scope new_session = "def" - standard_session.store( + standard_history.store( "new user prompt", "new user response", session_tag=new_session ) - context = standard_session.get_recent(session_tag=new_session) + context = standard_history.get_recent(session_tag=new_session) assert context == [ {"role": "user", "content": "new user prompt"}, {"role": "llm", "content": "new user response"}, ] # test that default session data is still accessible - context = standard_session.get_recent() + context = standard_history.get_recent() assert context == [ {"role": "user", "content": "some prompt"}, {"role": "llm", "content": "some response"}, ] bad_session = "xyz" - no_context = standard_session.get_recent(session_tag=bad_session) + no_context = standard_history.get_recent(session_tag=bad_session) assert no_context == [] -def test_standard_get_text(standard_session): - standard_session.store("first prompt", "first response") - text = standard_session.get_recent(as_text=True) +def test_standard_get_text(standard_history): + standard_history.store("first prompt", "first response") + text = standard_history.get_recent(as_text=True) assert text == ["first prompt", "first response"] - standard_session.add_message({"role": "system", "content": "system level prompt"}) - text = standard_session.get_recent(as_text=True) + standard_history.add_message({"role": "system", "content": "system level prompt"}) + text = standard_history.get_recent(as_text=True) assert text == ["first prompt", "first response", "system level prompt"] -def test_standard_get_raw(standard_session): - standard_session.store("first prompt", "first response") - standard_session.store("second prompt", "second response") - raw = standard_session.get_recent(raw=True) +def test_standard_get_raw(standard_history): + standard_history.store("first prompt", "first response") + standard_history.store("second prompt", "second response") + raw = standard_history.get_recent(raw=True) assert len(raw) == 4 assert raw[0]["role"] == "user" assert raw[0]["content"] == "first prompt" @@ -261,15 +258,15 @@ def test_standard_get_raw(standard_session): assert raw[1]["content"] == "first response" -def test_standard_drop(standard_session): - standard_session.store("first prompt", "first response") - standard_session.store("second prompt", "second response") - standard_session.store("third prompt", "third response") - standard_session.store("fourth prompt", "fourth response") +def test_standard_drop(standard_history): + standard_history.store("first prompt", "first response") + standard_history.store("second prompt", "second response") + standard_history.store("third prompt", "third response") + standard_history.store("fourth prompt", "fourth response") # test drop() with no arguments removes the last element - standard_session.drop() - context = standard_session.get_recent(top_k=3) + standard_history.drop() + context = standard_history.get_recent(top_k=3) assert context == [ {"role": "user", "content": "third prompt"}, {"role": "llm", "content": "third response"}, @@ -277,10 +274,10 @@ def test_standard_drop(standard_session): ] # test drop(id) removes the specified element - context = standard_session.get_recent(top_k=10, raw=True) + context = standard_history.get_recent(top_k=10, raw=True) middle_id = context[3][ID_FIELD_NAME] - standard_session.drop(middle_id) - context = standard_session.get_recent(top_k=6) + standard_history.drop(middle_id) + context = standard_history.get_recent(top_k=6) assert context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -291,82 +288,82 @@ def test_standard_drop(standard_session): ] -def test_standard_clear(standard_session): - standard_session.store("some prompt", "some response") - standard_session.clear() - empty_context = standard_session.get_recent(top_k=10) +def test_standard_clear(standard_history): + standard_history.store("some prompt", "some response") + standard_history.clear() + empty_context = standard_history.get_recent(top_k=10) assert empty_context == [] -# test semantic session manager +# test semantic message history def test_semantic_specify_client(client): - session = SemanticSessionManager( + history = SemanticMessageHistory( name="test_app", session_tag="abc", redis_client=client, overwrite=True ) - assert isinstance(session._index.client, type(client)) + assert isinstance(history._index.client, type(client)) def test_semantic_bad_connection_info(): with pytest.raises(ConnectionError): - SemanticSessionManager( + SemanticMessageHistory( name="test_app", session_tag="abc", redis_url="redis://localhost:6389", ) -def test_semantic_scope(semantic_session): +def test_semantic_scope(semantic_history): # store entries under default session tag - semantic_session.store("some prompt", "some response") + semantic_history.store("some prompt", "some response") # test that changing session tag does indeed change access scope new_session = "def" - semantic_session.store( + semantic_history.store( "new user prompt", "new user response", session_tag=new_session ) - context = semantic_session.get_recent(session_tag=new_session) + context = semantic_history.get_recent(session_tag=new_session) assert context == [ {"role": "user", "content": "new user prompt"}, {"role": "llm", "content": "new user response"}, ] # test that previous session data is still accessible - context = semantic_session.get_recent() + context = semantic_history.get_recent() assert context == [ {"role": "user", "content": "some prompt"}, {"role": "llm", "content": "some response"}, ] bad_session = "xyz" - no_context = semantic_session.get_recent(session_tag=bad_session) + no_context = semantic_history.get_recent(session_tag=bad_session) assert no_context == [] -def test_semantic_store_and_get_recent(semantic_session): - context = semantic_session.get_recent() +def test_semantic_store_and_get_recent(semantic_history): + context = semantic_history.get_recent() assert len(context) == 0 - semantic_session.store(prompt="first prompt", response="first response") - semantic_session.store(prompt="second prompt", response="second response") - semantic_session.store(prompt="third prompt", response="third response") - semantic_session.store(prompt="fourth prompt", response="fourth response") - semantic_session.add_message( + semantic_history.store(prompt="first prompt", response="first response") + semantic_history.store(prompt="second prompt", response="second response") + semantic_history.store(prompt="third prompt", response="third response") + semantic_history.store(prompt="fourth prompt", response="fourth response") + semantic_history.add_message( {"role": "tool", "content": "tool result", "tool_call_id": "tool id"} ) # test default context history size - default_context = semantic_session.get_recent() + default_context = semantic_history.get_recent() assert len(default_context) == 5 # 5 is default # test specified context history size - partial_context = semantic_session.get_recent(top_k=4) + partial_context = semantic_history.get_recent(top_k=4) assert len(partial_context) == 4 # test larger context history returns full history - too_large_context = semantic_session.get_recent(top_k=100) + too_large_context = semantic_history.get_recent(top_k=100) assert len(too_large_context) == 9 # test that order is maintained - full_context = semantic_session.get_recent(top_k=9) + full_context = semantic_history.get_recent(top_k=9) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -380,7 +377,7 @@ def test_semantic_store_and_get_recent(semantic_session): ] # test that more recent entries are returned - context = semantic_session.get_recent(top_k=4) + context = semantic_history.get_recent(top_k=4) assert context == [ {"role": "llm", "content": "third response"}, {"role": "user", "content": "fourth prompt"}, @@ -389,28 +386,28 @@ def test_semantic_store_and_get_recent(semantic_session): ] # test no entries are returned and no error is raised if top_k == 0 - context = semantic_session.get_recent(top_k=0) + context = semantic_history.get_recent(top_k=0) assert context == [] # test that a ValueError is raised when top_k is invalid with pytest.raises(ValueError): - bad_context = semantic_session.get_recent(top_k=0.5) + bad_context = semantic_history.get_recent(top_k=0.5) with pytest.raises(ValueError): - bad_context = semantic_session.get_recent(top_k=-1) + bad_context = semantic_history.get_recent(top_k=-1) with pytest.raises(ValueError): - bad_context = semantic_session.get_recent(top_k=-2.0) + bad_context = semantic_history.get_recent(top_k=-2.0) with pytest.raises(ValueError): - bad_context = semantic_session.get_recent(top_k=1.3) + bad_context = semantic_history.get_recent(top_k=1.3) with pytest.raises(ValueError): - bad_context = semantic_session.get_recent(top_k="3") + bad_context = semantic_history.get_recent(top_k="3") -def test_semantic_messages_property(semantic_session): - semantic_session.add_messages( +def test_semantic_messages_property(semantic_history): + semantic_history.add_messages( [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -430,7 +427,7 @@ def test_semantic_messages_property(semantic_session): ] ) - assert semantic_session.messages == [ + assert semantic_history.messages == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, {"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"}, @@ -441,23 +438,23 @@ def test_semantic_messages_property(semantic_session): ] -def test_semantic_add_and_get_relevant(semantic_session): - semantic_session.add_message( +def test_semantic_add_and_get_relevant(semantic_history): + semantic_history.add_message( {"role": "system", "content": "discussing common fruits and vegetables"} ) - semantic_session.store( + semantic_history.store( prompt="list of common fruits", response="apples, oranges, bananas, strawberries", ) - semantic_session.store( + semantic_history.store( prompt="list of common vegetables", response="carrots, broccoli, onions, spinach", ) - semantic_session.store( + semantic_history.store( prompt="winter sports in the olympics", response="downhill skiing, ice skating, luge", ) - semantic_session.add_message( + semantic_history.add_message( { "role": "tool", "content": "skiing, skating, luge", @@ -466,7 +463,7 @@ def test_semantic_add_and_get_relevant(semantic_session): ) # test default distance metric - default_context = semantic_session.get_relevant( + default_context = semantic_history.get_relevant( "set of common fruits like apples and bananas" ) assert len(default_context) == 2 @@ -477,15 +474,15 @@ def test_semantic_add_and_get_relevant(semantic_session): } # test increasing distance metric broadens results - semantic_session.set_distance_threshold(0.5) - default_context = semantic_session.get_relevant("list of fruits and vegetables") + semantic_history.set_distance_threshold(0.5) + default_context = semantic_history.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system - assert default_context == semantic_session.get_relevant( + assert default_context == semantic_history.get_relevant( "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned - context = semantic_session.get_relevant("winter sports like skiing") + context = semantic_history.get_relevant("winter sports like skiing") assert context == [ { "role": "user", @@ -504,22 +501,22 @@ def test_semantic_add_and_get_relevant(semantic_session): # test that a ValueError is raised when top_k is invalid with pytest.raises(ValueError): - bad_context = semantic_session.get_relevant("test prompt", top_k=-1) + bad_context = semantic_history.get_relevant("test prompt", top_k=-1) with pytest.raises(ValueError): - bad_context = semantic_session.get_relevant("test prompt", top_k=-2.0) + bad_context = semantic_history.get_relevant("test prompt", top_k=-2.0) with pytest.raises(ValueError): - bad_context = semantic_session.get_relevant("test prompt", top_k=1.3) + bad_context = semantic_history.get_relevant("test prompt", top_k=1.3) with pytest.raises(ValueError): - bad_context = semantic_session.get_relevant("test prompt", top_k="3") + bad_context = semantic_history.get_relevant("test prompt", top_k="3") -def test_semantic_get_raw(semantic_session): - semantic_session.store("first prompt", "first response") - semantic_session.store("second prompt", "second response") - raw = semantic_session.get_recent(raw=True) +def test_semantic_get_raw(semantic_history): + semantic_history.store("first prompt", "first response") + semantic_history.store("second prompt", "second response") + raw = semantic_history.get_recent(raw=True) assert len(raw) == 4 assert raw[0]["role"] == "user" assert raw[0]["content"] == "first prompt" @@ -527,15 +524,15 @@ def test_semantic_get_raw(semantic_session): assert raw[1]["content"] == "first response" -def test_semantic_drop(semantic_session): - semantic_session.store("first prompt", "first response") - semantic_session.store("second prompt", "second response") - semantic_session.store("third prompt", "third response") - semantic_session.store("fourth prompt", "fourth response") +def test_semantic_drop(semantic_history): + semantic_history.store("first prompt", "first response") + semantic_history.store("second prompt", "second response") + semantic_history.store("third prompt", "third response") + semantic_history.store("fourth prompt", "fourth response") # test drop() with no arguments removes the last element - semantic_session.drop() - context = semantic_session.get_recent(top_k=3) + semantic_history.drop() + context = semantic_history.get_recent(top_k=3) assert context == [ {"role": "user", "content": "third prompt"}, {"role": "llm", "content": "third response"}, @@ -543,10 +540,10 @@ def test_semantic_drop(semantic_session): ] # test drop(id) removes the specified element - context = semantic_session.get_recent(top_k=5, raw=True) + context = semantic_history.get_recent(top_k=5, raw=True) middle_id = context[2][ID_FIELD_NAME] - semantic_session.drop(middle_id) - context = semantic_session.get_recent(top_k=4) + semantic_history.drop(middle_id) + context = semantic_history.get_recent(top_k=4) assert context == [ {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, @@ -557,16 +554,16 @@ def test_semantic_drop(semantic_session): def test_different_vector_dtypes(): try: - bfloat_sess = SemanticSessionManager(name="bfloat_session", dtype="bfloat16") + bfloat_sess = SemanticMessageHistory(name="bfloat_history", dtype="bfloat16") bfloat_sess.add_message({"role": "user", "content": "bfloat message"}) - float16_sess = SemanticSessionManager(name="float16_session", dtype="float16") + float16_sess = SemanticMessageHistory(name="float16_history", dtype="float16") float16_sess.add_message({"role": "user", "content": "float16 message"}) - float32_sess = SemanticSessionManager(name="float32_session", dtype="float32") + float32_sess = SemanticMessageHistory(name="float32_history", dtype="float32") float32_sess.add_message({"role": "user", "content": "float32 message"}) - float64_sess = SemanticSessionManager(name="float64_session", dtype="float64") + float64_sess = SemanticMessageHistory(name="float64_history", dtype="float64") float64_sess.add_message({"role": "user", "content": "float64 message"}) for sess in [bfloat_sess, float16_sess, float32_sess, float64_sess]: @@ -576,27 +573,27 @@ def test_different_vector_dtypes(): pytest.skip("Not using a late enough version of Redis") -def test_bad_dtype_connecting_to_exiting_session(redis_url): +def test_bad_dtype_connecting_to_exiting_history(redis_url): try: - session = SemanticSessionManager( - name="float64 session", dtype="float64", redis_url=redis_url + history = SemanticMessageHistory( + name="float64 history", dtype="float64", redis_url=redis_url ) - same_type = SemanticSessionManager( - name="float64 session", dtype="float64", redis_url=redis_url + same_type = SemanticMessageHistory( + name="float64 history", dtype="float64", redis_url=redis_url ) # under the hood uses from_existing except RedisModuleVersionError: pytest.skip("Not using a late enough version of Redis") with pytest.raises(ValueError): - bad_type = SemanticSessionManager( - name="float64 session", dtype="float16", redis_url=redis_url + bad_type = SemanticMessageHistory( + name="float64 history", dtype="float16", redis_url=redis_url ) def test_vectorizer_dtype_mismatch(redis_url): with pytest.raises(ValueError): - SemanticSessionManager( + SemanticMessageHistory( name="test_dtype_mismatch", dtype="float32", vectorizer=HFTextVectorizer(dtype="float16"), @@ -607,7 +604,7 @@ def test_vectorizer_dtype_mismatch(redis_url): def test_invalid_vectorizer(redis_url): with pytest.raises(TypeError): - SemanticSessionManager( + SemanticMessageHistory( name="test_invalid_vectorizer", vectorizer="invalid_vectorizer", # type: ignore redis_url=redis_url, @@ -617,7 +614,7 @@ def test_invalid_vectorizer(redis_url): def test_passes_through_dtype_to_default_vectorizer(redis_url): # The default is float32, so we should see float64 if we pass it in. - cache = SemanticSessionManager( + cache = SemanticMessageHistory( name="test_pass_through_dtype", dtype="float64", redis_url=redis_url, @@ -628,6 +625,6 @@ def test_passes_through_dtype_to_default_vectorizer(redis_url): def test_deprecated_dtype_argument(redis_url): with pytest.warns(DeprecationWarning): - SemanticSessionManager( - name="float64 session", dtype="float64", redis_url=redis_url, overwrite=True + SemanticMessageHistory( + name="float64 history", dtype="float64", redis_url=redis_url, overwrite=True ) diff --git a/tests/unit/test_session_schema.py b/tests/unit/test_message_history_schema.py similarity index 98% rename from tests/unit/test_session_schema.py rename to tests/unit/test_message_history_schema.py index 5bd2c221..6143d2da 100644 --- a/tests/unit/test_session_schema.py +++ b/tests/unit/test_message_history_schema.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.extensions.message_history.schema import ChatMessage from redisvl.redis.utils import array_to_buffer from redisvl.utils.utils import create_ulid, current_timestamp From 7a2a2b255a8f2732528e80305c3797afb1ee329e Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 18 Apr 2025 14:06:06 -0700 Subject: [PATCH 2/4] wraps message history classes for backward compatibility. improves warnings --- .../extensions/message_history/__init__.py | 4 +--- ...message_history.py => semantic_history.py} | 0 .../extensions/session_manager/__init__.py | 23 ++++++------------- .../session_manager/base_session.py | 17 ++++++++++---- redisvl/extensions/session_manager/schema.py | 17 +++++++++++--- .../session_manager/semantic_session.py | 13 ++++++++--- .../session_manager/standard_session.py | 13 ++++++++--- 7 files changed, 54 insertions(+), 33 deletions(-) rename redisvl/extensions/message_history/{semantic_message_history.py => semantic_history.py} (100%) diff --git a/redisvl/extensions/message_history/__init__.py b/redisvl/extensions/message_history/__init__.py index a98d8f2a..c040b698 100644 --- a/redisvl/extensions/message_history/__init__.py +++ b/redisvl/extensions/message_history/__init__.py @@ -1,7 +1,5 @@ from redisvl.extensions.message_history.base_history import BaseMessageHistory from redisvl.extensions.message_history.message_history import MessageHistory -from redisvl.extensions.message_history.semantic_message_history import ( - SemanticMessageHistory, -) +from redisvl.extensions.message_history.semantic_history import SemanticMessageHistory __all__ = ["BaseMessageHistory", "MessageHistory", "SemanticMessageHistory"] diff --git a/redisvl/extensions/message_history/semantic_message_history.py b/redisvl/extensions/message_history/semantic_history.py similarity index 100% rename from redisvl/extensions/message_history/semantic_message_history.py rename to redisvl/extensions/message_history/semantic_history.py diff --git a/redisvl/extensions/session_manager/__init__.py b/redisvl/extensions/session_manager/__init__.py index 619c2ee5..d86b69c1 100644 --- a/redisvl/extensions/session_manager/__init__.py +++ b/redisvl/extensions/session_manager/__init__.py @@ -6,27 +6,18 @@ import warnings -from redisvl.extensions.message_history.message_history import MessageHistory -from redisvl.extensions.message_history.schema import ( - ChatMessage, - MessageHistorySchema, - SemanticMessageHistorySchema, -) -from redisvl.extensions.message_history.semantic_message_history import ( - SemanticMessageHistory, -) +from redisvl.extensions.session_manager.base_session import BaseSessionManager +from redisvl.extensions.session_manager.semantic_session import SemanticSessionManager +from redisvl.extensions.session_manager.standard_session import StandardSessionManager warnings.warn( "Importing from redisvl.extensions.session_manager is deprecated. " + "StandardSessionManager has been renamed to MessageHistory. " + "SemanticSessionManager has been renamed to SemanticMessageHistory. " "Please import from redisvl.extensions.message_history instead.", DeprecationWarning, stacklevel=2, ) -__all__ = [ - "MessageHistory", - "SemanticMessageHistory", - "ChatMessage", - "MessageHistorySchema", - "SemanticMessageHistorySchema", -] + +__all__ = ["BaseSessionManager", "StandardSessionManager", "SemanticSessionManager"] diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index 9d61279b..92d30c77 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -1,18 +1,25 @@ """ RedisVL Standard Session Manager (Deprecated Path) -This module is kept for backward compatibility. Please use `redisvl.extensions.standard_history` instead. +This module is kept for backward compatibility. Please use `redisvl.extensions.base_history` instead. """ import warnings -from redisvl.extensions.message_history.standard_history import StandardHistory +from redisvl.extensions.message_history.base_history import BaseMessageHistory warnings.warn( - "Importing from redisvl.extensions.session_manager.standard_session is deprecated. " - "Please import from redisvl.extensions.message_history instead.", + "Importing from redisvl.extensions.session_manager.base_session is deprecated. " + "BaseSessionManager has been renamed to BaseMessageHistory. " + "Please import BaseMessageHistory from redisvl.extensions.base_history instead.", DeprecationWarning, stacklevel=2, ) -__all__ = ["StandardHistory"] + +class BaseSessionManager(BaseMessageHistory): + # keep for backward compatibility + pass + + +__all__ = ["BaseSessionManager"] diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 7a206f58..28d44147 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -8,8 +8,8 @@ from redisvl.extensions.message_history.schema import ( ChatMessage, - SemanticMessageHistorySchema, MessageHistorySchema, + SemanticMessageHistorySchema, ) warnings.warn( @@ -19,8 +19,19 @@ stacklevel=2, ) + +class StandardSessionIndexSchema(MessageHistorySchema): + # keep for backward compatibility + pass + + +class SemanticSessionIndexSchema(SemanticMessageHistorySchema): + # keep for backward compatibility + pass + + __all__ = [ "ChatMessage", - "MessageHistory", - "SemanticMessageHistory", + "StandardSessionIndexSchema", + "SemanticSessionIndexSchema", ] diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index d41727b3..fbdf342e 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -6,13 +6,20 @@ import warnings -from redisvl.extensions.message_history.semantic_history import SemanticHistory +from redisvl.extensions.message_history.semantic_history import SemanticMessageHistory warnings.warn( "Importing from redisvl.extensions.session_manger.semantic_session is deprecated. " - "Please import from redisvl.extensions.semantic_history instead.", + "SemanticSessionManager has been renamed to SemanticMessageHistory. " + "Please import SemanticMessageHistory from redisvl.extensions.semantic_history instead.", DeprecationWarning, stacklevel=2, ) -__all__ = ["SemanticHistory"] + +class SemanticSessionManager(SemanticMessageHistory): + # keep for backwards compatibility + pass + + +__all__ = ["SemanticSessionManager"] diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 356e600d..ba7f0e57 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -6,13 +6,20 @@ import warnings -from redisvl.extensions.message_history.standard_history import MessageHistory +from redisvl.extensions.message_history.message_history import MessageHistory warnings.warn( "Importing from redisvl.extensions.session_manger.standard_session is deprecated. " - "Please import from redisvl.extensions.message_history instead.", + "StandardSessionManager has been renamed to MessageHistory. " + "Please import MessageHistory from redisvl.extensions.message_history instead.", DeprecationWarning, stacklevel=2, ) -__all__ = ["MessageHistory"] + +class StandardSessionManager(MessageHistory): + # keep for backward compatibility + pass + + +__all__ = ["StandardSessionManager"] From 79902960d6ec6d2815bd93f9bd0cd228a0dbb6ef Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 18 Apr 2025 14:06:41 -0700 Subject: [PATCH 3/4] updates message history notebook --- docs/user_guide/07_message_history.ipynb | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/user_guide/07_message_history.ipynb b/docs/user_guide/07_message_history.ipynb index 80f0ae14..baf6ee0c 100644 --- a/docs/user_guide/07_message_history.ipynb +++ b/docs/user_guide/07_message_history.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -144,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -200,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -220,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -251,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -284,7 +284,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -317,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ From 93f3640bbc31167fa90d73321ecd8576b8d57f06 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 22 Apr 2025 09:37:26 -0700 Subject: [PATCH 4/4] renames session manager in .md files --- README.md | 16 ++++++++-------- docs/api/index.md | 2 +- docs/user_guide/index.md | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f5eaceb1..bfe63ccc 100644 --- a/README.md +++ b/README.md @@ -264,20 +264,20 @@ print(response[0]["response"]) > Learn more about [semantic caching]((https://docs.redisvl.com/en/stable/user_guide/03_llmcache.html)) for LLMs. -### LLM Session Management +### LLM Memory History -Improve personalization and accuracy of LLM responses by providing user chat history as context. Manage access to the session data using recency or relevancy, *powered by vector search* with the [`SemanticSessionManager`](https://docs.redisvl.com/en/stable/api/session_manager.html). +Improve personalization and accuracy of LLM responses by providing user message history as context. Manage access to message history data using recency or relevancy, *powered by vector search* with the [`MessageHistory`](https://docs.redisvl.com/en/stable/api/message_history.html). ```python -from redisvl.extensions.session_manager import SemanticSessionManager +from redisvl.extensions.message_history import SemanticMessageHistory -session = SemanticSessionManager( +history = SemanticMessageHistory( name="my-session", redis_url="redis://localhost:6379", distance_threshold=0.7 ) -session.add_messages([ +history.add_messages([ {"role": "user", "content": "hello, how are you?"}, {"role": "assistant", "content": "I'm doing fine, thanks."}, {"role": "user", "content": "what is the weather going to be today?"}, @@ -286,19 +286,19 @@ session.add_messages([ ``` Get recent chat history: ```python -session.get_recent(top_k=1) +history.get_recent(top_k=1) ``` ```stdout >>> [{"role": "assistant", "content": "I don't know"}] ``` Get relevant chat history (powered by vector search): ```python -session.get_relevant("weather", top_k=1) +history.get_relevant("weather", top_k=1) ``` ```stdout >>> [{"role": "user", "content": "what is the weather going to be today?"}] ``` -> Learn more about [LLM session management]((https://docs.redisvl.com/en/stable/user_guide/07_session_manager.html)). +> Learn more about [LLM message history]((https://docs.redisvl.com/en/stable/user_guide/07_message_history.html)). ### LLM Semantic Routing diff --git a/docs/api/index.md b/docs/api/index.md index e3dc486e..4a98c69a 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -20,7 +20,7 @@ filter vectorizer reranker cache -session_manager +message_history router threshold_optimizer ``` diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 4d6e5c04..873522ae 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -19,7 +19,7 @@ User guides provide helpful resources for using RedisVL and its different compon 04_vectorizers 05_hash_vs_json 06_rerankers -07_session_manager +07_message_history 08_semantic_router 09_threshold_optimization release_guide/index