Skip to content

Speed-up the DRY logits processor#6087

Closed
jojje wants to merge 21 commits intooobabooga:mainfrom
jojje:feat/dry_speedup
Closed

Speed-up the DRY logits processor#6087
jojje wants to merge 21 commits intooobabooga:mainfrom
jojje:feat/dry_speedup

Conversation

@jojje
Copy link
Copy Markdown

@jojje jojje commented Jun 3, 2024

Checklist:

Note, I didn't find any unit-tests in this repo, so created my own to ensure this change does not alter the behavior (output) in any way. You can find the test (both assertion testing and benchmarking in one) here.

@jojje jojje changed the title Feat/dry speedup Speed-up the DRY logits processor Jun 3, 2024
@jojje
Copy link
Copy Markdown
Author

jojje commented Jun 3, 2024

After integrated profiling into text-ui, it seems this solution while being faster according to profiling data, hardly makes a dent in the overall generation latency from what I can see.

If you have access to a fast model that yields more tokens/s than I have, then it's worth a shot.

According to the profiling data, this version is about 2.5x faster when generating 4500 tokens using microsoft_Phi-3-mini-128k-instruct and the settings below.

This corresponds to the performance benchmark. posted in the thread.

{   'max_new_tokens': 512,
    'temperature': 1,
    'temperature_last': False,
    'dynamic_temperature': False,
    'dynatemp_low': 1,
    'dynatemp_high': 1,
    'dynatemp_exponent': 1,
    'smoothing_factor': 0,
    'smoothing_curve': 1,
    'top_p': 1,
    'min_p': 0,
    'top_k': 1,
    'repetition_penalty': 1,
    'presence_penalty': 0,
    'frequency_penalty': 0,
    'repetition_penalty_range': 1024,
    'typical_p': 1,
    'tfs': 1,
    'top_a': 0,
    'guidance_scale': 1,
    'penalty_alpha': 0,
    'mirostat_mode': 0,
    'mirostat_tau': 5,
    'mirostat_eta': 0.1,
    'do_sample': False,
    'encoder_repetition_penalty': 1,
    'no_repeat_ngram_size': 0,
    'dry_multiplier': 0.8,
    'dry_base': 1.75,
    'dry_allowed_length': 2,
    'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
    'sampler_priority': [   'temperature',
                            'dynamic_temperature',
                            'quadratic_sampling',
                            'top_k',
                            'top_p',
                            'typical_p',
                            'epsilon_cutoff',
                            'eta_cutoff',
                            'tfs',
                            'top_a',
                            'min_p',
                            'mirostat'],
    'use_cache': True,
    'eos_token_id': [32000],
    'stopping_criteria': [   <modules.callbacks._StopEverythingStoppingCriteria object at 0x00000177A42DBED0>],
    'logits_processor': []}

Here's the profiling data for the original code (the code in the dev branch):

Timer unit: 1e-06 s

Total time: 0.93335 s
File: text-generation-webui/modules/sampler_hijack.py
Function: __call__ at line 203

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   203                                               @profile
   204                                               def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
   205       284        248.7      0.9      0.0          if self._range > 0:
   206       284       1387.0      4.9      0.1              input_ids = input_ids[:, -self._range:]
   207                                           
   208       568       4473.6      7.9      0.5          for input_ids_row, scores_row in zip(input_ids, scores):
   209                                                       # Raw integer must be extracted here to check for set membership.
   210       284     148665.8    523.5     15.9              last_token = input_ids_row[-1].item()
   211                                           
   212       284        235.7      0.8      0.0              if last_token in self.sequence_breakers:
   213         6          1.6      0.3      0.0                  continue
   214                                           
   215                                                       # Exclude the last token as it always matches.
   216       278      24540.7     88.3      2.6              match_indices = (input_ids_row[:-1] == last_token).nonzero()
   217                                           
   218                                                       # Stores the maximum matching sequence length
   219                                                       # for each token immediately following the sequence in the input.
   220       278        112.6      0.4      0.0              match_lengths = {}
   221                                           
   222      4322       5767.5      1.3      0.6              for i in match_indices:
   223      4044     191851.9     47.4     20.6                  next_token = input_ids_row[i+1].item()
   224                                           
   225      4044       2437.7      0.6      0.3                  if next_token in self.sequence_breakers:
   226       196         36.9      0.2      0.0                      continue
   227                                           
   228                                                           # We have already found that `last_token` matches at this index,
   229                                                           # so the match is at least of length 1.
   230      3848        709.5      0.2      0.1                  match_length = 1
   231                                           
   232                                                           # Extend the match backwards as far as possible.
   233      4500        801.7      0.2      0.1                  while True:
   234      4500      35658.7      7.9      3.8                      j = i - match_length
   235      4500     155454.7     34.5     16.7                      if j < 0:
   236                                                                   # Start of input reached.
   237                                                                   break
   238                                           
   239      4500     128235.6     28.5     13.7                      previous_token = input_ids_row[-(match_length+1)].item()
   240      4500     220285.0     49.0     23.6                      if input_ids_row[j] != previous_token:
   241                                                                   # Start of match reached.
   242      3836        878.3      0.2      0.1                          break
   243                                           
   244       664        415.9      0.6      0.0                      if previous_token in self.sequence_breakers:
   245                                                                   # Sequence-breaking token reached.
   246        12          2.2      0.2      0.0                          break
   247                                           
   248       652        217.6      0.3      0.0                      match_length += 1
   249                                           
   250      3848       1595.6      0.4      0.2                  if next_token in match_lengths:
   251      1446       2007.6      1.4      0.2                      match_lengths[next_token] = max(match_length, match_lengths[next_token])
   252                                                           else:
   253      2402       1123.1      0.5      0.1                      match_lengths[next_token] = match_length
   254                                           
   255                                                       # Apply penalties.
   256      2680        894.3      0.3      0.1              for token, match_length in match_lengths.items():
   257      2402        576.5      0.2      0.1                  if match_length >= self.allowed_length:
   258       408        262.9      0.6      0.0                      penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
   259       408       4415.7     10.8      0.5                      scores_row[token] -= penalty
   260                                           
   261       284         55.0      0.2      0.0          return scores

