Skip to content

Commit 2a71c01

Browse files
committed
fix: Avoid ClassCastException and reduce copy/pasta 🍝 in FunctionTool
Prompted by google#487 (comment).
1 parent 51d60d1 commit 2a71c01

1 file changed

Lines changed: 32 additions & 56 deletions

File tree

core/src/main/java/com/google/adk/tools/FunctionTool.java

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -181,55 +181,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
181181
@SuppressWarnings("unchecked") // For tool parameter type casting.
182182
private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext toolContext)
183183
throws IllegalAccessException, InvocationTargetException {
184-
Parameter[] parameters = func.getParameters();
185-
Object[] arguments = new Object[parameters.length];
186-
for (int i = 0; i < parameters.length; i++) {
187-
String paramName =
188-
parameters[i].isAnnotationPresent(Annotations.Schema.class)
189-
&& !parameters[i].getAnnotation(Annotations.Schema.class).name().isEmpty()
190-
? parameters[i].getAnnotation(Annotations.Schema.class).name()
191-
: parameters[i].getName();
192-
if (paramName.equals("toolContext")) {
193-
arguments[i] = toolContext;
194-
continue;
195-
}
196-
if (paramName.equals("inputStream")) {
197-
arguments[i] = null;
198-
continue;
199-
}
200-
if (!args.containsKey(paramName)) {
201-
throw new IllegalArgumentException(
202-
String.format(
203-
"The parameter '%s' was not found in the arguments provided by the model.",
204-
paramName));
205-
}
206-
Class<?> paramType = parameters[i].getType();
207-
Object argValue = args.get(paramName);
208-
if (paramType.equals(List.class)) {
209-
if (argValue instanceof List) {
210-
Type type =
211-
((ParameterizedType) parameters[i].getParameterizedType())
212-
.getActualTypeArguments()[0];
213-
Class<?> typeArgClass;
214-
if (type instanceof Class) {
215-
// Case 1: The argument is a simple class like String, Integer, etc.
216-
typeArgClass = (Class<?>) type;
217-
} else if (type instanceof ParameterizedType pType) {
218-
// Case 2: The argument is another parameterized type like Map<String, Integer>
219-
typeArgClass = (Class<?>) pType.getRawType(); // Get the raw class (e.g., Map)
220-
} else {
221-
throw new IllegalArgumentException(
222-
String.format("Unsupported parameterized type %s for '%s'", type, paramName));
223-
}
224-
arguments[i] = createList((List<Object>) argValue, typeArgClass);
225-
continue;
226-
}
227-
} else if (argValue instanceof Map) {
228-
arguments[i] = OBJECT_MAPPER.convertValue(argValue, paramType);
229-
continue;
230-
}
231-
arguments[i] = castValue(argValue, paramType);
232-
}
184+
Object[] arguments = buildArguments(args, toolContext, null);
233185
Object result = func.invoke(instance, arguments);
234186
if (result == null) {
235187
return Maybe.empty();
@@ -253,6 +205,21 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
253205
public Flowable<Map<String, Object>> callLive(
254206
Map<String, Object> args, ToolContext toolContext, InvocationContext invocationContext)
255207
throws IllegalAccessException, InvocationTargetException {
208+
Object[] arguments = buildArguments(args, toolContext, invocationContext);
209+
Object result = func.invoke(instance, arguments);
210+
if (result instanceof Flowable) {
211+
return (Flowable<Map<String, Object>>) result;
212+
} else {
213+
throw new IllegalArgumentException(
214+
"callLive was called but the underlying function does not return a Flowable.");
215+
}
216+
}
217+
218+
@SuppressWarnings("unchecked") // For tool parameter type casting.
219+
private Object[] buildArguments(
220+
Map<String, Object> args,
221+
ToolContext toolContext,
222+
@Nullable InvocationContext invocationContext) {
256223
Parameter[] parameters = func.getParameters();
257224
Object[] arguments = new Object[parameters.length];
258225
for (int i = 0; i < parameters.length; i++) {
@@ -266,7 +233,8 @@ public Flowable<Map<String, Object>> callLive(
266233
continue;
267234
}
268235
if (paramName.equals("inputStream")) {
269-
if (invocationContext.activeStreamingTools().containsKey(this.name())
236+
if (invocationContext != null
237+
&& invocationContext.activeStreamingTools().containsKey(this.name())
270238
&& invocationContext.activeStreamingTools().get(this.name()).stream() != null) {
271239
arguments[i] = invocationContext.activeStreamingTools().get(this.name()).stream();
272240
} else {
@@ -287,7 +255,8 @@ public Flowable<Map<String, Object>> callLive(
287255
Type type =
288256
((ParameterizedType) parameters[i].getParameterizedType())
289257
.getActualTypeArguments()[0];
290-
arguments[i] = createList((List<Object>) argValue, (Class) type);
258+
Class<?> typeArgClass = getTypeClass(type, paramName);
259+
arguments[i] = createList((List<Object>) argValue, typeArgClass);
291260
continue;
292261
}
293262
} else if (argValue instanceof Map) {
@@ -296,12 +265,19 @@ public Flowable<Map<String, Object>> callLive(
296265
}
297266
arguments[i] = castValue(argValue, paramType);
298267
}
299-
Object result = func.invoke(instance, arguments);
300-
if (result instanceof Flowable) {
301-
return (Flowable<Map<String, Object>>) result;
268+
return arguments;
269+
}
270+
271+
private static Class<?> getTypeClass(Type type, String paramName) {
272+
if (type instanceof Class) {
273+
// Case 1: The argument is a simple class like String, Integer, etc.
274+
return (Class<?>) type;
275+
} else if (type instanceof ParameterizedType pType) {
276+
// Case 2: The argument is another parameterized type like Map<String, Integer>
277+
return (Class<?>) pType.getRawType(); // Get the raw class (e.g., Map)
302278
} else {
303-
logger.warn("callLive was called but the underlying function does not return a Flowable.");
304-
return Flowable.empty();
279+
throw new IllegalArgumentException(
280+
String.format("Unsupported parameterized type %s for '%s'", type, paramName));
305281
}
306282
}
307283

0 commit comments

Comments
 (0)