-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbackend_kobold.py
More file actions
executable file
·70 lines (61 loc) · 2.44 KB
/
backend_kobold.py
File metadata and controls
executable file
·70 lines (61 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from backend import backend
import requests
import json
class backend_kobold(backend):
def __init__(self, api_url, max_context_length=None):
super().__init__()
if not api_url.endswith('/'):
api_url += '/'
if max_context_length:
self.max_context_length = max_context_length
else:
try:
r = requests.get(api_url+'extra/true_max_context_length')
self.max_context_length = r.json()['value']
except:
print('unable to get max context, using default')
self.api_url = api_url + 'extra/'
self.sampler_order = [6,0,1,3,4,2,5]
def tokens_count(self, text):
r = requests.post(self.api_url+'tokencount', json={'prompt':text })
return r.json()['value']
def generate(self, prompt, stop, on_stream=None):
data = {'prompt':prompt,
'stop_sequence': [stop],
'max_context_length': self.max_context_length,
'max_length': self.max_length,
'temperature': self.temperature,
'rep_pen': self.rep_pen,
'rep_pen_range': 600,
'rep_pen_slope': 0,
'tfs': 1,
'top_a': 0,
'top_k': self.top_k,
'top_p': self.top_p,
'min_p': self.min_p,
'typical': self.typical,
'sampler_order': self.sampler_order,
'use_story': False,
'use_memory': False,
'use_authors_note': False,
'use_world_info': False,
'singleline': False
}
result = ''
try:
r = requests.post(self.api_url+'generate/stream/', json=data, stream=True)
if r.status_code != 200:
print('model_kobold', r.status_code, r.reason)
return result
lines = r.iter_lines(decode_unicode=True)
def generate():
for line in lines:
if line.startswith('data:'):
nonlocal result
result += json.loads(line[5:])['token']
return result
return None
return self.process(generate, stop, on_stream)
except Exception as e:
print('backend_kobold',type(e),e)
return result