And here are the results for this PR

Timer unit: 1e-06 s

Total time: 0.185979 s
File: text-generation-webui/modules/sampler_hijack.py
Function: __call__ at line 203

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   203                                               @profile
   204                                               def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
   205                                                   # Operations used for the algorithm are faster on CPU with numpy.
   206       284        492.9      1.7      0.3          if input_ids.device.type != 'cpu':
   207       284     151634.1    533.9     81.5              input_ids = input_ids.to('cpu')
   208       284       2714.9      9.6      1.5          input_ids = input_ids.numpy()
   209                                           
   210       284        241.0      0.8      0.1          if self._range > 0:
   211       284        632.2      2.2      0.3              input_ids = input_ids[:, -self._range:]
   212                                           
   213       568       4229.4      7.4      2.3          for input_ids_row, scores_row in zip(input_ids, scores):
   214                                                       # Raw integer must be extracted here to check for set membership.
   215       284        517.9      1.8      0.3              last_token = input_ids_row[-1]
   216                                           
   217       284        285.4      1.0      0.2              if last_token in self.sequence_breakers:
   218         6          1.2      0.2      0.0                  continue
   219                                           
   220                                                       # Exclude the last token as it always matches.
   221       278       4891.5     17.6      2.6              match_indices = (input_ids_row[:-1] == last_token).nonzero()[0]
   222                                           
   223                                                       # Stores the maximum matching sequence length
   224                                                       # for each token immediately following the sequence in the input.
   225       278         98.4      0.4      0.1              match_lengths = {}
   226                                           
   227      4322       1747.6      0.4      0.9              for i in match_indices:
   228      4044       1521.2      0.4      0.8                  next_token = input_ids_row[i + 1]
   229                                           
   230      4044       1005.0      0.2      0.5                  if next_token in self.sequence_breakers:
   231       196         26.7      0.1      0.0                      continue
   232                                           
   233                                                           # We have already found that `last_token` matches at this index,
   234                                                           # so the match is at least of length 1.
   235      3848        584.8      0.2      0.3                  match_length = 1
   236                                           
   237                                                           # Extend the match backwards as far as possible.
   238      4500        694.0      0.2      0.4                  while True:
   239      4500        969.6      0.2      0.5                      j = i - match_length
   240      4500        986.6      0.2      0.5                      if j < 0:
   241                                                                   # Start of input reached.
   242                                                                   break
   243                                           
   244      4500       1229.4      0.3      0.7                      previous_token = input_ids_row[-(match_length + 1)]
   245      4500       1291.0      0.3      0.7                      if input_ids_row[j] != previous_token:
   246                                                                   # Start of match reached.
   247      3836        531.6      0.1      0.3                          break
   248                                           
   249       664        160.0      0.2      0.1                      if previous_token in self.sequence_breakers:
   250                                                                   # Sequence-breaking token reached.
   251        12          1.8      0.1      0.0                          break
   252                                           
   253       652        126.4      0.2      0.1                      match_length += 1
   254                                           
   255      3848        921.9      0.2      0.5                  if next_token in match_lengths:
   256      1446        775.5      0.5      0.4                      match_lengths[next_token] = max(match_length, match_lengths[next_token])
   257                                                           else:
   258      2402        733.6      0.3      0.4                      match_lengths[next_token] = match_length
   259                                           
   260                                                       # Apply penalties.
   261      2680        827.2      0.3      0.4              for token, match_length in match_lengths.items():
   262      2402        547.5      0.2      0.3                  if match_length >= self.allowed_length:
   263       408        267.1      0.7      0.1                      penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
   264       408       5234.8     12.8      2.8                      scores_row[token] -= penalty
   265                                           
   266       284         57.0      0.2      0.0          return scores

To summarize the data.

  • The original code has a cost of 3.29 ms per token generated.
  • This PR has a cost of ~0.66 ms.
    All results of course specific to the machine tested on, but it's the relative difference that matters.

@jojje jojje closed this Jun 3, 2024
@jojje jojje reopened this Jun 3, 2024
@belladoreai
Copy link
Copy Markdown
Contributor

Your PR is currently pointed at main, you probably want to point it at dev instead

@belladoreai
Copy link
Copy Markdown
Contributor

For reference:

@p-e-w
Copy link
Copy Markdown
Contributor

p-e-w commented Jun 3, 2024

As pointed out by @belladoreai, this PR is quite similar to #6053, minus the match length cap to guarantee linear-time complexity for adversarial inputs. I don't have enough VRAM to run a large model at very long context length, so if you do, perhaps you can benchmark to compare the two implementations.

@jojje
Copy link
Copy Markdown
Author

jojje commented Jun 3, 2024

As per the conversation in #5677, I'll close this one so as not to split the focus unnecessarily. #6047 seems to be where we're already headed for this specific issue, so let's focus on that one instead.

PS. thanks for taking a look at it though.

@jojje jojje closed this Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.