Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,6 @@ async def generate_async(
input=messages, response=res, adapters=self._log_adapters
)
await tracer.export_async()
res = res.response[0]
return res
else:
# If a prompt is used, we only return the content of the message.
Expand Down
41 changes: 39 additions & 2 deletions tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# limitations under the License.

import asyncio
import os

import pytest
import unittest
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

from nemoguardrails import LLMRails
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.rails.llm.config import TracingConfig
from nemoguardrails.rails.llm.config import TracingConfig, RailsConfig
from nemoguardrails.rails.llm.options import (
ActivatedRail,
ExecutedAction,
Expand Down Expand Up @@ -200,6 +204,39 @@ def test_export_async(self):
asyncio.run(tracer_non_empty.export_async())
adapter_non_empty.transform_async.assert_called_once()

@patch.object(Tracer, "export_async", return_value="")
@pytest.mark.asyncio
async def test_tracing_enable_no_crash_issue_1093(mockTracer):
config = RailsConfig.from_content(
colang_content="""
define user express greeting
"hello"

define flow
user express greeting
bot express greeting

define bot express greeting
"Hello World!\\n NewLine World!"
""",
config={
"models": [],
"rails": {"dialog": {"user_messages": {"embeddings_only": True}}},
},
)
# Force Tracing to be enabled
config.tracing.enabled = True
rails = LLMRails(config)
res = await rails.generate_async(
messages=[
{"role": "user", "content": "hi!"},
{"role": "assistant", "content": "hi!"},
{"role": "user", "content": "hi!"},
]
)
assert mockTracer.called == True
assert res.response != None


if __name__ == "__main__":
unittest.main()
Loading