diff --git a/ddtrace/contrib/fastapi/patch.py b/ddtrace/contrib/fastapi/patch.py index af6fe9af5ce..d4dfbe4e478 100644 --- a/ddtrace/contrib/fastapi/patch.py +++ b/ddtrace/contrib/fastapi/patch.py @@ -1,5 +1,4 @@ import fastapi -from fastapi.middleware import Middleware import fastapi.routing from ddtrace import Pin @@ -35,11 +34,8 @@ def span_modifier(span, scope): span.resource = "{} {}".format(scope["method"], resource) -def traced_init(wrapped, instance, args, kwargs): - mw = kwargs.pop("middleware", []) - mw.insert(0, Middleware(TraceMiddleware, integration_config=config.fastapi)) - kwargs.update({"middleware": mw}) - wrapped(*args, **kwargs) +def wrap_middleware_stack(wrapped, instance, args, kwargs): + return TraceMiddleware(app=wrapped(*args, **kwargs), integration_config=config.fastapi) async def traced_serialize_response(wrapped, instance, args, kwargs): @@ -73,7 +69,7 @@ def patch(): setattr(fastapi, "_datadog_patch", True) Pin().onto(fastapi) - _w("fastapi.applications", "FastAPI.__init__", traced_init) + _w("fastapi.applications", "FastAPI.build_middleware_stack", wrap_middleware_stack) _w("fastapi.routing", "serialize_response", traced_serialize_response) # We need to check that Starlette instrumentation hasn't already patched these @@ -90,7 +86,7 @@ def unpatch(): setattr(fastapi, "_datadog_patch", False) - _u(fastapi.applications.FastAPI, "__init__") + _u(fastapi.applications.FastAPI, "build_middleware_stack") _u(fastapi.routing, "serialize_response") # We need to check that Starlette instrumentation hasn't already unpatched these diff --git a/releasenotes/notes/fastapi-fix-middlewares-705975a535daaea8.yaml b/releasenotes/notes/fastapi-fix-middlewares-705975a535daaea8.yaml new file mode 100644 index 00000000000..99f2ad11b9b --- /dev/null +++ b/releasenotes/notes/fastapi-fix-middlewares-705975a535daaea8.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + fastapi: Previously, custom fastapi middlewares configured after application startup were not traced. This fix ensures that all fastapi middlewares are captured in the `fastapi.request` span. diff --git a/tests/contrib/fastapi/test_fastapi.py b/tests/contrib/fastapi/test_fastapi.py index 282c31d68c0..d8f32484f4c 100644 --- a/tests/contrib/fastapi/test_fastapi.py +++ b/tests/contrib/fastapi/test_fastapi.py @@ -52,6 +52,23 @@ def application(tracer): yield application +@pytest.fixture +def snapshot_app_with_middleware(): + fastapi_patch() + + application = app.get_app() + + @application.middleware("http") + async def traced_middlware(request, call_next): + with ddtrace.tracer.trace("traced_middlware"): + response = await call_next(request) + return response + + yield application + + fastapi_unpatch() + + @pytest.fixture def client(tracer): with TestClient(app.get_app()) as test_client: @@ -608,3 +625,11 @@ def test_host_header(client, tracer, test_spans, host): assert test_spans.spans request_span = test_spans.spans[0] assert request_span.get_tag("http.url") == "http://%s/asynctask" % (host,) + + +@snapshot() +def test_tracing_in_middleware(snapshot_app_with_middleware): + """Test if fastapi middlewares are traced""" + with TestClient(snapshot_app_with_middleware) as test_client: + r = test_client.get("/", headers={"sleep": "False"}) + assert r.status_code == 200 diff --git a/tests/contrib/fastapi/test_fastapi_patch.py b/tests/contrib/fastapi/test_fastapi_patch.py index d36a5ff147d..fa65822d175 100644 --- a/tests/contrib/fastapi/test_fastapi_patch.py +++ b/tests/contrib/fastapi/test_fastapi_patch.py @@ -10,19 +10,19 @@ class TestFastapiPatch(PatchTestCase.Base): __unpatch_func__ = unpatch def assert_module_patched(self, fastapi): - self.assert_wrapped(fastapi.applications.FastAPI.__init__) + self.assert_wrapped(fastapi.applications.FastAPI.build_middleware_stack) self.assert_wrapped(fastapi.routing.serialize_response) self.assert_wrapped(fastapi.routing.APIRoute.handle) self.assert_wrapped(fastapi.routing.Mount.handle) def assert_not_module_patched(self, fastapi): - self.assert_not_wrapped(fastapi.applications.FastAPI.__init__) + self.assert_not_wrapped(fastapi.applications.FastAPI.build_middleware_stack) self.assert_not_wrapped(fastapi.routing.serialize_response) self.assert_not_wrapped(fastapi.routing.APIRoute.handle) self.assert_not_wrapped(fastapi.routing.Mount.handle) def assert_not_module_double_patched(self, fastapi): - self.assert_not_double_wrapped(fastapi.applications.FastAPI.__init__) + self.assert_not_double_wrapped(fastapi.applications.FastAPI.build_middleware_stack) self.assert_not_double_wrapped(fastapi.routing.serialize_response) self.assert_not_double_wrapped(fastapi.routing.APIRoute.handle) self.assert_not_double_wrapped(fastapi.routing.Mount.handle) diff --git a/tests/snapshots/tests.contrib.fastapi.test_fastapi.test_tracing_in_middleware.json b/tests/snapshots/tests.contrib.fastapi.test_fastapi.test_tracing_in_middleware.json new file mode 100644 index 00000000000..55705fdf274 --- /dev/null +++ b/tests/snapshots/tests.contrib.fastapi.test_fastapi.test_tracing_in_middleware.json @@ -0,0 +1,55 @@ +[[ + { + "name": "fastapi.request", + "service": "fastapi", + "resource": "GET /", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "web", + "error": 0, + "meta": { + "_dd.p.dm": "-0", + "component": "fastapi", + "http.method": "GET", + "http.status_code": "200", + "http.url": "http://testserver/", + "http.useragent": "testclient", + "http.version": "1.1", + "language": "python", + "runtime-id": "7433b22c2562484081ca485a65d19945" + }, + "metrics": { + "_dd.agent_psr": 1.0, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "process_id": 4144 + }, + "duration": 1034000, + "start": 1669131973327481000 + }, + { + "name": "traced_middlware", + "service": "fastapi", + "resource": "traced_middlware", + "trace_id": 0, + "span_id": 2, + "parent_id": 1, + "type": "", + "error": 0, + "duration": 459000, + "start": 1669131973327931000 + }, + { + "name": "fastapi.serialize_response", + "service": "fastapi", + "resource": "fastapi.serialize_response", + "trace_id": 0, + "span_id": 3, + "parent_id": 2, + "type": "", + "error": 0, + "duration": 24000, + "start": 1669131973328198000 + }]]