Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed
- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830))

## Version 1.18.0/0.39b0 (2023-05-10)

- `opentelemetry-instrumentation-system-metrics` Add `process.` prefix to `runtime.memory`, `runtime.cpu.time`, and `runtime.gc_count`. Change `runtime.memory` from count to UpDownCounter. ([#1735](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1735))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,44 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _build_span_name(instance, cmd_args):
if len(cmd_args) > 0 and cmd_args[0]:
name = cmd_args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
return name


def _build_span_meta_data_for_pipeline(instance, sanitize_query):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
)
for c in command_stack
]
resource = "\n".join(cmds)

span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""

return command_stack, resource, span_name


def _instrument(
tracer,
request_hook: _RequestHookT = None,
Expand All @@ -165,11 +203,8 @@ def _instrument(
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args, sanitize_query)
name = _build_span_name(instance, args)

if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
Expand All @@ -185,31 +220,11 @@ def _traced_execute_command(func, instance, args, kwargs):
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
)
for c in command_stack
]
resource = "\n".join(cmds)

span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance, sanitize_query)

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
Expand Down Expand Up @@ -254,32 +269,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
"ClusterPipeline.execute",
_traced_execute_pipeline,
)

async def _async_traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args, sanitize_query)
name = _build_span_name(instance, args)

with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
if callable(request_hook):
request_hook(span, instance, args, kwargs)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance, sanitize_query)

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(command_stack)
)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
wrap_function_wrapper(
"redis.asyncio",
f"{redis_class}.execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
wrap_function_wrapper(
"redis.asyncio.client",
f"{pipeline_class}.execute",
_traced_execute_pipeline,
_async_traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.asyncio.client",
f"{pipeline_class}.immediate_execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
wrap_function_wrapper(
"redis.asyncio.cluster",
"RedisCluster.execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
wrap_function_wrapper(
"redis.asyncio.cluster",
"ClusterPipeline.execute",
_traced_execute_pipeline,
_async_traced_execute_pipeline,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from unittest import mock

import redis
import redis.asyncio

from opentelemetry import trace
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind


class AsyncMock:
"""A sufficient async mock implementation.

Python 3.7 doesn't have an inbuilt async mock class, so this is used.
"""

def __init__(self):
self.mock = mock.Mock()

async def __call__(self, *args, **kwargs):
f = asyncio.Future()
f.set_result("random")
return f

def __getattr__(self, item):
return AsyncMock()


class TestRedis(TestBase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -87,6 +107,35 @@ def test_instrument_uninstrument(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_instrument_uninstrument_async_client_command(self):
redis_client = redis.asyncio.Redis()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

# Test uninstrument
RedisInstrumentor().uninstrument()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
self.memory_exporter.clear()

# Test instrument again
RedisInstrumentor().instrument()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_response_hook(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
Expand Down