Skip to content

Commit 76dd656

Browse files
fix: filter out injected args from tracing (#33729)
this is CC generated and I want to do a thorough review + update the tests. but should be able to ship today. before eek <img width="637" height="485" alt="Screenshot 2025-10-29 at 12 34 52 PM" src="https://github.com/user-attachments/assets/121def87-fb7b-4847-b9e2-74f37b3b4763" /> now, woo <img width="651" height="158" alt="Screenshot 2025-10-29 at 12 36 09 PM" src="https://github.com/user-attachments/assets/1fc0e19e-a83f-417c-81e2-3aa0028630d6" />
1 parent d218936 commit 76dd656

File tree

2 files changed

+319
-12
lines changed

2 files changed

+319
-12
lines changed

libs/core/langchain_core/tools/base.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,35 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any:
707707
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
708708
return await run_in_executor(None, self._run, *args, **kwargs)
709709

710+
def _filter_injected_args(self, tool_input: dict) -> dict:
711+
"""Filter out injected tool arguments from the input dictionary.
712+
713+
Injected arguments are those annotated with InjectedToolArg or its
714+
subclasses, or arguments in FILTERED_ARGS like run_manager and callbacks.
715+
716+
Args:
717+
tool_input: The tool input dictionary to filter.
718+
719+
Returns:
720+
A filtered dictionary with injected arguments removed.
721+
"""
722+
# Start with filtered args from the constant
723+
filtered_keys = set[str](FILTERED_ARGS)
724+
725+
# If we have an args_schema, use it to identify injected args
726+
if self.args_schema is not None:
727+
try:
728+
annotations = get_all_basemodel_annotations(self.args_schema)
729+
for field_name, field_type in annotations.items():
730+
if _is_injected_arg_type(field_type):
731+
filtered_keys.add(field_name)
732+
except Exception: # noqa: S110
733+
# If we can't get annotations, just use FILTERED_ARGS
734+
pass
735+
736+
# Filter out the injected keys from tool_input
737+
return {k: v for k, v in tool_input.items() if k not in filtered_keys}
738+
710739
def _to_args_and_kwargs(
711740
self, tool_input: str | dict, tool_call_id: str | None
712741
) -> tuple[tuple, dict]:
@@ -794,17 +823,29 @@ def run(
794823
self.metadata,
795824
)
796825

826+
# Filter out injected arguments from callback inputs
827+
filtered_tool_input = (
828+
self._filter_injected_args(tool_input)
829+
if isinstance(tool_input, dict)
830+
else None
831+
)
832+
833+
# Use filtered inputs for the input_str parameter as well
834+
tool_input_str = (
835+
tool_input
836+
if isinstance(tool_input, str)
837+
else str(
838+
filtered_tool_input if filtered_tool_input is not None else tool_input
839+
)
840+
)
841+
797842
run_manager = callback_manager.on_tool_start(
798843
{"name": self.name, "description": self.description},
799-
tool_input if isinstance(tool_input, str) else str(tool_input),
844+
tool_input_str,
800845
color=start_color,
801846
name=run_name,
802847
run_id=run_id,
803-
# Inputs by definition should always be dicts.
804-
# For now, it's unclear whether this assumption is ever violated,
805-
# but if it is we will send a `None` value to the callback instead
806-
# TODO: will need to address issue via a patch.
807-
inputs=tool_input if isinstance(tool_input, dict) else None,
848+
inputs=filtered_tool_input,
808849
**kwargs,
809850
)
810851

@@ -905,17 +946,30 @@ async def arun(
905946
metadata,
906947
self.metadata,
907948
)
949+
950+
# Filter out injected arguments from callback inputs
951+
filtered_tool_input = (
952+
self._filter_injected_args(tool_input)
953+
if isinstance(tool_input, dict)
954+
else None
955+
)
956+
957+
# Use filtered inputs for the input_str parameter as well
958+
tool_input_str = (
959+
tool_input
960+
if isinstance(tool_input, str)
961+
else str(
962+
filtered_tool_input if filtered_tool_input is not None else tool_input
963+
)
964+
)
965+
908966
run_manager = await callback_manager.on_tool_start(
909967
{"name": self.name, "description": self.description},
910-
tool_input if isinstance(tool_input, str) else str(tool_input),
968+
tool_input_str,
911969
color=start_color,
912970
name=run_name,
913971
run_id=run_id,
914-
# Inputs by definition should always be dicts.
915-
# For now, it's unclear whether this assumption is ever violated,
916-
# but if it is we will send a `None` value to the callback instead
917-
# TODO: will need to address issue via a patch.
918-
inputs=tool_input if isinstance(tool_input, dict) else None,
972+
inputs=filtered_tool_input,
919973
**kwargs,
920974
)
921975
content = None

libs/core/tests/unit_tests/test_tools.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@
7070
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
7171
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
7272

73+
try:
74+
from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]
75+
76+
HAS_LANGGRAPH = True
77+
except ImportError:
78+
HAS_LANGGRAPH = False
79+
7380

