Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ jobs:
- name: Lint with ruff
run: uv run --directory py ruff check .

- name: Ensure Python test files follow the *_test.py naming convention
run: |
if find py -name "test_*.py" -not -path "*/.*" | grep -q .; then
echo "Error: Found Python test files starting with 'test_'. Please use the '*_test.py' format instead:"
find py -name "test_*.py" -not -path "*/.*"
exit 1
fi

- name: Type check with Ty
run: uv run --directory py ty check .

Expand Down
10 changes: 9 additions & 1 deletion bin/lint
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@ TOP_DIR=$(git rev-parse --show-toplevel)

GO_DIR="${TOP_DIR}/go"
PY_DIR="${TOP_DIR}/py"
JS_DIR="${TOP_DIR}/js"j
JS_DIR="${TOP_DIR}/js"

uv run --directory "${PY_DIR}" ruff check --fix --preview --unsafe-fixes .
uv run --directory "${PY_DIR}" ruff format --preview .
# Ensure Python test files follow the *_test.py naming convention.
if find "${PY_DIR}" -name "test_*.py" -not -path "*/.*" | grep -q .; then
echo "Error: Found Python test files starting with 'test_'. Please use the '*_test.py' format instead:"
find "${PY_DIR}" -name "test_*.py" -not -path "*/.*"
exit 1
fi

uv run --directory "${PY_DIR}" ty check .

# Disabled because there are many lint errors.
Expand Down
9 changes: 6 additions & 3 deletions py/bin/sanitize_schema_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, cast
from typing import cast


class ClassTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -75,7 +75,7 @@ def create_model_config(self, existing_config: ast.Call | None = None, frozen: b
found_populate = False
found_frozen = False

# Preserve existing keywords if present, but override 'extra'
# Preserve existing keywords if present, but override 'extra' and 'alias_generator'
if existing_config:
for kw in existing_config.keywords:
if kw.arg == 'populate_by_name':
Expand All @@ -90,6 +90,9 @@ def create_model_config(self, existing_config: ast.Call | None = None, frozen: b
elif kw.arg == 'extra':
# Skip the existing 'extra', we will enforce 'forbid'
continue
elif kw.arg == 'alias_generator':
# Skip existing alias_generator, we will add our own
continue
elif kw.arg == 'frozen':
# Use the provided 'frozen' value
keywords.append(
Expand Down Expand Up @@ -189,7 +192,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: # noqa: N802

return node

def visit_ClassDef(self, node: ast.ClassDef) -> Any: # noqa: N802
def visit_ClassDef(self, node: ast.ClassDef) -> object: # noqa: N802
"""Visit and transform a class definition node.

Args:
Expand Down
50 changes: 30 additions & 20 deletions py/packages/genkit/src/genkit/ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class while customizing it with any plugins.
import uuid
from collections.abc import AsyncIterator
from pathlib import Path
from typing import Any, cast
from typing import TypedDict, cast

from genkit.aio import Channel
from genkit.blocks.document import Document
Expand Down Expand Up @@ -67,6 +67,16 @@ class while customizing it with any plugins.
from ._server import ServerSpec


class OutputConfigDict(TypedDict, total=False):
"""TypedDict for output configuration when passed as a dict."""

format: str | None
content_type: str | None
instructions: bool | str | None
schema: type | dict[str, object] | None
constrained: bool | None


class Genkit(GenkitBase):
"""Genkit asyncio user-facing API."""

Expand Down Expand Up @@ -127,16 +137,16 @@ async def generate(
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
config: GenerationCommonConfig | dict[str, Any] | None = None,
config: GenerationCommonConfig | dict[str, object] | None = None,
max_turns: int | None = None,
on_chunk: ModelStreamingCallback | None = None,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
output_format: str | None = None,
output_content_type: str | None = None,
output_instructions: bool | str | None = None,
output_schema: type | dict[str, Any] | None = None,
output_schema: type | dict[str, object] | None = None,
output_constrained: bool | None = None,
output: OutputConfig | dict[str, Any] | None = None,
output: OutputConfig | OutputConfigDict | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[DocumentData] | None = None,
) -> GenerateResponseWrapper:
Expand Down Expand Up @@ -274,15 +284,15 @@ def generate_stream(
tools: list[str] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
config: GenerationCommonConfig | dict[str, Any] | None = None,
config: GenerationCommonConfig | dict[str, object] | None = None,
max_turns: int | None = None,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
output_format: str | None = None,
output_content_type: str | None = None,
output_instructions: bool | str | None = None,
output_schema: type | dict[str, Any] | None = None,
output_schema: type | dict[str, object] | None = None,
output_constrained: bool | None = None,
output: OutputConfig | dict[str, Any] | None = None,
output: OutputConfig | OutputConfigDict | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[DocumentData] | None = None,
timeout: float | None = None,
Expand Down Expand Up @@ -379,7 +389,7 @@ async def retrieve(
self,
retriever: str | RetrieverRef | None = None,
query: str | DocumentData | None = None,
options: dict[str, Any] | None = None,
options: dict[str, object] | None = None,
) -> RetrieverResponse:
"""Retrieves documents based on query.

Expand All @@ -392,7 +402,7 @@ async def retrieve(
The generated response with documents.
"""
retriever_name: str
retriever_config: dict[str, Any] = {}
retriever_config: dict[str, object] = {}

if isinstance(retriever, RetrieverRef):
retriever_name = retriever.name
Expand Down Expand Up @@ -429,7 +439,7 @@ async def index(
self,
indexer: str | IndexerRef | None = None,
documents: list[Document] | None = None,
options: dict[str, Any] | None = None,
options: dict[str, object] | None = None,
) -> None:
"""Indexes documents.

Expand All @@ -439,7 +449,7 @@ async def index(
options: Optional indexer-specific options.
"""
indexer_name: str
indexer_config: dict[str, Any] = {}
indexer_config: dict[str, object] = {}

if isinstance(indexer, IndexerRef):
indexer_name = indexer.name
Expand Down Expand Up @@ -473,8 +483,8 @@ async def embed(
self,
embedder: str | EmbedderRef | None = None,
content: str | Document | DocumentData | None = None,
metadata: dict[str, Any] | None = None,
options: dict[str, Any] | None = None,
metadata: dict[str, object] | None = None,
options: dict[str, object] | None = None,
) -> list[Embedding]:
"""Embeds a single document or string.

Expand Down Expand Up @@ -524,7 +534,7 @@ async def embed(
>>> embeddings = await ai.embed(embedder=ref, content='Text')
"""
embedder_name = self._resolve_embedder_name(embedder)
embedder_config: dict[str, Any] = {}
embedder_config: dict[str, object] = {}

# Extract config and version from EmbedderRef (not done for embed_many per JS behavior)
if isinstance(embedder, EmbedderRef):
Expand Down Expand Up @@ -558,8 +568,8 @@ async def embed_many(
self,
embedder: str | EmbedderRef | None = None,
content: list[str] | list[Document] | list[DocumentData] | None = None,
metadata: dict[str, Any] | None = None,
options: dict[str, Any] | None = None,
metadata: dict[str, object] | None = None,
options: dict[str, object] | None = None,
) -> list[Embedding]:
"""Embeds multiple documents or strings in a single batch call.

Expand Down Expand Up @@ -630,7 +640,7 @@ async def evaluate(
self,
evaluator: str | EvaluatorRef | None = None,
dataset: list[BaseDataPoint] | None = None,
options: dict[str, Any] | None = None,
options: dict[str, object] | None = None,
eval_run_id: str | None = None,
) -> EvalResponse:
"""Evaluates a dataset using an evaluator.
Expand All @@ -645,7 +655,7 @@ async def evaluate(
The evaluation results.
"""
evaluator_name: str = ''
evaluator_config: dict[str, Any] = {}
evaluator_config: dict[str, object] = {}

if isinstance(evaluator, EvaluatorRef):
evaluator_name = evaluator.name
Expand Down
Loading
Loading