diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 320695006..850d2b48b 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -938,7 +938,6 @@ async def generate_async( input=messages, response=res, adapters=self._log_adapters ) await tracer.export_async() - res = res.response[0] return res else: # If a prompt is used, we only return the content of the message. diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 2e51e8f48..d376ab4e7 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -14,11 +14,15 @@ # limitations under the License. import asyncio +import os + +import pytest import unittest -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch +from nemoguardrails import LLMRails from nemoguardrails.logging.explain import LLMCallInfo -from nemoguardrails.rails.llm.config import TracingConfig +from nemoguardrails.rails.llm.config import TracingConfig, RailsConfig from nemoguardrails.rails.llm.options import ( ActivatedRail, ExecutedAction, @@ -201,5 +205,39 @@ def test_export_async(self): adapter_non_empty.transform_async.assert_called_once() +@patch.object(Tracer, "export_async", return_value="") +@pytest.mark.asyncio +async def test_tracing_enable_no_crash_issue_1093(mockTracer): + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hello" + + define flow + user express greeting + bot express greeting + + define bot express greeting + "Hello World!\\n NewLine World!" + """, + config={ + "models": [], + "rails": {"dialog": {"user_messages": {"embeddings_only": True}}}, + }, + ) + # Force Tracing to be enabled + config.tracing.enabled = True + rails = LLMRails(config) + res = await rails.generate_async( + messages=[ + {"role": "user", "content": "hi!"}, + {"role": "assistant", "content": "hi!"}, + {"role": "user", "content": "hi!"}, + ] + ) + assert mockTracer.called == True + assert res.response != None + + if __name__ == "__main__": unittest.main()