7481
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
7582
tool_schema = tool.tool_call_schema
@@ -2773,3 +2780,249 @@ def test_tool(
27732780
"type": "array",
27742781
}
27752782
}
2783+
2784+
2785+
class CallbackHandlerWithInputCapture(FakeCallbackHandler):
2786+
"""Callback handler that captures inputs passed to on_tool_start."""
2787+
2788+
captured_inputs: list[dict | None] = []
2789+
2790+
def on_tool_start(
2791+
self,
2792+
serialized: dict[str, Any],
2793+
input_str: str,
2794+
*,
2795+
run_id: Any,
2796+
parent_run_id: Any | None = None,
2797+
tags: list[str] | None = None,
2798+
metadata: dict[str, Any] | None = None,
2799+
inputs: dict[str, Any] | None = None,
2800+
**kwargs: Any,
2801+
) -> Any:
2802+
"""Capture the inputs passed to on_tool_start."""
2803+
self.captured_inputs.append(inputs)
2804+
return super().on_tool_start(
2805+
serialized,
2806+
input_str,
2807+
run_id=run_id,
2808+
parent_run_id=parent_run_id,
2809+
tags=tags,
2810+
metadata=metadata,
2811+
inputs=inputs,
2812+
**kwargs,
2813+
)
2814+
2815+
2816+
def test_filter_injected_args_from_callbacks() -> None:
2817+
"""Test that injected tool arguments are filtered from callback inputs."""
2818+
2819+
@tool
2820+
def search_tool(
2821+
query: str,
2822+
state: Annotated[dict, InjectedToolArg()],
2823+
) -> str:
2824+
"""Search with injected state.
2825+
2826+
Args:
2827+
query: The search query.
2828+
state: Injected state context.
2829+
"""
2830+
return f"Results for: {query}"
2831+
2832+
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
2833+
result = search_tool.invoke(
2834+
{"query": "test query", "state": {"user_id": 123}},
2835+
config={"callbacks": [handler]},
2836+
)
2837+
2838+
assert result == "Results for: test query"
2839+
assert handler.tool_starts == 1
2840+
assert len(handler.captured_inputs) == 1
2841+
2842+
# Verify that injected 'state' arg is filtered out
2843+
captured = handler.captured_inputs[0]
2844+
assert captured is not None
2845+
assert "query" in captured
2846+
assert "state" not in captured
2847+
assert captured["query"] == "test query"
2848+
2849+
2850+
def test_filter_run_manager_from_callbacks() -> None:
2851+
"""Test that run_manager is filtered from callback inputs."""
2852+
2853+
@tool
2854+
def tool_with_run_manager(
2855+
message: str,
2856+
run_manager: CallbackManagerForToolRun | None = None,
2857+
) -> str:
2858+
"""Tool with run_manager parameter.
2859+
2860+
Args:
2861+
message: The message to process.
2862+
run_manager: The callback manager.
2863+
"""
2864+
return f"Processed: {message}"
2865+
2866+
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
2867+
result = tool_with_run_manager.invoke(
2868+
{"message": "hello"},
2869+
config={"callbacks": [handler]},
2870+
)
2871+
2872+
assert result == "Processed: hello"
2873+
assert handler.tool_starts == 1
2874+
assert len(handler.captured_inputs) == 1
2875+
2876+
# Verify that run_manager is filtered out
2877+
captured = handler.captured_inputs[0]
2878+
assert captured is not None
2879+
assert "message" in captured
2880+
assert "run_manager" not in captured
2881+
2882+
2883+
def test_filter_multiple_injected_args() -> None:
2884+
"""Test filtering multiple injected arguments from callback inputs."""
2885+
2886+
@tool
2887+
def complex_tool(
2888+
query: str,
2889+
limit: int,
2890+
state: Annotated[dict, InjectedToolArg()],
2891+
context: Annotated[str, InjectedToolArg()],
2892+
run_manager: CallbackManagerForToolRun | None = None,
2893+
) -> str:
2894+
"""Complex tool with multiple injected args.
2895+
2896+
Args:
2897+
query: The search query.
2898+
limit: Maximum number of results.
2899+
state: Injected state.
2900+
context: Injected context.
2901+
run_manager: The callback manager.
2902+
"""
2903+
return f"Query: {query}, Limit: {limit}"
2904+
2905+
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
2906+
result = complex_tool.invoke(
2907+
{
2908+
"query": "test",
2909+
"limit": 10,
2910+
"state": {"foo": "bar"},
2911+
"context": "some context",
2912+
},
2913+
config={"callbacks": [handler]},
2914+
)
2915+
2916+
assert result == "Query: test, Limit: 10"
2917+
assert handler.tool_starts == 1
2918+
assert len(handler.captured_inputs) == 1
2919+
2920+
# Verify that only non-injected args remain
2921+
captured = handler.captured_inputs[0]
2922+
assert captured is not None
2923+
assert captured == {"query": "test", "limit": 10}
2924+
assert "state" not in captured
2925+
assert "context" not in captured
2926+
assert "run_manager" not in captured
2927+
2928+
2929+
def test_no_filtering_for_string_input() -> None:
2930+
"""Test that string inputs are not filtered (passed as None)."""
2931+
2932+
@tool
2933+
def simple_tool(query: str) -> str:
2934+
"""Simple tool with string input.
2935+
2936+
Args:
2937+
query: The query string.
2938+
"""
2939+
return f"Result: {query}"
2940+
2941+
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
2942+
result = simple_tool.invoke("test query", config={"callbacks": [handler]})
2943+
2944+
assert result == "Result: test query"
2945+
assert handler.tool_starts == 1
2946+
assert len(handler.captured_inputs) == 1
2947+
2948+
# String inputs should result in None for the inputs parameter
2949+
assert handler.captured_inputs[0] is None
2950+
2951+
2952+
async def test_filter_injected_args_async() -> None:
2953+
"""Test that injected args are filtered in async tool execution."""
2954+
2955+
@tool
2956+
async def async_search_tool(
2957+
query: str,
2958+
state: Annotated[dict, InjectedToolArg()],
2959+
) -> str:
2960+
"""Async search with injected state.
2961+
2962+
Args:
2963+
query: The search query.
2964+
state: Injected state context.
2965+
"""
2966+
return f"Async results for: {query}"
2967+
2968+
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
2969+
result = await async_search_tool.ainvoke(
2970+
{"query": "async test", "state": {"user_id": 456}},
2971+
config={"callbacks": [handler]},
2972+
)
2973+
2974+
assert result == "Async results for: async test"
2975+
assert handler.tool_starts == 1
2976+
assert len(handler.captured_inputs) == 1
2977+
2978+
# Verify filtering in async execution
2979+
captured = handler.captured_inputs[0]
2980+
assert captured is not None
2981+
assert "query" in captured
2982+
assert "state" not in captured
2983+
assert captured["query"] == "async test"
2984+
2985+
2986+
@pytest.mark.skipif(not HAS_LANGGRAPH, reason="langgraph not installed")
2987+
def test_filter_tool_runtime_directly_injected_arg() -> None:
2988+
"""Test that ToolRuntime (a _DirectlyInjectedToolArg) is filtered."""
2989+
2990+
@tool
2991+
def tool_with_runtime(query: str, limit: int, runtime: ToolRuntime) -> str:
2992+
"""Tool with ToolRuntime parameter.
2993+
2994+
Args:
2995+
query: The search query.
2996+
limit: Max results.
2997+
runtime: The tool runtime (directly injected).
2998+
"""
2999+
return f"Query: {query}, Limit: {limit}"
3000+
3001+
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
3002+
3003+
# Create a mock ToolRuntime instance
3004+
class MockRuntime:
3005+
"""Mock ToolRuntime for testing."""
3006+
3007+
agent_name = "test_agent"
3008+
context: dict[str, Any] = {}
3009+
state: dict[str, Any] = {}
3010+
3011+
result = tool_with_runtime.invoke(
3012+
{
3013+
"query": "test",
3014+
"limit": 5,
3015+
"runtime": MockRuntime(),
3016+
},
3017+
config={"callbacks": [handler]},
3018+
)
3019+
3020+
assert result == "Query: test, Limit: 5"
3021+
assert handler.tool_starts == 1
3022+
assert len(handler.captured_inputs) == 1
3023+
3024+
# Verify that ToolRuntime is filtered out
3025+
captured = handler.captured_inputs[0]
3026+
assert captured is not None
3027+
assert captured == {"query": "test", "limit": 5}
3028+
assert "runtime" not in captured

0 commit comments

Comments
 (0)