-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbackend_exllama.py
More file actions
executable file
·69 lines (60 loc) · 2.72 KB
/
backend_exllama.py
File metadata and controls
executable file
·69 lines (60 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from backend import backend
from os import path
from glob import glob
try:
from exllama.generator import ExLlamaGenerator
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.tokenizer import ExLlamaTokenizer
except ModuleNotFoundError:
local_path = 'exllama' #exllama folder in project root
import sys
if not local_path in sys.path:
sys.path.append(local_path)
from generator import ExLlamaGenerator
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
class backend_exllama(backend):
def __init__(self, model_directory, max_context_length=2048):
super().__init__()
self.max_context_length = max_context_length
tokenizer_path = path.join(model_directory, "tokenizer.model")
tokenizer = ExLlamaTokenizer(tokenizer_path)
model_config_path = path.join(model_directory, "config.json")
st_pattern = path.join(model_directory, "*.safetensors")
model_path = glob(st_pattern)[0]
config = ExLlamaConfig(model_config_path)
config.max_seq_len = max_context_length
config.model_path = model_path
self._model = ExLlama(config)
cache = ExLlamaCache(self._model)
self._generator = ExLlamaGenerator(self._model, tokenizer, cache)
self._generator.settings.token_repetition_penalty_sustain = config.max_seq_len
def tokens_count(self, text):
return self._generator.tokenizer.encode(text).shape[-1]
def generate(self, prompt, stop, on_stream=None):
self._generator.settings.temperature = self.temperature
self._generator.settings.top_p = self.top_p
self._generator.settings.top_k = self.top_k
self._generator.settings.typical = self.typical
self._generator.settings.token_repetition_penalty_max = self.rep_pen
ids = self._generator.tokenizer.encode(prompt, max_seq_len=self.max_context_length)
ids = ids[:, -(self.max_context_length-self.max_length):]
self._generator.gen_begin_reuse(ids)
initial_len = self._generator.sequence[0].shape[0]
def generate():
token = self._generator.gen_single_token()
if token.item() == self._generator.tokenizer.eos_token_id:
return None
return self._generator.tokenizer.decode(self._generator.sequence[0][initial_len:])
result = self.process(generate, stop, on_stream)
self._generator.end_beam_search()
return result
def unload(self):
if self._model is None:
return
self._model.free_unmanaged()
self._model = None
self._generator = None
import gc, torch
gc.collect()
torch.cuda.empty_cache()