Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions invokeai/frontend/web/src/common/util/promptAST.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/**
* Expected as either '+', '-', '++', '--', etc. or a numeric string like '1.2', '0.8', etc.
*/
export type Attention = string | number;

type Word = string;

type Punct = string;

type Whitespace = string;

type Embedding = string;

export type Token =
| { type: 'word'; value: Word }
| { type: 'whitespace'; value: Whitespace }
| { type: 'punct'; value: Punct }
| { type: 'lparen' }
| { type: 'rparen' }
| { type: 'weight'; value: Attention }
| { type: 'lembed' }
| { type: 'rembed' };

export type ASTNode =
| { type: 'word'; text: Word; attention?: Attention }
| { type: 'group'; children: ASTNode[]; attention?: Attention }
| { type: 'embedding'; value: Embedding }
| { type: 'whitespace'; value: Whitespace }
| { type: 'punct'; value: Punct };

/**
* Convert a prompt string into an AST.
* @param prompt string
* @returns ASTNode[]
*/
export function tokenize(prompt: string): Token[] {
if (!prompt) {
return [];
}

let i = 0;
let tokens: Token[] = [];

while (i < prompt.length) {
const char = prompt[i];
if (!char) {
break;
}

// Whitespace (including newlines)
if (/\s/.test(char)) {
tokens.push({ type: 'whitespace', value: char });
i++;
continue;
}

// Parentheses
if (char === '(') {
tokens.push({ type: 'lparen' });
i++;
continue;
}

if (char === ')') {
// Look ahead for weight like ')1.1' or ')-0.9' or ')+' or ')-'
const weightMatch = prompt.slice(i + 1).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/);
if (weightMatch && weightMatch[0]) {
let weight: Attention = weightMatch[0];
if (!isNaN(Number(weight))) {
weight = Number(weight);
}
tokens.push({ type: 'rparen' });
tokens.push({ type: 'weight', value: weight });
i += 1 + weightMatch[0].length;
continue;
}
tokens.push({ type: 'rparen' });
i++;
continue;
}

// Handle punctuation (comma, period, etc.)
if (/[,.]/.test(char)) {
tokens.push({ type: 'punct', value: char });
i++;
continue;
}

// Read a word (letters, digits, underscores)
if (/[a-zA-Z0-9_]/.test(char)) {
let j = i;
while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) {
j++;
}
const word = prompt.slice(i, j);
tokens.push({ type: 'word', value: word });

// Check for weight immediately after word (e.g., "Lorem+", "consectetur-")
const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/);
if (weightMatch && weightMatch[0]) {
tokens.push({ type: 'weight', value: weightMatch[0] });
i = j + weightMatch[0].length;
} else {
i = j;
}
continue;
}

// Embeddings
if (char === '<') {
tokens.push({ type: 'lembed' });
i++;
continue;
}

if (char === '>') {
tokens.push({ type: 'rembed' });
i++;
continue;
}

// Any other single character punctuation
if (!/\s/.test(char)) {
tokens.push({ type: 'punct', value: char });
}

i++;
}

return tokens;
}

/**
* Convert tokens into an AST.
* @param tokens Token[]
* @returns ASTNode[]
*/
export function parseTokens(tokens: Token[]): ASTNode[] {
let pos = 0;

function peek(): Token | undefined {
return tokens[pos];
}

function consume(): Token | undefined {
return tokens[pos++];
}

function parseGroup(): ASTNode[] {
const nodes: ASTNode[] = [];

while (pos < tokens.length) {
const token = peek();
if (!token || token.type === 'rparen') {
break;
}
// console.log('Parsing token:', token);

switch (token.type) {
case 'whitespace': {
const wsToken = consume() as Token & { type: 'whitespace' };
nodes.push({ type: 'whitespace', value: wsToken.value });
break;
}
case 'lparen': {
consume();
const groupChildren = parseGroup();

let attention: Attention | undefined;
if (peek()?.type === 'rparen') {
consume(); // consume ')'
if (peek()?.type === 'weight') {
attention = (consume() as Token & { type: 'weight' }).value;
}
}

nodes.push({ type: 'group', children: groupChildren, attention });
break;
}
case 'lembed': {
consume(); // consume '<'
let embedValue = '';
while (peek() && peek()!.type !== 'rembed') {
const embedToken = consume()!;
embedValue +=
embedToken.type === 'word' || embedToken.type === 'punct' || embedToken.type === 'whitespace'
? embedToken.value
: '';
}
if (peek()?.type === 'rembed') {
consume(); // consume '>'
}
nodes.push({ type: 'embedding', value: embedValue.trim() });
break;
}
case 'word': {
const wordToken = consume() as Token & { type: 'word' };
let attention: Attention | undefined;

// Check for immediate weight after word
if (peek()?.type === 'weight') {
attention = (consume() as Token & { type: 'weight' }).value;
}

nodes.push({ type: 'word', text: wordToken.value, attention });
break;
}
case 'punct': {
const punctToken = consume() as Token & { type: 'punct' };
nodes.push({ type: 'punct', value: punctToken.value });
break;
}
default: {
consume();
}
}
}

return nodes;
}

return parseGroup();
}

/**
* Convert an AST back into a prompt string.
* @param ast ASTNode[]
* @returns string
*/
export function serialize(ast: ASTNode[]): string {
let prompt = '';

for (const node of ast) {
switch (node.type) {
case 'punct':
case 'whitespace': {
prompt += node.value;
break;
}
case 'word': {
prompt += node.text;
if (node.attention) {
prompt += String(node.attention);
}
break;
}
case 'group': {
prompt += '(';
prompt += serialize(node.children);
prompt += ')';
if (node.attention) {
prompt += String(node.attention);
}
break;
}
case 'embedding': {
prompt += `<${node.value}>`;
break;
}
}
}

return prompt;
}
Loading