Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ async def _call_tool_in_thread_pool(
# For sync FunctionTool, call the underlying function directly
def run_sync_tool():
if isinstance(tool, FunctionTool):
args_to_call = tool._preprocess_args(args)
args_to_call, validation_errors = tool._preprocess_args(args)
if validation_errors:
return tool._build_validation_error_response(validation_errors)
signature = inspect.signature(tool.func)
valid_params = {param for param in signature.parameters}
if tool._context_param_name in valid_params:
Expand Down
8 changes: 6 additions & 2 deletions src/google/adk/tools/crewai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ async def run_async(
duplicates, but is re-added if the function signature explicitly requires it
as a parameter.
"""
# Preprocess arguments (includes Pydantic model conversion)
args_to_call = self._preprocess_args(args)
# Preprocess arguments (includes Pydantic model conversion and type
# validation)
args_to_call, validation_errors = self._preprocess_args(args)

if validation_errors:
return self._build_validation_error_response(validation_errors)

signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}
Expand Down
119 changes: 72 additions & 47 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import logging
from typing import Any
from typing import Callable
from typing import get_args
from typing import get_origin
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -85,6 +83,7 @@ def __init__(
self._context_param_name = find_context_parameter(func) or 'tool_context'
self._ignore_params = [self._context_param_name, 'input_stream']
self._require_confirmation = require_confirmation
self._type_adapter_cache: dict[Any, pydantic.TypeAdapter] = {}

@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
Expand All @@ -100,68 +99,94 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:

return function_decl

def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
"""Preprocess and convert function arguments before invocation.
def _preprocess_args(
self, args: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""Preprocess, validate, and convert function arguments before invocation.

Currently handles:
Handles:
- Converting JSON dictionaries to Pydantic model instances where expected

Future extensions could include:
- Type coercion for other complex types
- Validation and sanitization
- Custom conversion logic
- Validating and coercing primitive types (int, float, str, bool)
- Validating enum values
- Validating container types (list[int], dict[str, float], etc.)

Args:
args: Raw arguments from the LLM tool call

Returns:
Processed arguments ready for function invocation
A tuple of (processed_args, validation_errors). If validation_errors is
non-empty, the caller should return the errors to the LLM instead of
invoking the function.
"""
signature = inspect.signature(self.func)
converted_args = args.copy()
validation_errors = []

for param_name, param in signature.parameters.items():
if param_name in args and param.annotation != inspect.Parameter.empty:
target_type = param.annotation

# Handle Optional[PydanticModel] types
if get_origin(param.annotation) is Union:
union_args = get_args(param.annotation)
# Find the non-None type in Optional[T] (which is Union[T, None])
non_none_types = [arg for arg in union_args if arg is not type(None)]
if len(non_none_types) == 1:
target_type = non_none_types[0]

# Check if the target type is a Pydantic model
if inspect.isclass(target_type) and issubclass(
target_type, pydantic.BaseModel
):
# Skip conversion if the value is None and the parameter is Optional
if args[param_name] is None:
continue

# Convert to Pydantic model if it's not already the correct type
if not isinstance(args[param_name], target_type):
try:
converted_args[param_name] = target_type.model_validate(
args[param_name]
)
except Exception as e:
logger.warning(
f"Failed to convert argument '{param_name}' to Pydantic model"
f' {target_type.__name__}: {e}'
)
# Keep the original value if conversion fails
pass

return converted_args
if (
param_name not in args
or param.annotation is inspect.Parameter.empty
or param_name in self._ignore_params
):
continue

target_type = param.annotation

# Validate and coerce using TypeAdapter. Handles primitives, enums,
# Pydantic models, Optional[T], T | None, and container types natively.
try:
try:
adapter = self._type_adapter_cache[target_type]
except TypeError:
adapter = pydantic.TypeAdapter(target_type)
except KeyError:
adapter = pydantic.TypeAdapter(target_type)
self._type_adapter_cache[target_type] = adapter
converted_args[param_name] = adapter.validate_python(
args[param_name]
)
except pydantic.ValidationError as e:
validation_errors.append(
f"Parameter '{param_name}': expected type '{getattr(target_type, '__name__', target_type)}',"
f' validation error: {e}'
)
except (TypeError, NameError) as e:
# TypeAdapter could not handle this annotation (e.g. a forward
# reference string). Skip validation but log a warning.
logger.warning(
"Skipping validation for parameter '%s' due to unhandled"
" annotation type '%s': %s",
param_name,
target_type,
e,
)

return converted_args, validation_errors

def _build_validation_error_response(
self, validation_errors: list[str]
) -> dict[str, str]:
"""Formats validation errors into an error dict for the LLM."""
validation_errors_str = '\n'.join(validation_errors)
return {
'error': (
f'Invoking `{self.name}()` failed due to argument validation'
f' errors:\n{validation_errors_str}\nYou could retry calling'
' this tool with corrected argument types.'
)
}

@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
# Preprocess arguments (includes Pydantic model conversion)
args_to_call = self._preprocess_args(args)
# Preprocess arguments (includes Pydantic model conversion and type
# validation). Validation errors are returned to the LLM so it can
# self-correct and retry with proper argument types.
args_to_call, validation_errors = self._preprocess_args(args)

if validation_errors:
return self._build_validation_error_response(validation_errors)

signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}
Expand Down
213 changes: 213 additions & 0 deletions tests/unittests/tools/test_function_tool_arg_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for FunctionTool argument type validation and coercion."""

from enum import Enum
from typing import Optional
from unittest.mock import MagicMock

from google.adk.agents.invocation_context import InvocationContext
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
import pytest


class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"


def int_func(num: int) -> int:
return num


def float_func(val: float) -> float:
return val


def bool_func(flag: bool) -> bool:
return flag


def enum_func(color: Color) -> str:
return color.value


def list_int_func(nums: list[int]) -> list[int]:
return nums


def optional_int_func(num: Optional[int] = None) -> Optional[int]:
return num


def multi_param_func(name: str, count: int, flag: bool) -> dict:
return {"name": name, "count": count, "flag": flag}


# --- _preprocess_args coercion tests ---


class TestArgCoercion:

def test_string_to_int(self):
tool = FunctionTool(int_func)
args, errors = tool._preprocess_args({"num": "42"})
assert errors == []
assert args["num"] == 42
assert isinstance(args["num"], int)

def test_float_to_int(self):
"""Pydantic lax mode truncates float to int."""
tool = FunctionTool(int_func)
args, errors = tool._preprocess_args({"num": 3.0})
assert errors == []
assert args["num"] == 3
assert isinstance(args["num"], int)

def test_string_to_float(self):
tool = FunctionTool(float_func)
args, errors = tool._preprocess_args({"val": "3.14"})
assert errors == []
assert abs(args["val"] - 3.14) < 1e-9

def test_int_to_float(self):
tool = FunctionTool(float_func)
args, errors = tool._preprocess_args({"val": 5})
assert errors == []
assert args["val"] == 5.0
assert isinstance(args["val"], float)

def test_enum_valid_value(self):
tool = FunctionTool(enum_func)
args, errors = tool._preprocess_args({"color": "red"})
assert errors == []
assert args["color"] == Color.RED

def test_enum_invalid_value(self):
tool = FunctionTool(enum_func)
args, errors = tool._preprocess_args({"color": "purple"})
assert len(errors) == 1
assert "color" in errors[0]

def test_list_int_coercion(self):
tool = FunctionTool(list_int_func)
args, errors = tool._preprocess_args({"nums": ["1", "2", "3"]})
assert errors == []
assert args["nums"] == [1, 2, 3]

def test_optional_none_skipped(self):
tool = FunctionTool(optional_int_func)
args, errors = tool._preprocess_args({"num": None})
assert errors == []
assert args["num"] is None

def test_optional_value_coerced(self):
tool = FunctionTool(optional_int_func)
args, errors = tool._preprocess_args({"num": "7"})
assert errors == []
assert args["num"] == 7

def test_bool_from_int(self):
tool = FunctionTool(bool_func)
args, errors = tool._preprocess_args({"flag": 1})
assert errors == []
assert args["flag"] is True


# --- _preprocess_args validation error tests ---


class TestArgValidationErrors:

def test_string_for_int_returns_error(self):
tool = FunctionTool(int_func)
args, errors = tool._preprocess_args({"num": "foobar"})
assert len(errors) == 1
assert "num" in errors[0]

def test_none_for_required_int_returns_error(self):
"""None for a non-Optional int should be flagged."""
tool = FunctionTool(int_func)
# None passed for a required int param. The Optional unwrap won't
# trigger because the annotation is plain `int`, not Optional[int].
# TypeAdapter(int).validate_python(None) raises ValidationError.
args, errors = tool._preprocess_args({"num": None})
assert len(errors) == 1
assert "num" in errors[0]

def test_multiple_param_errors(self):
tool = FunctionTool(multi_param_func)
args, errors = tool._preprocess_args(
{"name": 123, "count": "not_a_number", "flag": "not_a_bool"}
)
# All three fail: pydantic rejects int->str, "not_a_number"->int,
# and "not_a_bool"->bool.
assert len(errors) == 3
assert any("name" in e for e in errors)
assert any("count" in e for e in errors)
assert any("flag" in e for e in errors)


# --- run_async integration tests ---


def _make_tool_context():
tool_context_mock = MagicMock(spec=ToolContext)
invocation_context_mock = MagicMock(spec=InvocationContext)
session_mock = MagicMock(spec=Session)
invocation_context_mock.session = session_mock
tool_context_mock.invocation_context = invocation_context_mock
return tool_context_mock


class TestRunAsyncValidation:

@pytest.mark.asyncio
async def test_invalid_arg_returns_error_to_llm(self):
tool = FunctionTool(int_func)
result = await tool.run_async(
args={"num": "foobar"}, tool_context=_make_tool_context()
)
assert isinstance(result, dict)
assert "error" in result
assert "validation error" in result["error"].lower()

@pytest.mark.asyncio
async def test_valid_coercion_invokes_function(self):
tool = FunctionTool(int_func)
result = await tool.run_async(
args={"num": "42"}, tool_context=_make_tool_context()
)
assert result == 42

@pytest.mark.asyncio
async def test_enum_invalid_returns_error(self):
tool = FunctionTool(enum_func)
result = await tool.run_async(
args={"color": "purple"}, tool_context=_make_tool_context()
)
assert isinstance(result, dict)
assert "error" in result

@pytest.mark.asyncio
async def test_enum_valid_invokes_function(self):
tool = FunctionTool(enum_func)
result = await tool.run_async(
args={"color": "green"}, tool_context=_make_tool_context()
)
assert result == "green"
Loading
Loading