Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/core/src/agents/agent-scheduler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ export interface AgentSchedulingOptions {
signal: AbortSignal;
/** Optional function to get the preferred editor for tool modifications. */
getPreferredEditor?: () => EditorType | undefined;
/** Optional function to be notified when the scheduler is waiting for user confirmation. */
onWaitingForConfirmation?: (waiting: boolean) => void;
}

/**
Expand All @@ -48,6 +50,7 @@ export async function scheduleAgentTools(
toolRegistry,
signal,
getPreferredEditor,
onWaitingForConfirmation,
} = options;

// Create a proxy/override of the config to provide the agent-specific tool registry.
Expand All @@ -60,6 +63,7 @@ export async function scheduleAgentTools(
getPreferredEditor: getPreferredEditor ?? (() => undefined),
schedulerId,
parentCallId,
onWaitingForConfirmation,
});

return scheduler.schedule(requests, signal);
Expand Down
40 changes: 31 additions & 9 deletions packages/core/src/agents/local-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import { getModelConfigAlias } from './registry.js';
import { getVersion } from '../utils/version.js';
import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js';
import { DeadlineTimer } from '../utils/deadlineTimer.js';

/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
Expand Down Expand Up @@ -231,6 +232,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter: number,
combinedSignal: AbortSignal,
timeoutSignal: AbortSignal, // Pass the timeout controller's signal
onWaitingForConfirmation?: (waiting: boolean) => void,
): Promise<AgentTurnResult> {
const promptId = `${this.agentId}#${turnCounter}`;

Expand Down Expand Up @@ -265,7 +267,12 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}

const { nextMessage, submittedOutput, taskCompleted } =
await this.processFunctionCalls(functionCalls, combinedSignal, promptId);
await this.processFunctionCalls(
functionCalls,
combinedSignal,
promptId,
onWaitingForConfirmation,
);
if (taskCompleted) {
const finalResult = submittedOutput ?? 'Task completed successfully.';
return {
Expand Down Expand Up @@ -322,6 +329,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
| AgentTerminateMode.MAX_TURNS
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
externalSignal: AbortSignal, // The original signal passed to run()
onWaitingForConfirmation?: (waiting: boolean) => void,
): Promise<string | null> {
this.emitActivity('THOUGHT_CHUNK', {
text: `Execution limit reached (${reason}). Attempting one final recovery turn with a grace period.`,
Expand Down Expand Up @@ -355,6 +363,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter, // This will be the "last" turn number
combinedSignal,
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
onWaitingForConfirmation,
);

if (
Expand Down Expand Up @@ -415,14 +424,22 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.definition.runConfig.maxTimeMinutes ?? DEFAULT_MAX_TIME_MINUTES;
const maxTurns = this.definition.runConfig.maxTurns ?? DEFAULT_MAX_TURNS;

const timeoutController = new AbortController();
const timeoutId = setTimeout(
() => timeoutController.abort(new Error('Agent timed out.')),
const deadlineTimer = new DeadlineTimer(
maxTimeMinutes * 60 * 1000,
'Agent timed out.',
);

// Track time spent waiting for user confirmation to credit it back to the agent.
const onWaitingForConfirmation = (waiting: boolean) => {
if (waiting) {
deadlineTimer.pause();
} else {
deadlineTimer.resume();
}
};

// Combine the external signal with the internal timeout signal.
const combinedSignal = AbortSignal.any([signal, timeoutController.signal]);
const combinedSignal = AbortSignal.any([signal, deadlineTimer.signal]);

logAgentStart(
this.runtimeContext,
Expand Down Expand Up @@ -458,7 +475,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
// Check for timeout or external abort.
if (combinedSignal.aborted) {
// Determine which signal caused the abort.
terminateReason = timeoutController.signal.aborted
terminateReason = deadlineTimer.signal.aborted
? AgentTerminateMode.TIMEOUT
: AgentTerminateMode.ABORTED;
break;
Expand All @@ -469,7 +486,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
currentMessage,
turnCounter++,
combinedSignal,
timeoutController.signal,
deadlineTimer.signal,
onWaitingForConfirmation,
);

if (turnResult.status === 'stop') {
Expand Down Expand Up @@ -498,6 +516,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter, // Use current turnCounter for the recovery attempt
terminateReason,
signal, // Pass the external signal
onWaitingForConfirmation,
);

if (recoveryResult !== null) {
Expand Down Expand Up @@ -551,7 +570,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
if (
error instanceof Error &&
error.name === 'AbortError' &&
timeoutController.signal.aborted &&
deadlineTimer.signal.aborted &&
!signal.aborted // Ensure the external signal was not the cause
) {
terminateReason = AgentTerminateMode.TIMEOUT;
Expand All @@ -563,6 +582,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter, // Use current turnCounter
AgentTerminateMode.TIMEOUT,
signal,
onWaitingForConfirmation,
);

if (recoveryResult !== null) {
Expand Down Expand Up @@ -591,7 +611,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.emitActivity('ERROR', { error: String(error) });
throw error; // Re-throw other errors or external aborts.
} finally {
clearTimeout(timeoutId);
deadlineTimer.abort();
logAgentFinish(
this.runtimeContext,
new AgentFinishEvent(
Expand Down Expand Up @@ -779,6 +799,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
functionCalls: FunctionCall[],
signal: AbortSignal,
promptId: string,
onWaitingForConfirmation?: (waiting: boolean) => void,
): Promise<{
nextMessage: Content;
submittedOutput: string | null;
Expand Down Expand Up @@ -979,6 +1000,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
parentCallId: this.parentCallId,
toolRegistry: this.toolRegistry,
signal,
onWaitingForConfirmation,
},
);

Expand Down
5 changes: 4 additions & 1 deletion packages/core/src/scheduler/confirmation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ export async function resolveConfirmation(
modifier: ToolModificationHandler;
getPreferredEditor: () => EditorType | undefined;
schedulerId: string;
onWaitingForConfirmation?: (waiting: boolean) => void;
},
): Promise<ResolutionResult> {
const { state } = deps;
const { state, onWaitingForConfirmation } = deps;
const callId = toolCall.request.callId;
let outcome = ToolConfirmationOutcome.ModifyWithEditor;
let lastDetails: SerializableConfirmationDetails | undefined;
Expand Down Expand Up @@ -142,12 +143,14 @@ export async function resolveConfirmation(
correlationId,
});

onWaitingForConfirmation?.(true);
const response = await waitForConfirmation(
deps.messageBus,
correlationId,
signal,
ideConfirmation,
);
onWaitingForConfirmation?.(false);
outcome = response.outcome;

if ('onConfirm' in details && typeof details.onConfirm === 'function') {
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/scheduler/scheduler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export interface SchedulerOptions {
getPreferredEditor: () => EditorType | undefined;
schedulerId: string;
parentCallId?: string;
onWaitingForConfirmation?: (waiting: boolean) => void;
}

const createErrorResponse = (
Expand Down Expand Up @@ -90,6 +91,7 @@ export class Scheduler {
private readonly getPreferredEditor: () => EditorType | undefined;
private readonly schedulerId: string;
private readonly parentCallId?: string;
private readonly onWaitingForConfirmation?: (waiting: boolean) => void;

private isProcessing = false;
private isCancelling = false;
Expand All @@ -101,6 +103,7 @@ export class Scheduler {
this.getPreferredEditor = options.getPreferredEditor;
this.schedulerId = options.schedulerId;
this.parentCallId = options.parentCallId;
this.onWaitingForConfirmation = options.onWaitingForConfirmation;
this.state = new SchedulerStateManager(
this.messageBus,
this.schedulerId,
Expand Down Expand Up @@ -437,6 +440,7 @@ export class Scheduler {
modifier: this.modifier,
getPreferredEditor: this.getPreferredEditor,
schedulerId: this.schedulerId,
onWaitingForConfirmation: this.onWaitingForConfirmation,
});
outcome = result.outcome;
lastDetails = result.lastDetails;
Expand Down
80 changes: 80 additions & 0 deletions packages/core/src/scheduler/scheduler_waiting_callback.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, vi, beforeEach } from 'vitest';
import { Scheduler } from './scheduler.js';
import { resolveConfirmation } from './confirmation.js';
import { checkPolicy } from './policy.js';
import { PolicyDecision } from '../policy/types.js';
import { ToolConfirmationOutcome } from '../tools/tools.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { MockTool } from '../test-utils/mock-tool.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { makeFakeConfig } from '../test-utils/config.js';
import type { Config } from '../config/config.js';
import type { ToolCallRequestInfo } from './types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';

vi.mock('./confirmation.js');
vi.mock('./policy.js');

describe('Scheduler waiting callback', () => {
let mockConfig: Config;
let messageBus: MessageBus;
let toolRegistry: ToolRegistry;
let mockTool: MockTool;

beforeEach(() => {
messageBus = createMockMessageBus();
mockConfig = makeFakeConfig();

// Override methods to use our mocks
vi.spyOn(mockConfig, 'getMessageBus').mockReturnValue(messageBus);

mockTool = new MockTool({ name: 'test_tool' });
toolRegistry = new ToolRegistry(mockConfig, messageBus);
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(toolRegistry);
toolRegistry.registerTool(mockTool);

vi.mocked(checkPolicy).mockResolvedValue({
decision: PolicyDecision.ASK_USER,
rule: undefined,
});
});

it('should trigger onWaitingForConfirmation callback', async () => {
const onWaitingForConfirmation = vi.fn();
const scheduler = new Scheduler({
config: mockConfig,
messageBus,
getPreferredEditor: () => undefined,
schedulerId: 'test-scheduler',
onWaitingForConfirmation,
});

vi.mocked(resolveConfirmation).mockResolvedValue({
outcome: ToolConfirmationOutcome.ProceedOnce,
});

const req: ToolCallRequestInfo = {
callId: 'call-1',
name: 'test_tool',
args: {},
isClientInitiated: false,
prompt_id: 'test-prompt',
};

await scheduler.schedule(req, new AbortController().signal);

expect(resolveConfirmation).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
expect.objectContaining({
onWaitingForConfirmation,
}),
);
});
});
82 changes: 82 additions & 0 deletions packages/core/src/utils/deadlineTimer.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { DeadlineTimer } from './deadlineTimer.js';

describe('DeadlineTimer', () => {
beforeEach(() => {
vi.useFakeTimers();
});

afterEach(() => {
vi.restoreAllMocks();
});

it('should abort when timeout is reached', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;
expect(signal.aborted).toBe(false);

vi.advanceTimersByTime(1000);
expect(signal.aborted).toBe(true);
expect(signal.reason).toBeInstanceOf(Error);
expect((signal.reason as Error).message).toBe('Timeout exceeded.');
});

it('should allow extending the deadline', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;

vi.advanceTimersByTime(500);
expect(signal.aborted).toBe(false);

timer.extend(1000); // New deadline is 1000 + 1000 = 2000 from start

vi.advanceTimersByTime(600); // 1100 total
expect(signal.aborted).toBe(false);

vi.advanceTimersByTime(900); // 2000 total
expect(signal.aborted).toBe(true);
});

it('should allow pausing and resuming the timer', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;

vi.advanceTimersByTime(500);
timer.pause();

vi.advanceTimersByTime(2000); // Wait a long time while paused
expect(signal.aborted).toBe(false);

timer.resume();
vi.advanceTimersByTime(400);
expect(signal.aborted).toBe(false);

vi.advanceTimersByTime(200); // Total active time 500 + 400 + 200 = 1100
expect(signal.aborted).toBe(true);
});

it('should abort immediately when abort() is called', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;

timer.abort('cancelled');
expect(signal.aborted).toBe(true);
expect(signal.reason).toBe('cancelled');
});

it('should not fire timeout if aborted manually', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;

timer.abort();
vi.advanceTimersByTime(1000);
// Already aborted, but shouldn't re-abort or throw
expect(signal.aborted).toBe(true);
});
});
Loading
Loading