Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 89 additions & 61 deletions packages/core/src/policy/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import * as crypto from 'node:crypto';
import { fileURLToPath } from 'node:url';
import { Storage } from '../config/storage.js';
import {
Expand All @@ -17,7 +18,7 @@ import {
} from './types.js';
import type { PolicyEngine } from './policy-engine.js';
import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js';
import { buildArgsPatterns } from './utils.js';
import { buildArgsPatterns, isSafeRegExp } from './utils.js';
import toml from '@iarna/toml';
import {
MessageBusType,
Expand Down Expand Up @@ -331,6 +332,9 @@ export function createPolicyUpdater(
policyEngine: PolicyEngine,
messageBus: MessageBus,
) {
// Use a sequential queue for persistence to avoid lost updates from concurrent events.
let persistenceQueue = Promise.resolve();

messageBus.subscribe(
MessageBusType.UPDATE_POLICY,
async (message: UpdatePolicy) => {
Expand All @@ -341,6 +345,8 @@ export function createPolicyUpdater(
const patterns = buildArgsPatterns(undefined, message.commandPrefix);
for (const pattern of patterns) {
if (pattern) {
// Note: patterns from buildArgsPatterns are derived from escapeRegex,
// which is safe and won't contain ReDoS patterns.
policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
Expand All @@ -354,6 +360,14 @@ export function createPolicyUpdater(
}
}
} else {
if (message.argsPattern && !isSafeRegExp(message.argsPattern)) {
coreEvents.emitFeedback(
'error',
`Invalid or unsafe regular expression for tool ${toolName}: ${message.argsPattern}`,
);
return;
}

const argsPattern = message.argsPattern
? new RegExp(message.argsPattern)
: undefined;
Expand All @@ -371,72 +385,86 @@ export function createPolicyUpdater(
}

if (message.persist) {
try {
const userPoliciesDir = Storage.getUserPoliciesDir();
await fs.mkdir(userPoliciesDir, { recursive: true });
const policyFile = path.join(userPoliciesDir, 'auto-saved.toml');

// Read existing file
let existingData: { rule?: TomlRule[] } = {};
persistenceQueue = persistenceQueue.then(async () => {
try {
const fileContent = await fs.readFile(policyFile, 'utf-8');
existingData = toml.parse(fileContent) as { rule?: TomlRule[] };
} catch (error) {
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
debugLogger.warn(
`Failed to parse ${policyFile}, overwriting with new policy.`,
error,
);
const userPoliciesDir = Storage.getUserPoliciesDir();
await fs.mkdir(userPoliciesDir, { recursive: true });
const policyFile = path.join(userPoliciesDir, 'auto-saved.toml');

// Read existing file
let existingData: { rule?: TomlRule[] } = {};
try {
const fileContent = await fs.readFile(policyFile, 'utf-8');
existingData = toml.parse(fileContent) as { rule?: TomlRule[] };
} catch (error) {
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
debugLogger.warn(
`Failed to parse ${policyFile}, overwriting with new policy.`,
error,
);
}
}
}

// Initialize rule array if needed
if (!existingData.rule) {
existingData.rule = [];
}

// Create new rule object
const newRule: TomlRule = {};

if (message.mcpName) {
newRule.mcpName = message.mcpName;
// Extract simple tool name
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
? toolName.slice(message.mcpName.length + 2)
: toolName;
newRule.toolName = simpleToolName;
newRule.decision = 'allow';
newRule.priority = 200;
} else {
newRule.toolName = toolName;
newRule.decision = 'allow';
newRule.priority = 100;
}

if (message.commandPrefix) {
newRule.commandPrefix = message.commandPrefix;
} else if (message.argsPattern) {
newRule.argsPattern = message.argsPattern;
}
// Initialize rule array if needed
if (!existingData.rule) {
existingData.rule = [];
}

// Add to rules
existingData.rule.push(newRule);
// Create new rule object
const newRule: TomlRule = {};

if (message.mcpName) {
newRule.mcpName = message.mcpName;
// Extract simple tool name
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
? toolName.slice(message.mcpName.length + 2)
: toolName;
newRule.toolName = simpleToolName;
newRule.decision = 'allow';
newRule.priority = 200;
} else {
newRule.toolName = toolName;
newRule.decision = 'allow';
newRule.priority = 100;
}

// Serialize back to TOML
// @iarna/toml stringify might not produce beautiful output but it handles escaping correctly
const newContent = toml.stringify(existingData as toml.JsonMap);
if (message.commandPrefix) {
newRule.commandPrefix = message.commandPrefix;
} else if (message.argsPattern) {
// message.argsPattern was already validated above
newRule.argsPattern = message.argsPattern;
}

// Atomic write: write to tmp then rename
const tmpFile = `${policyFile}.tmp`;
await fs.writeFile(tmpFile, newContent, 'utf-8');
await fs.rename(tmpFile, policyFile);
} catch (error) {
coreEvents.emitFeedback(
'error',
`Failed to persist policy for ${toolName}`,
error,
);
}
// Add to rules
existingData.rule.push(newRule);

// Serialize back to TOML
// @iarna/toml stringify might not produce beautiful output but it handles escaping correctly
const newContent = toml.stringify(existingData as toml.JsonMap);

// Atomic write: write to a unique tmp file then rename to the target file.
// Using a unique suffix avoids race conditions where concurrent processes
// overwrite each other's temporary files, leading to ENOENT errors on rename.
const tmpSuffix = crypto.randomBytes(8).toString('hex');
const tmpFile = `${policyFile}.${tmpSuffix}.tmp`;

let handle: fs.FileHandle | undefined;
try {
// Use 'wx' to create the file exclusively (fails if exists) for security.
handle = await fs.open(tmpFile, 'wx');
await handle.writeFile(newContent, 'utf-8');
} finally {
await handle?.close();
}
await fs.rename(tmpFile, policyFile);
} catch (error) {
coreEvents.emitFeedback(
'error',
`Failed to persist policy for ${toolName}`,
error,
);
}
});
}
},
);
Expand Down
47 changes: 35 additions & 12 deletions packages/core/src/policy/persistence.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
); // Simulate new file
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);

const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);

