-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbackend_exllamav2.py
More file actions
executable file
·65 lines (56 loc) · 2.39 KB
/
backend_exllamav2.py
File metadata and controls
executable file
·65 lines (56 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from backend import backend
from random import random
import gc, torch
local_path = 'exllamav2' #folder in project root
import sys
if not local_path in sys.path:
sys.path.append(local_path)
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2Sampler
class backend_exllamav2(backend):
def __init__(self, model_directory, max_context_length=None, gpu_split=None):
super().__init__()
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
if max_context_length:
self.max_context_length = max_context_length
config.max_seq_len = max_context_length
else:
self.max_context_length = config.max_seq_len
self._model = ExLlamaV2(config)
self._model.load(gpu_split)
self._cache = ExLlamaV2Cache(self._model)
self._tokenizer = ExLlamaV2Tokenizer(config)
self._settings = ExLlamaV2Sampler.Settings()
def tokens_count(self, text):
return self._tokenizer.encode(text).shape[-1]
def generate(self, prompt, stop, on_stream=None):
self._settings.temperature = self.temperature
self._settings.top_p = self.top_p
self._settings.top_k = self.top_k
self._settings.typical = self.typical
self._settings.token_repetition_penalty = self.rep_pen
self._settings.token_repetition_range = -1
ids = self._tokenizer.encode(prompt)
ids = ids[:, -(self.max_context_length-self.max_length):]
initial_len = ids.shape[-1]
self._cache.current_seq_len = 0
self._model.forward(ids[:, :-1], self._cache, input_mask=None, preprocess_only=True)
def generate():
nonlocal ids
logits = self._model.forward(ids[:, -1:], self._cache, input_mask=None).float().cpu()
token, _ = ExLlamaV2Sampler.sample(logits, self._settings, ids, random())
ids = torch.cat([ids, token], dim=1)
if token.item() == self._tokenizer.eos_token_id:
return None
return self._tokenizer.decode(ids[:, initial_len:])[0]
result = self.process(generate, stop, on_stream)
return result
def unload(self):
if self._model is None:
return
self._model = None
self._generator = None
gc.collect()
torch.cuda.empty_cache()