Skip to content

Commit 9ac5287

Browse files
authored
Merge pull request #5996 from oobabooga/dev
Merge dev branch
2 parents 8f12fb0 + 7a728a3 commit 9ac5287

20 files changed

Lines changed: 138 additions & 109 deletions

README.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ List of command-line flags
256256
| Flag | Description |
257257
|-------------|-------------|
258258
| `--tensorcores` | Use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards. NVIDIA only. |
259+
| `--flash-attn` | Use flash-attention. |
259260
| `--n_ctx N_CTX` | Size of the prompt context. |
260261
| `--threads` | Number of threads to use. |
261262
| `--threads-batch THREADS_BATCH` | Number of threads to use for batches/prompt processing. |
@@ -425,9 +426,3 @@ If you would like to contribute to the project, check out the [Contributing guid
425426
## Acknowledgment
426427

427428
In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition.
428-
429-
## GitHub Sponsors
430-
431-
The following is a list of top-tier sponsors for this project here on GitHub:
432-
433-
* Be the first one! Visit https://github.com/sponsors/oobabooga/.

download-model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,17 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
190190
headers = {}
191191
mode = 'wb'
192192

193-
if output_path.exists() and not start_from_scratch:
194-
# Resume download
195-
r = session.get(url, stream=True, timeout=20)
196-
total_size = int(r.headers.get('content-length', 0))
197-
if output_path.stat().st_size >= total_size:
198-
return
193+
try:
194+
if output_path.exists() and not start_from_scratch:
195+
# Resume download
196+
r = session.get(url, stream=True, timeout=20)
197+
total_size = int(r.headers.get('content-length', 0))
198+
if output_path.stat().st_size >= total_size:
199+
return
199200

200-
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
201-
mode = 'ab'
201+
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
202+
mode = 'ab'
202203

203-
try:
204204
with session.get(url, stream=True, headers=headers, timeout=30) as r:
205205
r.raise_for_status() # If status is not 2xx, raise an error
206206
total_size = int(r.headers.get('content-length', 0))
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
instruction_template: |-
2+
{%- set ns = namespace(found=false) -%}
3+
{%- for message in messages -%}
4+
{%- if message['role'] == 'system' -%}
5+
{%- set ns.found = true -%}
6+
{%- endif -%}
7+
{%- endfor -%}
8+
{%- if not ns.found -%}
9+
{{- '' -}}
10+
{%- endif %}
11+
{%- for message in messages %}
12+
{%- if message['role'] == 'system' -%}
13+
{{- 'System:' + message['content'] + '\n\n' -}}
14+
{%- else -%}
15+
{%- if message['role'] == 'user' -%}
16+
{{-'User: ' + message['content'] + '\n\n'-}}
17+
{%- else -%}
18+
{{-'Assistant: ' + message['content'] + '\n\n' -}}
19+
{%- endif -%}
20+
{%- endif -%}
21+
{%- endfor -%}
22+
{%- if add_generation_prompt -%}
23+
{{-'Assistant:'-}}
24+
{%- endif -%}
25+

js/main.js

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,22 +144,21 @@ targetElement.addEventListener("scroll", function() {
144144

145145
// Create a MutationObserver instance
146146
const observer = new MutationObserver(function(mutations) {
147-
mutations.forEach(function(mutation) {
148-
updateCssProperties();
149-
150-
const firstChild = targetElement.children[0];
151-
if (firstChild.classList.contains("generating")) {
152-
typing.parentNode.classList.add("visible-dots");
153-
document.getElementById("stop").style.display = "flex";
154-
document.getElementById("Generate").style.display = "none";
155-
} else {
156-
typing.parentNode.classList.remove("visible-dots");
157-
document.getElementById("stop").style.display = "none";
158-
document.getElementById("Generate").style.display = "flex";
159-
}
147+
updateCssProperties();
148+
149+
const firstChild = targetElement.children[0];
150+
if (firstChild.classList.contains("generating")) {
151+
typing.parentNode.classList.add("visible-dots");
152+
document.getElementById("stop").style.display = "flex";
153+
document.getElementById("Generate").style.display = "none";
154+
} else {
155+
typing.parentNode.classList.remove("visible-dots");
156+
document.getElementById("stop").style.display = "none";
157+
document.getElementById("Generate").style.display = "flex";
158+
}
160159

161-
doSyntaxHighlighting();
162-
});
160+
161+
doSyntaxHighlighting();
163162

164163
if(!isScrolled) {
165164
targetElement.scrollTop = targetElement.scrollHeight;
@@ -215,6 +214,9 @@ function doSyntaxHighlighting() {
215214
indexes.forEach((index) => {
216215
const element = elements[index];
217216

217+
// Tag this element to prevent it from being highlighted twice
218+
element.setAttribute("data-highlighted", "true");
219+
218220
// Perform syntax highlighting
219221
const codeBlocks = element.querySelectorAll("pre code");
220222

@@ -231,8 +233,6 @@ function doSyntaxHighlighting() {
231233
],
232234
});
233235

234-
// Tag this element to indicate it has been syntax highlighted
235-
element.setAttribute("data-highlighted", "true");
236236
});
237237

238238
observer.observe(targetElement, config);

models/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,5 @@
204204
instruction_template: 'ChatML'
205205
.*airoboros-3_1-yi-34b-200k:
206206
instruction_template: 'Llama-v2'
207+
.*chatqa:
208+
instruction_template: 'NVIDIA-ChatQA'

modules/llamacpp_hf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
217217
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
218218
'logits_all': shared.args.logits_all,
219219
'offload_kqv': not shared.args.no_offload_kqv,
220-
'split_mode': 1 if not shared.args.row_split else 2
220+
'split_mode': 1 if not shared.args.row_split else 2,
221+
'flash_attn': shared.args.flash_attn
221222
}
222223

223224
Llama = llama_cpp_lib().Llama

modules/llamacpp_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def from_pretrained(self, path):
9696
'tensor_split': tensor_split_list,
9797
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
9898
'offload_kqv': not shared.args.no_offload_kqv,
99-
'split_mode': 1 if not shared.args.row_split else 2
99+
'split_mode': 1 if not shared.args.row_split else 2,
100+
'flash_attn': shared.args.flash_attn
100101
}
101102

102103
result.model = Llama(**params)

modules/loaders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
'no_offload_kqv',
4747
'row_split',
4848
'tensorcores',
49+
'flash-attn',
4950
'streaming_llm',
5051
'attention_sink_size',
5152
],
@@ -71,6 +72,7 @@
7172
'no_offload_kqv',
7273
'row_split',
7374
'tensorcores',
75+
'flash-attn',
7476
'streaming_llm',
7577
'attention_sink_size',
7678
'llamacpp_HF_info',

modules/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def load_model(model_name, loader=None):
107107
elif loader in ['llama.cpp', 'llamacpp_HF']:
108108
shared.settings['truncation_length'] = shared.args.n_ctx
109109

110+
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
110111
logger.info(f"LOADER: \"{loader}\"")
111112
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
112113
logger.info(f"INSTRUCTION TEMPLATE: \"{metadata['instruction_template']}\"")
113-
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
114114
return model, tokenizer
115115

116116

modules/shared.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114

115115
# llama.cpp
116116
group = parser.add_argument_group('llama.cpp')
117+
group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.')
117118
group.add_argument('--tensorcores', action='store_true', help='Use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards. NVIDIA only.')
118119
group.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.')
119120
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')

0 commit comments

Comments
 (0)