diff --git a/transformers_cfg/cli/cli_main.py b/transformers_cfg/cli/cli_main.py index e9e39b1..5b2c5a3 100755 --- a/transformers_cfg/cli/cli_main.py +++ b/transformers_cfg/cli/cli_main.py @@ -75,18 +75,21 @@ def parse_arguments(args=None): action="store_true", help="Load the model in 8-bit mode using bitsandbytes", ) - generate_parser.add_argument( "--no_contrast_mode", action="store_true", help="Disable contrast mode (enabled by default)", ) - generate_parser.add_argument( "--save_to", type=str, help="File path to save the generated text", ) + generate_parser.add_argument( + "--use_mlx", + action="store_true", + help="Use MLX on max to speed up generation", + ) return parser.parse_args(args) @@ -102,10 +105,84 @@ def check_model_support(model_name): def generate_text(args): + # Store results for optional file output + result = f"Prompt: {args.prompt}\n\n" + # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_id) tokenizer.pad_token = tokenizer.eos_token + # Load grammar + with open(args.grammar_file_path, "r") as file: + grammar_str = file.read() + grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + + if args.use_mlx: + try: + import_module("mlx_lm") + except ImportError: + raise ImportError( + "You need to install mlx to use MLX. Install it with `pip install 'git+https://github.com/nathanrchn/mlx-examples.git@logits_processor#subdirectory=llms'`." + ) + + import numpy as np + import mlx.core as mx + from mlx_lm import load, stream_generate + + model, _ = load(args.model_id) + + if not args.no_contrast_mode: + print("\033[91m" + "Unconstrained Generation:" + "\033[0m") + result += "Unconstrained Generation:\n" + generation_stream = stream_generate( + model, + tokenizer, + prompt=args.prompt, + max_tokens=args.max_new_tokens, + repetition_penalty=args.repetition_penalty, + ) + + for token in generation_stream: + result += token + print(token, end="", flush=True) + + print() + + def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array: + torch_input_ids = torch.tensor(np.array(input_ids[None, :]), device=args.device) + torch_logits = torch.tensor(np.array(logits), device=args.device) + + torch_processed_logits = grammar_processor(torch_input_ids, torch_logits) + return mx.array(torch_processed_logits.cpu().numpy()) + + generation_stream = stream_generate( + model, + tokenizer, + prompt=args.prompt, + max_tokens=args.max_new_tokens, + repetition_penalty=args.repetition_penalty, + logits_processor=logits_processor + ) + + # print prompt first in color + print("\033[92m" + "Prompt:" + args.prompt + "\033[0m") + + print("\033[94m" + "Constrained Generation:" + "\033[0m") + result += "Constrained Generation:\n" + for token in generation_stream: + result += token + print(token, end="", flush=True) + + print() + + if args.save_to: + with open(args.save_to, "w") as f: + f.write(result) + print(f"\nResults saved to {args.save_to}") + + return + # Load the model with bitsandbytes if 8bit or 4bit flag is set if args.use_8bit or args.use_4bit: try: @@ -136,12 +213,6 @@ def generate_text(args): input_ids = inputs["input_ids"].to(args.device) attention_mask = inputs["attention_mask"].to(args.device) - # Load grammar - with open(args.grammar_file_path, "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - # Generate with grammar constraints constrained_output = model.generate( input_ids, @@ -163,9 +234,6 @@ def generate_text(args): # print prompt first in color print("\033[92m" + "Prompt:" + args.prompt + "\033[0m") - # Store results for optional file output - result = f"Prompt: {args.prompt}\n\n" - # Generate without grammar constraints (if contrast mode is enabled) if not args.no_contrast_mode: unconstrained_output = model.generate( diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py index 24b1028..2406c52 100644 --- a/transformers_cfg/generation/logits_process.py +++ b/transformers_cfg/generation/logits_process.py @@ -2,6 +2,7 @@ import math import os import pprint +from typing import Optional import torch import logging @@ -11,18 +12,20 @@ ) from transformers.utils import add_start_docstrings +from transformers_cfg.token_grammar_recognizer import AbsTokenRecognizer + logger = logging.getLogger(__name__) class GrammarConstrainedLogitsProcessor(LogitsProcessor): - def __init__(self, grammar_constraint, valid_token_start_idx=None, device=None): + def __init__(self, grammar_constraint: AbsTokenRecognizer, valid_token_start_idx: Optional[int] = None, device: Optional[torch.device] = None) -> None: self.last_size = None self.grammar_constraint = grammar_constraint self.batch_parsing_states = None self.valid_token_start_idx = valid_token_start_idx self.device = device - def mask_logits(self, logits, device): + def mask_logits(self, logits: torch.FloatTensor, device: torch.device) -> torch.FloatTensor: masked_logits = logits.clone() # resolve each stack to a tensor of True/False for each token # indicating acceptance @@ -77,7 +80,7 @@ def mask_logits(self, logits, device): masked_logits[~acceptance] = -math.inf return masked_logits - def process_logits(self, input_ids, scores): + def process_logits(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: """ :param input_ids: :param scores: diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index 7f71736..b096adc 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -4,6 +4,7 @@ from typing import List, Optional import torch +from transformers import PreTrainedTokenizer from transformers_cfg.recognizer import StringRecognizer, AcceptState from transformers_cfg.parser import parse_ebnf @@ -18,11 +19,11 @@ class AbsTokenRecognizer(ABC): def __init__( self, - grammar_str, - tokenizer, - start_rule_name="root", - trie=None, - homomorphism=None, + grammar_str: str, + tokenizer: PreTrainedTokenizer, + start_rule_name: Optional[str] = "root", + trie: Optional[ByteTrie] = None, + homomorphism: Optional[TokenizerMiddleMapping] = None, ): parsed_grammar = parse_ebnf(grammar_str) grammar_encoding = parsed_grammar.grammar_encoding @@ -66,13 +67,13 @@ def update_state_with_batch_token_seqs(self, *args, **kwargs): """Process a list of tokens according to the grammar rules.""" raise NotImplementedError - def batch_filter_vocab(self, batch_parsing_states, device) -> torch.Tensor: + def batch_filter_vocab(self, batch_parsing_states: List[AcceptState], device: torch.device) -> torch.Tensor: batch_acceptance = [] for parsing_state in batch_parsing_states: batch_acceptance.append(self.filter_vocab(parsing_state, device)) return torch.stack(batch_acceptance) - def filter_vocab(self, parsing_state, device) -> torch.Tensor: + def filter_vocab(self, parsing_state: AcceptState, device: torch.device) -> torch.Tensor: if not parsing_state.stacks: # Check if stacks is empty # Handle the empty case: for example, return a tensor of False # The size of the tensor should match the size of your vocabulary @@ -87,7 +88,7 @@ def filter_vocab(self, parsing_state, device) -> torch.Tensor: return acceptance - def get_next_token_acceptance(self, parsing_state, device) -> torch.Tensor: + def get_next_token_acceptance(self, parsing_state: AcceptState, device: torch.device) -> torch.Tensor: raise NotImplementedError def validate_and_set_eos_acceptance(self, acceptance: torch.Tensor) -> torch.Tensor: @@ -111,7 +112,12 @@ def detect_unicode(text: str) -> bool: class IncrementalTokenRecognizer(AbsTokenRecognizer): def __init__( - self, grammar_str, start_rule_name, tokenizer, trie=None, homomorphism=None + self, + grammar_str: str, + start_rule_name: str, + tokenizer: PreTrainedTokenizer, + trie: Optional[ByteTrie] = None, + homomorphism: Optional[TokenizerMiddleMapping] = None, ): super().__init__( grammar_str, @@ -154,8 +160,11 @@ def _update_state_with_token_id( # In this case, do nothing. def update_state_with_batch_token_seqs( - self, input_ids, batch_parsing_states, valid_token_start_idx=None - ): + self, + input_ids: torch.LongTensor, + batch_parsing_states: list[AcceptState], + valid_token_start_idx: Optional[int] = None + ) -> list[AcceptState]: if self.last_size is None: valid_prefix_tokens = [