|
70 | 70 | from tests.unit_tests.fake.callbacks import FakeCallbackHandler |
71 | 71 | from tests.unit_tests.pydantic_utils import _normalize_schema, _schema |
72 | 72 |
|
| 73 | +try: |
| 74 | + from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found] |
| 75 | + |
| 76 | + HAS_LANGGRAPH = True |
| 77 | +except ImportError: |
| 78 | + HAS_LANGGRAPH = False |
| 79 | + |
73 | 80 |
|
74 | 81 | def _get_tool_call_json_schema(tool: BaseTool) -> dict: |
75 | 82 | tool_schema = tool.tool_call_schema |
@@ -2773,3 +2780,249 @@ def test_tool( |
2773 | 2780 | "type": "array", |
2774 | 2781 | } |
2775 | 2782 | } |
| 2783 | + |
| 2784 | + |
| 2785 | +class CallbackHandlerWithInputCapture(FakeCallbackHandler): |
| 2786 | + """Callback handler that captures inputs passed to on_tool_start.""" |
| 2787 | + |
| 2788 | + captured_inputs: list[dict | None] = [] |
| 2789 | + |
| 2790 | + def on_tool_start( |
| 2791 | + self, |
| 2792 | + serialized: dict[str, Any], |
| 2793 | + input_str: str, |
| 2794 | + *, |
| 2795 | + run_id: Any, |
| 2796 | + parent_run_id: Any | None = None, |
| 2797 | + tags: list[str] | None = None, |
| 2798 | + metadata: dict[str, Any] | None = None, |
| 2799 | + inputs: dict[str, Any] | None = None, |
| 2800 | + **kwargs: Any, |
| 2801 | + ) -> Any: |
| 2802 | + """Capture the inputs passed to on_tool_start.""" |
| 2803 | + self.captured_inputs.append(inputs) |
| 2804 | + return super().on_tool_start( |
| 2805 | + serialized, |
| 2806 | + input_str, |
| 2807 | + run_id=run_id, |
| 2808 | + parent_run_id=parent_run_id, |
| 2809 | + tags=tags, |
| 2810 | + metadata=metadata, |
| 2811 | + inputs=inputs, |
| 2812 | + **kwargs, |
| 2813 | + ) |
| 2814 | + |
| 2815 | + |
| 2816 | +def test_filter_injected_args_from_callbacks() -> None: |
| 2817 | + """Test that injected tool arguments are filtered from callback inputs.""" |
| 2818 | + |
| 2819 | + @tool |
| 2820 | + def search_tool( |
| 2821 | + query: str, |
| 2822 | + state: Annotated[dict, InjectedToolArg()], |
| 2823 | + ) -> str: |
| 2824 | + """Search with injected state. |
| 2825 | +
|
| 2826 | + Args: |
| 2827 | + query: The search query. |
| 2828 | + state: Injected state context. |
| 2829 | + """ |
| 2830 | + return f"Results for: {query}" |
| 2831 | + |
| 2832 | + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) |
| 2833 | + result = search_tool.invoke( |
| 2834 | + {"query": "test query", "state": {"user_id": 123}}, |
| 2835 | + config={"callbacks": [handler]}, |
| 2836 | + ) |
| 2837 | + |
| 2838 | + assert result == "Results for: test query" |
| 2839 | + assert handler.tool_starts == 1 |
| 2840 | + assert len(handler.captured_inputs) == 1 |
| 2841 | + |
| 2842 | + # Verify that injected 'state' arg is filtered out |
| 2843 | + captured = handler.captured_inputs[0] |
| 2844 | + assert captured is not None |
| 2845 | + assert "query" in captured |
| 2846 | + assert "state" not in captured |
| 2847 | + assert captured["query"] == "test query" |
| 2848 | + |
| 2849 | + |
| 2850 | +def test_filter_run_manager_from_callbacks() -> None: |
| 2851 | + """Test that run_manager is filtered from callback inputs.""" |
| 2852 | + |
| 2853 | + @tool |
| 2854 | + def tool_with_run_manager( |
| 2855 | + message: str, |
| 2856 | + run_manager: CallbackManagerForToolRun | None = None, |
| 2857 | + ) -> str: |
| 2858 | + """Tool with run_manager parameter. |
| 2859 | +
|
| 2860 | + Args: |
| 2861 | + message: The message to process. |
| 2862 | + run_manager: The callback manager. |
| 2863 | + """ |
| 2864 | + return f"Processed: {message}" |
| 2865 | + |
| 2866 | + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) |
| 2867 | + result = tool_with_run_manager.invoke( |
| 2868 | + {"message": "hello"}, |
| 2869 | + config={"callbacks": [handler]}, |
| 2870 | + ) |
| 2871 | + |
| 2872 | + assert result == "Processed: hello" |
| 2873 | + assert handler.tool_starts == 1 |
| 2874 | + assert len(handler.captured_inputs) == 1 |
| 2875 | + |
| 2876 | + # Verify that run_manager is filtered out |
| 2877 | + captured = handler.captured_inputs[0] |
| 2878 | + assert captured is not None |
| 2879 | + assert "message" in captured |
| 2880 | + assert "run_manager" not in captured |
| 2881 | + |
| 2882 | + |
| 2883 | +def test_filter_multiple_injected_args() -> None: |
| 2884 | + """Test filtering multiple injected arguments from callback inputs.""" |
| 2885 | + |
| 2886 | + @tool |
| 2887 | + def complex_tool( |
| 2888 | + query: str, |
| 2889 | + limit: int, |
| 2890 | + state: Annotated[dict, InjectedToolArg()], |
| 2891 | + context: Annotated[str, InjectedToolArg()], |
| 2892 | + run_manager: CallbackManagerForToolRun | None = None, |
| 2893 | + ) -> str: |
| 2894 | + """Complex tool with multiple injected args. |
| 2895 | +
|
| 2896 | + Args: |
| 2897 | + query: The search query. |
| 2898 | + limit: Maximum number of results. |
| 2899 | + state: Injected state. |
| 2900 | + context: Injected context. |
| 2901 | + run_manager: The callback manager. |
| 2902 | + """ |
| 2903 | + return f"Query: {query}, Limit: {limit}" |
| 2904 | + |
| 2905 | + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) |
| 2906 | + result = complex_tool.invoke( |
| 2907 | + { |
| 2908 | + "query": "test", |
| 2909 | + "limit": 10, |
| 2910 | + "state": {"foo": "bar"}, |
| 2911 | + "context": "some context", |
| 2912 | + }, |
| 2913 | + config={"callbacks": [handler]}, |
| 2914 | + ) |
| 2915 | + |
| 2916 | + assert result == "Query: test, Limit: 10" |
| 2917 | + assert handler.tool_starts == 1 |
| 2918 | + assert len(handler.captured_inputs) == 1 |
| 2919 | + |
| 2920 | + # Verify that only non-injected args remain |
| 2921 | + captured = handler.captured_inputs[0] |
| 2922 | + assert captured is not None |
| 2923 | + assert captured == {"query": "test", "limit": 10} |
| 2924 | + assert "state" not in captured |
| 2925 | + assert "context" not in captured |
| 2926 | + assert "run_manager" not in captured |
| 2927 | + |
| 2928 | + |
| 2929 | +def test_no_filtering_for_string_input() -> None: |
| 2930 | + """Test that string inputs are not filtered (passed as None).""" |
| 2931 | + |
| 2932 | + @tool |
| 2933 | + def simple_tool(query: str) -> str: |
| 2934 | + """Simple tool with string input. |
| 2935 | +
|
| 2936 | + Args: |
| 2937 | + query: The query string. |
| 2938 | + """ |
| 2939 | + return f"Result: {query}" |
| 2940 | + |
| 2941 | + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) |
| 2942 | + result = simple_tool.invoke("test query", config={"callbacks": [handler]}) |
| 2943 | + |
| 2944 | + assert result == "Result: test query" |
| 2945 | + assert handler.tool_starts == 1 |
| 2946 | + assert len(handler.captured_inputs) == 1 |
| 2947 | + |
| 2948 | + # String inputs should result in None for the inputs parameter |
| 2949 | + assert handler.captured_inputs[0] is None |
| 2950 | + |
| 2951 | + |
| 2952 | +async def test_filter_injected_args_async() -> None: |
| 2953 | + """Test that injected args are filtered in async tool execution.""" |
| 2954 | + |
| 2955 | + @tool |
| 2956 | + async def async_search_tool( |
| 2957 | + query: str, |
| 2958 | + state: Annotated[dict, InjectedToolArg()], |
| 2959 | + ) -> str: |
| 2960 | + """Async search with injected state. |
| 2961 | +
|
| 2962 | + Args: |
| 2963 | + query: The search query. |
| 2964 | + state: Injected state context. |
| 2965 | + """ |
| 2966 | + return f"Async results for: {query}" |
| 2967 | + |
| 2968 | + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) |
| 2969 | + result = await async_search_tool.ainvoke( |
| 2970 | + {"query": "async test", "state": {"user_id": 456}}, |
| 2971 | + config={"callbacks": [handler]}, |
| 2972 | + ) |
| 2973 | + |
| 2974 | + assert result == "Async results for: async test" |
| 2975 | + assert handler.tool_starts == 1 |
| 2976 | + assert len(handler.captured_inputs) == 1 |
| 2977 | + |
| 2978 | + # Verify filtering in async execution |
| 2979 | + captured = handler.captured_inputs[0] |
| 2980 | + assert captured is not None |
| 2981 | + assert "query" in captured |
| 2982 | + assert "state" not in captured |
| 2983 | + assert captured["query"] == "async test" |
| 2984 | + |
| 2985 | + |
| 2986 | +@pytest.mark.skipif(not HAS_LANGGRAPH, reason="langgraph not installed") |
| 2987 | +def test_filter_tool_runtime_directly_injected_arg() -> None: |
| 2988 | + """Test that ToolRuntime (a _DirectlyInjectedToolArg) is filtered.""" |
| 2989 | + |
| 2990 | + @tool |
| 2991 | + def tool_with_runtime(query: str, limit: int, runtime: ToolRuntime) -> str: |
| 2992 | + """Tool with ToolRuntime parameter. |
| 2993 | +
|
| 2994 | + Args: |
| 2995 | + query: The search query. |
| 2996 | + limit: Max results. |
| 2997 | + runtime: The tool runtime (directly injected). |
| 2998 | + """ |
| 2999 | + return f"Query: {query}, Limit: {limit}" |
| 3000 | + |
| 3001 | + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) |
| 3002 | + |
| 3003 | + # Create a mock ToolRuntime instance |
| 3004 | + class MockRuntime: |
| 3005 | + """Mock ToolRuntime for testing.""" |
| 3006 | + |
| 3007 | + agent_name = "test_agent" |
| 3008 | + context: dict[str, Any] = {} |
| 3009 | + state: dict[str, Any] = {} |
| 3010 | + |
| 3011 | + result = tool_with_runtime.invoke( |
| 3012 | + { |
| 3013 | + "query": "test", |
| 3014 | + "limit": 5, |
| 3015 | + "runtime": MockRuntime(), |
| 3016 | + }, |
| 3017 | + config={"callbacks": [handler]}, |
| 3018 | + ) |
| 3019 | + |
| 3020 | + assert result == "Query: test, Limit: 5" |
| 3021 | + assert handler.tool_starts == 1 |
| 3022 | + assert len(handler.captured_inputs) == 1 |
| 3023 | + |
| 3024 | + # Verify that ToolRuntime is filtered out |
| 3025 | + captured = handler.captured_inputs[0] |
| 3026 | + assert captured is not None |
| 3027 | + assert captured == {"query": "test", "limit": 5} |
| 3028 | + assert "runtime" not in captured |
0 commit comments