Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 79 additions & 11 deletions transformers_cfg/cli/cli_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,21 @@ def parse_arguments(args=None):
action="store_true",
help="Load the model in 8-bit mode using bitsandbytes",
)

generate_parser.add_argument(
"--no_contrast_mode",
action="store_true",
help="Disable contrast mode (enabled by default)",
)

generate_parser.add_argument(
"--save_to",
type=str,
help="File path to save the generated text",
)
generate_parser.add_argument(
"--use_mlx",
action="store_true",
help="Use MLX on max to speed up generation",
)

return parser.parse_args(args)

Expand All @@ -102,10 +105,84 @@ def check_model_support(model_name):


def generate_text(args):
# Store results for optional file output
result = f"Prompt: {args.prompt}\n\n"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
tokenizer.pad_token = tokenizer.eos_token

# Load grammar
with open(args.grammar_file_path, "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

if args.use_mlx:
try:
import_module("mlx_lm")
except ImportError:
raise ImportError(
"You need to install mlx to use MLX. Install it with `pip install 'git+https://github.com/nathanrchn/mlx-examples.git@logits_processor#subdirectory=llms'`."
)

import numpy as np
import mlx.core as mx
from mlx_lm import load, stream_generate

model, _ = load(args.model_id)

if not args.no_contrast_mode:
print("\033[91m" + "Unconstrained Generation:" + "\033[0m")
result += "Unconstrained Generation:\n"
generation_stream = stream_generate(
model,
tokenizer,
prompt=args.prompt,
max_tokens=args.max_new_tokens,
repetition_penalty=args.repetition_penalty,
)

for token in generation_stream:
result += token
print(token, end="", flush=True)

print()

def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array:
torch_input_ids = torch.tensor(np.array(input_ids[None, :]), device=args.device)
torch_logits = torch.tensor(np.array(logits), device=args.device)

torch_processed_logits = grammar_processor(torch_input_ids, torch_logits)
return mx.array(torch_processed_logits.cpu().numpy())

generation_stream = stream_generate(
model,
tokenizer,
prompt=args.prompt,
max_tokens=args.max_new_tokens,
repetition_penalty=args.repetition_penalty,
logits_processor=logits_processor
)

# print prompt first in color
print("\033[92m" + "Prompt:" + args.prompt + "\033[0m")

print("\033[94m" + "Constrained Generation:" + "\033[0m")
result += "Constrained Generation:\n"
for token in generation_stream:
result += token
print(token, end="", flush=True)

print()

if args.save_to:
with open(args.save_to, "w") as f:
f.write(result)
print(f"\nResults saved to {args.save_to}")

return

# Load the model with bitsandbytes if 8bit or 4bit flag is set
if args.use_8bit or args.use_4bit:
try:
Expand Down Expand Up @@ -136,12 +213,6 @@ def generate_text(args):
input_ids = inputs["input_ids"].to(args.device)
attention_mask = inputs["attention_mask"].to(args.device)

# Load grammar
with open(args.grammar_file_path, "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Generate with grammar constraints
constrained_output = model.generate(
input_ids,
Expand All @@ -163,9 +234,6 @@ def generate_text(args):
# print prompt first in color
print("\033[92m" + "Prompt:" + args.prompt + "\033[0m")

# Store results for optional file output
result = f"Prompt: {args.prompt}\n\n"

# Generate without grammar constraints (if contrast mode is enabled)
if not args.no_contrast_mode:
unconstrained_output = model.generate(
Expand Down
9 changes: 6 additions & 3 deletions transformers_cfg/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import os
import pprint
from typing import Optional

import torch
import logging
Expand All @@ -11,18 +12,20 @@
)
from transformers.utils import add_start_docstrings

from transformers_cfg.token_grammar_recognizer import AbsTokenRecognizer

logger = logging.getLogger(__name__)


class GrammarConstrainedLogitsProcessor(LogitsProcessor):
def __init__(self, grammar_constraint, valid_token_start_idx=None, device=None):
def __init__(self, grammar_constraint: AbsTokenRecognizer, valid_token_start_idx: Optional[int] = None, device: Optional[torch.device] = None) -> None:
self.last_size = None
self.grammar_constraint = grammar_constraint
self.batch_parsing_states = None
self.valid_token_start_idx = valid_token_start_idx
self.device = device

def mask_logits(self, logits, device):
def mask_logits(self, logits: torch.FloatTensor, device: torch.device) -> torch.FloatTensor:
masked_logits = logits.clone()
# resolve each stack to a tensor of True/False for each token
# indicating acceptance
Expand Down Expand Up @@ -77,7 +80,7 @@ def mask_logits(self, logits, device):
masked_logits[~acceptance] = -math.inf
return masked_logits

def process_logits(self, input_ids, scores):
def process_logits(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""
:param input_ids:
:param scores:
Expand Down
31 changes: 20 additions & 11 deletions transformers_cfg/token_grammar_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional

import torch
from transformers import PreTrainedTokenizer

from transformers_cfg.recognizer import StringRecognizer, AcceptState
from transformers_cfg.parser import parse_ebnf
Expand All @@ -18,11 +19,11 @@
class AbsTokenRecognizer(ABC):
def __init__(
self,
grammar_str,
tokenizer,
start_rule_name="root",
trie=None,
homomorphism=None,
grammar_str: str,
tokenizer: PreTrainedTokenizer,
start_rule_name: Optional[str] = "root",
trie: Optional[ByteTrie] = None,
homomorphism: Optional[TokenizerMiddleMapping] = None,
):
parsed_grammar = parse_ebnf(grammar_str)
grammar_encoding = parsed_grammar.grammar_encoding
Expand Down Expand Up @@ -66,13 +67,13 @@ def update_state_with_batch_token_seqs(self, *args, **kwargs):
"""Process a list of tokens according to the grammar rules."""
raise NotImplementedError

def batch_filter_vocab(self, batch_parsing_states, device) -> torch.Tensor:
def batch_filter_vocab(self, batch_parsing_states: List[AcceptState], device: torch.device) -> torch.Tensor:
batch_acceptance = []
for parsing_state in batch_parsing_states:
batch_acceptance.append(self.filter_vocab(parsing_state, device))
return torch.stack(batch_acceptance)

def filter_vocab(self, parsing_state, device) -> torch.Tensor:
def filter_vocab(self, parsing_state: AcceptState, device: torch.device) -> torch.Tensor:
if not parsing_state.stacks: # Check if stacks is empty
# Handle the empty case: for example, return a tensor of False
# The size of the tensor should match the size of your vocabulary
Expand All @@ -87,7 +88,7 @@ def filter_vocab(self, parsing_state, device) -> torch.Tensor:

return acceptance

def get_next_token_acceptance(self, parsing_state, device) -> torch.Tensor:
def get_next_token_acceptance(self, parsing_state: AcceptState, device: torch.device) -> torch.Tensor:
raise NotImplementedError

def validate_and_set_eos_acceptance(self, acceptance: torch.Tensor) -> torch.Tensor:
Expand All @@ -111,7 +112,12 @@ def detect_unicode(text: str) -> bool:

class IncrementalTokenRecognizer(AbsTokenRecognizer):
def __init__(
self, grammar_str, start_rule_name, tokenizer, trie=None, homomorphism=None
self,
grammar_str: str,
start_rule_name: str,
tokenizer: PreTrainedTokenizer,
trie: Optional[ByteTrie] = None,
homomorphism: Optional[TokenizerMiddleMapping] = None,
):
super().__init__(
grammar_str,
Expand Down Expand Up @@ -154,8 +160,11 @@ def _update_state_with_token_id(
# In this case, do nothing.

def update_state_with_batch_token_seqs(
self, input_ids, batch_parsing_states, valid_token_start_idx=None
):
self,
input_ids: torch.LongTensor,
batch_parsing_states: list[AcceptState],
valid_token_start_idx: Optional[int] = None
) -> list[AcceptState]:

if self.last_size is None:
valid_prefix_tokens = [
Expand Down