const toolName = 'test_tool';
Expand All @@ -70,10 +75,11 @@ describe('createPolicyUpdater', () => {
recursive: true,
});

expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');

// Check written content
const expectedContent = expect.stringContaining(`toolName = "test_tool"`);
expect(fs.writeFile).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
expectedContent,
'utf-8',
);
Expand Down Expand Up @@ -106,7 +112,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);

const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);

const toolName = 'run_shell_command';
Expand All @@ -131,8 +142,8 @@ describe('createPolicyUpdater', () => {
);

// Verify file written
expect(fs.writeFile).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
expect.stringContaining(`commandPrefix = "git status"`),
'utf-8',
);
Expand All @@ -147,7 +158,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);

const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);

const mcpName = 'my-jira-server';
Expand All @@ -164,8 +180,9 @@ describe('createPolicyUpdater', () => {
await new Promise((resolve) => setTimeout(resolve, 0));

// Verify file written
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0];
const writtenContent = writeCall[1] as string;
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
const writeCall = mockFileHandle.writeFile.mock.calls[0];
const writtenContent = writeCall[0] as string;
expect(writtenContent).toContain(`mcpName = "${mcpName}"`);
expect(writtenContent).toContain(`toolName = "${simpleToolName}"`);
expect(writtenContent).toContain('priority = 200');
Expand All @@ -180,7 +197,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);

const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);

const mcpName = 'my"jira"server';
Expand All @@ -195,8 +217,9 @@ describe('createPolicyUpdater', () => {

await new Promise((resolve) => setTimeout(resolve, 0));

const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0];
const writtenContent = writeCall[1] as string;
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
const writeCall = mockFileHandle.writeFile.mock.calls[0];
const writtenContent = writeCall[0] as string;

// Verify escaping - should be valid TOML
// Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar'
Expand Down
26 changes: 23 additions & 3 deletions packages/core/src/policy/policy-updater.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ describe('createPolicyUpdater', () => {
createPolicyUpdater(policyEngine, messageBus);
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);

const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
vi.mocked(fs.open).mockResolvedValue(
mockFileHandle as unknown as fs.FileHandle,
);
vi.mocked(fs.rename).mockResolvedValue(undefined);

await messageBus.publish({
Expand All @@ -120,8 +127,8 @@ describe('createPolicyUpdater', () => {
// Wait for the async listener to complete
await new Promise((resolve) => setTimeout(resolve, 0));

expect(fs.writeFile).toHaveBeenCalled();
const [_path, content] = vi.mocked(fs.writeFile).mock.calls[0] as [
expect(fs.open).toHaveBeenCalled();
const [content] = mockFileHandle.writeFile.mock.calls[0] as [
string,
string,
];
Expand All @@ -130,6 +137,19 @@ describe('createPolicyUpdater', () => {
expect(parsed.rule).toHaveLength(1);
expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']);
});

it('should reject unsafe regex patterns', async () => {
createPolicyUpdater(policyEngine, messageBus);

await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'test_tool',
argsPattern: '(a+)+',
persist: false,
});

expect(policyEngine.addRule).not.toHaveBeenCalled();
});
});

describe('ShellToolInvocation Policy Update', () => {
Expand Down
Loading
Loading