Skip to content

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Aug 22, 2023

ref: #1602

This PR adds support for Falcon models in llama.cpp

Currently, I've put everything inside llama.cpp source file.
We can think about better refactoring so that we can scale the process and add more models without the source file becoming too big. But for now, the main goal was to see what it takes to add support for a new LLM and serve as an example.

The PR also implements a more accurate BPE tokenizer utilizing merges. Used a reference implementation from:

https://github.com/cmp-nct/ggllm.cpp

However, I've dropped the unicode lib for clarity and therefore the implementation does not produce exactly correct tokenization in all cases. It should be good enough for latin languages. The advantage is that the code is more compact. Hopefully in the future we will improve it. In any case, 3rd party tokenizers are always an option and work well with llama.cpp, so if accuracy is essential - this is the recommended way.

  • CPU
  • Metal
  • CUDA
  • OpenCL (not tested)

CUDA offloading still does not work due to missing RoPE NeoX kernel

TODO

  • GGUF typed access to KV and tensor names
  • refactor llama_model_load_internal() into reusable steps
  • implement llm_load_falcon() similar to llm_load_llama()
  • implement llm_build_falcon() similar to llm_build_llama()
  • add Metal support
  • add quantization support
  • GPU offloading for CUDA
  • RoPE (mode = 2) implementation for CUDA backends
  • Clean-up the llm_build_falcon() function
  • Fix tokenization
  • The bpe_gpt2_preprocess() is quite slow (regex)

Usage

# get HF model
git clone https://huggingface.co/tiiuae/falcon-40b

# convert to F16 .gguf
python3 convert-falcon-hf-to-gguf.py ./falcon-40b/ 1

# quantize to Q4_0
make -j && ./quantize ./falcon-40b/ggml-model-f16.gguf ./falcon-40b/ggml-model-q4_0.gguf q4_0

# run
make -j && ./main -m ./falcon-40b/ggml-model-q4_0.gguf

Performance

  • M2 Ultra
model backend n_gpu_layers test t/s
Falcon 7B mostly F16 Metal 999 pp 512 400.84 ± 1.94
Falcon 7B mostly Q8_0 Metal 999 pp 512 387.38 ± 1.31
Falcon 7B mostly Q4_0 Metal 999 pp 512 388.54 ± 1.85
Falcon 7B mostly Q4_1 Metal 999 pp 512 390.00 ± 1.07
Falcon 7B mostly F16 Metal 999 tg 64 29.84 ± 0.04
Falcon 7B mostly Q8_0 Metal 999 tg 64 61.96 ± 0.07
Falcon 7B mostly Q4_0 Metal 999 tg 64 91.64 ± 0.16
Falcon 7B mostly Q4_1 Metal 999 tg 64 84.04 ± 0.04
Falcon 40B mostly F16 Metal 999 pp 512 78.00 ± 0.14
Falcon 40B mostly Q8_0 Metal 999 pp 512 77.31 ± 0.13
Falcon 40B mostly Q4_0 Metal 999 pp 512 77.43 ± 0.19
Falcon 40B mostly Q4_1 Metal 999 pp 512 77.87 ± 0.08
Falcon 40B mostly F16 Metal 999 tg 64 5.94 ± 0.00
Falcon 40B mostly Q8_0 Metal 999 tg 64 13.78 ± 0.01
Falcon 40B mostly Q4_0 Metal 999 tg 64 22.86 ± 0.02
Falcon 40B mostly Q4_1 Metal 999 tg 64 21.26 ± 0.01

build: 38b16df (1052)

model backend n_gpu_layers test t/s
LLaMA v2 7B mostly Q4_0 Metal 1 pp 512 632.43 ± 0.18
LLaMA v2 13B mostly Q4_0 Metal 1 pp 512 369.02 ± 0.07
LLaMA v2 70B mostly Q4_0 Metal 1 pp 512 76.41 ± 0.01
LLaMA v2 7B mostly Q4_0 Metal 1 tg 128 86.13 ± 0.11
LLaMA v2 13B mostly Q4_0 Metal 1 tg 128 54.52 ± 0.06
LLaMA v2 70B mostly Q4_0 Metal 1 tg 128 14.39 ± 0.01

build: 176ea71 (1052)

  • CUDA under WSL2:
    Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6
model backend n_gpu_layers test t/s
LLaMA 7B mostly Q4_0 (guessed) CUDA 33 pp 512 686.76 ± 15.63
falcon-7b 7B mostly Q4_0 CUDA 33 pp 512 476.94 ± 1.95
LLaMA 7B mostly Q4_0 (guessed) CUDA 33 tg 128 44.67 ± 0.28
falcon-7b 7B mostly Q4_0 CUDA 33 tg 128 68.52 ± 0.20
model backend n_gpu_layers test t/s
LLaMA 30B mostly Q4_0 CUDA 99 pp 512 615.45 ± 3.24
falcon-40b 40B mostly Q4_0 CUDA 99 pp 512 441.00 ± 0.93
LLaMA 30B mostly Q4_0 CUDA 99 tg 128 18.07 ± 0.03
falcon-40b 40B mostly Q4_0 CUDA 99 tg 128 17.29 ± 0.01

build: 176ea71 (1052)

@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

Conversion of 7b model do not work. The qkv transform needs n_kv_head = 1 for this to work
https://github.com/ggerganov/llama.cpp/blob/3c7c325b9867a30637529b0328fbc73b4e527004/convert-falcon-hf-to-gguf.py#L204
https://github.com/ggerganov/llama.cpp/blob/3c7c325b9867a30637529b0328fbc73b4e527004/convert-falcon-hf-to-gguf.py#L241-L246

Traceback (most recent call last):
  File "convert-falcon-hf-to-gguf.py", line 242, in <module>
    qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
RuntimeError: shape '[71, 3, 64, 4544]' is invalid for input of size 21229568

@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

main using falcon-7b:
Error loading model: create_tensor: tensor 'blk.0.attn_norm_2.weight' not found

There is no norm_2 in 7b model

https://github.com/ggerganov/llama.cpp/blob/2d58444dae1545b96de9929366c7dffe09605c5e/llama.cpp#L1865-L1867

@ggerganov
Copy link
Member Author

@klosax Reconvert and it works

Strangely, enabling Metal we crash here:

https://github.com/ggerganov/llama.cpp/blob/0ec27ad66c56a4831f62a8106b7873a6818bf051/llama.cpp#L2679

If I disable the graph concurrency optimization, it does not crash:

https://github.com/ggerganov/llama.cpp/blob/0ec27ad66c56a4831f62a8106b7873a6818bf051/llama.cpp#L4799-L4800

Any guesses what could be wrong?

@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

Any guesses what could be wrong?

Not a clue. I can try enabling cublas to see if that works.

Edit: Enabling cublas works on 40b-q4_0 model.

@ggerganov ggerganov marked this pull request as ready for review August 22, 2023 20:32
@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

Perplexity on 7b and 40b wont work:

terminate called after throwing an instance of 'std::out_of_range'
  what():  _Map_base::at
Aborted (core dumped)

It look it is something with the tokenizer since using wiki.test.raw.406as input to main also get the same error. Using a simple dataset works fine.

@ggerganov
Copy link
Member Author

Yes, something wrong with the tokenization. Here is the stack trace:

system_info: n_threads = 16 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
libc++abi: terminating due to uncaught exception of type std::out_of_range: unordered_map::at: key not found
Process 26597 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGABRT
    frame #0: 0x000000018b3d8764 libsystem_kernel.dylib`__pthread_kill + 8
libsystem_kernel.dylib`:
->  0x18b3d8764 <+8>:  b.lo   0x18b3d8784               ; <+40>
    0x18b3d8768 <+12>: pacibsp 
    0x18b3d876c <+16>: stp    x29, x30, [sp, #-0x10]!
    0x18b3d8770 <+20>: mov    x29, sp
Target 0: (perplexity) stopped.
(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGABRT
  * frame #0: 0x000000018b3d8764 libsystem_kernel.dylib`__pthread_kill + 8
    frame #1: 0x000000018b40fc28 libsystem_pthread.dylib`pthread_kill + 288
    frame #2: 0x000000018b31dae8 libsystem_c.dylib`abort + 180
    frame #3: 0x000000018b3c8b84 libc++abi.dylib`abort_message + 132
    frame #4: 0x000000018b3b83b4 libc++abi.dylib`demangling_terminate_handler() + 320
    frame #5: 0x000000018b08f03c libobjc.A.dylib`_objc_terminate() + 160
    frame #6: 0x000000018b3c7f48 libc++abi.dylib`std::__terminate(void (*)()) + 16
    frame #7: 0x000000018b3cad34 libc++abi.dylib`__cxxabiv1::failed_throw(__cxxabiv1::__cxa_exception*) + 36
    frame #8: 0x000000018b3cace0 libc++abi.dylib`__cxa_throw + 140
    frame #9: 0x0000000100026ba4 perplexity`std::__1::__throw_out_of_range[abi:v15006](__msg="unordered_map::at: key not found") at stdexcept:268:5
    frame #10: 0x00000001000619bc perplexity`std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, int, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const, int> > >::at(this=0x00000001006040e0 size=65023, __k="<0xE7>") const at unordered_map:1863:9
    frame #11: 0x00000001000615a8 perplexity`llama_byte_to_token(vocab=0x00000001006040d8, ch='\xe7') at llama.cpp:2852:30
    frame #12: 0x000000010005cb04 perplexity`llama_tokenizer::resegment(this=0x000000016fdfdf08, symbol=0x0000000130fe5888, output=size=1581) at llama.cpp:2991:44
    frame #13: 0x000000010005be4c perplexity`llama_tokenizer::tokenize(this=0x000000016fdfdf08, text=" \n = Robert Boulter = \n \n Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as \" Craig \" in the episode \" Teddy 's Story \" of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the Menier Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . \n In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 2006 episode of the televi"..., output=size=1581) at llama.cpp:2971:13
    frame #14: 0x0000000100036bd8 perplexity`llama_tokenize_internal(vocab=0x00000001006040d8, raw_text=" \n = Robert Boulter = \n \n Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as \" Craig \" in the episode \" Teddy 's Story \" of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the Menier Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . \n In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 2006 episode of the televi"..., bos=true, escape=false) at llama.cpp:3055:15
    frame #15: 0x00000001000367b0 perplexity`::llama_tokenize_with_model(model=0x0000000100604080, text=" \n = Robert Boulter = \n \n Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as \" Craig \" in the episode \" Teddy 's Story \" of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the Menier Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . \n In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 2006 episode of the televi"..., tokens=0x00000001187a0000, n_max_tokens=1290590, add_bos=true) at llama.cpp:5480:16
    frame #16: 0x000000010003671c perplexity`::llama_tokenize(ctx=0x0000000102008200, text=" \n = Robert Boulter = \n \n Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as \" Craig \" in the episode \" Teddy 's Story \" of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the Menier Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . \n In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 2006 episode of the televi"..., tokens=0x00000001187a0000, n_max_tokens=1290590, add_bos=true) at llama.cpp:5450:12
    frame #17: 0x0000000100014f50 perplexity`llama_tokenize(ctx=0x0000000102008200, text=" \n = Robert Boulter = \n \n Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as \" Craig \" in the episode \" Teddy 's Story \" of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the Menier Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . \n In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 2006 episode of the televi"..., add_bos=true) at common.cpp:711:16
    frame #18: 0x0000000100003284 perplexity`perplexity(ctx=0x0000000102008200, params=0x000000016fdfedd8) at perplexity.cpp:35:19
    frame #19: 0x0000000100006034 perplexity`main(argc=8, argv=0x000000016fdff258) at perplexity.cpp:412:9
    frame #20: 0x000000018b0b7f28 dyld`start + 2236
(lldb) frame select 11
frame #11: 0x00000001000615a8 perplexity`llama_byte_to_token(vocab=0x00000001006040d8, ch='\xe7') at llama.cpp:2852:30
   2849	   char buf[7];
   2850	   int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
   2851	   GGML_ASSERT(0 <= result && result < 7);
-> 2852	   return vocab.token_to_id.at(buf);
   2853	}
   2854	
   2855	static std::string llama_escape_whitespace(const std::string& text) {
(lldb) print ch
(uint8_t) $0 = '\xe7'
(lldb) 

@goerch
Copy link
Contributor

goerch commented Aug 22, 2023

Yes, something wrong with the tokenization. Here is the stack trace:

Is Falcon using a sentencepiece-based tokenizer? Otherwise you should not reach that function.

@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

Is Falcon using a sentencepiece-based tokenizer?

No, it is bpe.

@goerch
Copy link
Contributor

goerch commented Aug 22, 2023

No, it is bpe.

So I'd recommend switching vocabulary type then. That will be a first for me ;)

Edit: is there an easy way for me to test this?

@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

‎llama_tokenize_bpe‎() and llama_token_to_str_bpe‎() is never used in llama.cpp ?

@klosax
Copy link
Contributor

klosax commented Aug 22, 2023

Edit: is there an easy way for me to test this?

Build the branch and test with Falcon-7b from here:
https://huggingface.co/tiiuae/falcon-7b/tree/main

@slaren
Copy link
Member

slaren commented Aug 22, 2023

Seeing that falcon is entirely in bfloat16, should it be converted as f32?

@goerch
Copy link
Contributor

goerch commented Aug 22, 2023

‎llama_tokenize_bpe‎() and llama_token_to_str_bpe‎() is never used in llama.cpp ?

These are used in test to assert low level behavior of the tokenizer.

@ggerganov ggerganov merged commit cf658ad into master Aug 23, 2023
@klosax
Copy link
Contributor

klosax commented Aug 23, 2023

I believe the unicode implementation did not use regex for that reason?
https://github.com/ggerganov/llama.cpp/blob/fae8faa135942918a7c9ecc8c2fc26be7f61140d/llama.cpp#L3324

@ggerganov ggerganov deleted the falcon branch August 23, 2023 20:08
@klosax
Copy link
Contributor

klosax commented Aug 24, 2023

Metal inference now works correctly without the graph concurrency optimization. The bug was related to tanh returning NaNs sometimes:

@ggerganov Getting NaNs using cublas without offloading. That was the reason the HellaSwag score was so low when quantizing the output tensor Q4_0. It seems to work fine without blas. Dont know if this is the cuda GELU or something else.

task 46 ending 3 logprob -4.36228347
47	70.21276596
task 47 ending 0 logprob -3.56702414
task 47 ending 1 logprob -2.74259720
task 47 ending 2 logprob -2.01594493
task 47 ending 3 logprob -3.34398872
48	70.83333333
task 48 ending 0 logprob nan
task 48 ending 1 logprob nan
task 48 ending 2 logprob nan
task 48 ending 3 logprob nan
49	69.38775510
task 49 ending 0 logprob nan
task 49 ending 1 logprob nan

@ggerganov
Copy link
Member Author

Since you have a repro, can you try if the NaNs disappear after changing tanhf -> tanh in the following places:

https://github.com/ggerganov/ggml/blob/6319ae9ad7bdf9f834b2855d7e9fa70508e82f57/src/ggml.c#L3553

https://github.com/ggerganov/ggml/blob/6319ae9ad7bdf9f834b2855d7e9fa70508e82f57/src/ggml.c#L3562

@klosax
Copy link
Contributor

klosax commented Aug 24, 2023

Since you have a repro, can you try if the NaNs disappear after changing tanhf -> tanh in the following places:

That seems to work. I will run some more tests.

@cmp-nct
Copy link
Contributor

cmp-nct commented Aug 24, 2023

@klosax

The perplexity results using the "simple" BPE tokenizer are significantly worse than the "advanced". I was hoping the differences to be smaller.

Looks like we will need the advanced BPE tokenizer added to llama.cpp after all.

I looked at the example implementation that you provided - this will be very helpful to use while implementing. I wish the code was simpler so we can put it straight into llama.cpp, but it's not the case, so it will have to live in a separate C++ header file for now.

Will probably implement this after we merge the current PR

Should I change this string to "falcon" or keep it "gpt2"?

https://github.com/ggerganov/llama.cpp/blob/176ea716b355cff5a2b6a97c7648cee138183818/convert-falcon-hf-to-gguf.py#L122

Also, are there other models that use this specific tokenizer? I assume most GPT-NeoX-based models (e.g. Dolly, MPT) should be compatible, correct?

I created that tokenizer when I realized that without the bpe "merges" support barely any longer token is correctly tokenized, including the special tokens which frequently are used in fine tunes also any european language requires it.

@klosax I did not use regex only because I did not like adding it as a huge dependency and I thought it's going to be very slow, in the end that decision cost me quite some time. Same reasoning for creating the custom unicode c++ library, just in that case I'm sure it was the right decision.

@klosax
Copy link
Contributor

klosax commented Aug 25, 2023

I created that tokenizer when I realized that without the bpe "merges" support barely any longer token is correctly tokenized, including the special tokens which frequently are used in fine tunes also any european language requires it.

Yes, the merges are important and adding the merges lowered the perplexity by 33%. The value of having merges in bpe is comparable to having the scores in sentencepiece.

I did not use regex only because I did not like adding it as a huge dependency and I thought it's going to be very slow, in the end that decision cost me quite some time.

The current llama.cpp implementation of the tokenizer uses regex for simplicity and it is slow. I guess it will be replaced later.

Same reasoning for creating the custom unicode c++ library, just in that case I'm sure it was the right decision.

Yes, your unicode library is a much better choice than depending on the huge ICU for full unicode support. If this is implemented in llama.cpp it could possibly also be used by the LLaMA sentencepiece tokenizer.

The importance of a good tokenizer should not be underestimated when it comes to generation quality.

@logicchains
Copy link
Contributor

Falcon 180B has been released now, would be great to support that: https://huggingface.co/blog/falcon-180b

@Green-Sky
Copy link
Collaborator

o.o 360GB safetensors ......
The architecture in the model card looks like the 40B one (but scaled up). so it should be compatible. would be cool if someone with a large enough RAM system could try this.

@logicchains
Copy link
Contributor

I've got 256GB of CPU ram I could try it on, once there's a GGUF.

@Green-Sky
Copy link
Collaborator

You can try to create the GGUF yourself using the https://github.com/ggerganov/llama.cpp/blob/master/convert-falcon-hf-to-gguf.py

@Green-Sky
Copy link
Collaborator

on further inspection, the falcon convert script wont work. It might be though that the normal convert.py can work, but the tokenizer and tensor-reshaping might be an issue there.

issues with falcon converter .py :

  • the architecture is now called FalconForCausalLM ( RWForCausalLM previously )
  • the convert script needs pytorch named files (not safetensors)

@logicchains
Copy link
Contributor

The normal convert.py fails with:

Traceback (most recent call last):
File "llama.cpp/./convert.py", line 1224, in
main()
File "llama.cpp/./convert.py", line 1165, in main
model_plus = load_some_model(args.model)
File "llama.cpp/./convert.py", line 1087, in load_some_model
model_plus = merge_multifile_models(models_plus)
File "llama.cpp/./convert.py", line 620, in merge_multifile_models
model = merge_sharded([mp.model for mp in models_plus])
File "llama.cpp/./convert.py", line 599, in merge_sharded
return {name: convert(name) for name in names}
File "llama.cpp/./convert.py", line 599, in
return {name: convert(name) for name in names}
File "llama.cpp/./convert.py", line 574, in convert
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
File "llama.cpp/./convert.py", line 574, in
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
KeyError: 'transformer.h.0.mlp.dense_h_to_4h.weight'

@logicchains
Copy link
Contributor

Looks like convert.py expects all the model-*-of-00081.safetensors files to contain the same tensors, but actually for this model they contain different tensors.

@logicchains
Copy link
Contributor

Seems there is code in that for processing different tensors split across files, but changing it to that still fails, due to the config.json missing intermediate_size and rms_norm_eps.

@logicchains
Copy link
Contributor

If setting n_ff to hidden_size * 4 like in convert-falcon-hf-to-gguf, the hard-coded logic in Params.find_n_mult ends up failing: failed to find n_mult for (n_ff=59392, n_embd=14848). If we just hard-code n_mult to 1, then the code ends up failing later when looking for the tokenizer, which I guess normal convert.py doesn't support.

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 6, 2023

Seems there is code in that for processing different tensors split across files, but changing it to that still fails, due to the config.json missing intermediate_size and rms_norm_eps.

the config.json specifies:

"layer_norm_epsilon": 1e-05,

If setting n_ff to hidden_size * 4 like in convert-falcon-hf-to-gguf

iirc this should be something differently. edit: this seems to be right

the hard-coded logic in Params.find_n_mult ends up failing: failed to find n_mult for (n_ff=59392, n_embd=14848)

should be save to ignore n_mult 👍

then the code ends up failing later when looking for the tokenizer, which I guess normal convert.py doesn't support.

yea, I think it is the same as falcon 40B

@TheBloke
Copy link
Contributor

TheBloke commented Sep 6, 2023

I've been told that the model architecture should be identical to 40B, and I'm making GPTQs right now and it seems to work fine with the same code as worked with Falcon 40B.

I'd assumed convert-falcon-to-gguf.py would be the one to update, rather than convert.py?

I had a quick look at it myself earlier. There's a few easy fixes at the top, changing field names to match their new names in config.json. But then it's written to load from pytorch_model.bin and Falcon 180 is in Safetensors and when I looked at the convert.py code for safetensors it looked complicated, with memory mapping and such, so I gave up at that point :)

If anyone can get a working script out I can have GGUFs up soon after that

@logicchains
Copy link
Contributor

I got a version of convert-falcon-to-gguf.py working: https://github.com/logicchains/llama.cpp/blob/falcon180B/convert-falcon180-hf-to-gguf.py (forked because I can't make a branch here). And it seems to just work! Not too slow either; 0.8 tokens/second for the 6bit quantisation, while with llama2 70b 8bit I get around 1.5 tokens/seconds.

./main -m ./models/falcon-180B-q6_K.gguf -c 2048 --temp 0.7 -t 32 -p "The secrets to a happy marriage are as follows:"

The secrets to a happy marriage are as follows:

  • Communicate.
  • Be honest.
  • Do not be afraid to fight.
  • Have s*x.
    Those are the basics. Those are the things that will keep your marriage healthy. Those are things you should try to do as much as possible. But those things aren’t magic. You can be doing all of those things, and your marriage can still fail. If you want the secret sauce for a happy marriage — the one thing that most people forget to do — then keep reading.
    Here’s what I know for sure: A happy marriage is not a marriage full of happiness. It’s not a marriage that lacks sadness, or pain, or anger, or frustration. A happy marriage is not a marriage without hard days and hard times. It’s not a marriage without arguments, or tears, or long nights where you lay in bed thinking, What have I done?
    A happy marriage is a marriage that is full of growth. It’s a marriage that has two people who are willing to be vulnerable, to open themselves up, and to say the hard things. A happy marriage is a marriage with two people who are both willing to change for the better. It’s a marriage with two people who are both willing to face their fears, their anxieties, and their demons head-on, no matter how scary it can be.
    A happy marriage is a marriage that is full of trust. It’s a marriage that has two people who are willing to share their deepest thoughts, their darkest secrets, and their most intimate moments. A happy marriage is a marriage where you feel safe enough to open yourself up completely.
    So what do you do when you want a happy marriage? You start by communicating. You start by being honest. You start by not being afraid to fight. And yes, of course, you have s*x. But more than anything else, you have to be willing to grow together.
    It’s not easy, and it’s not always fun, but it’s worth it. Trust me. I know from experience.<|endoftext|> [end of text]
llama_print_timings:        load time = 53526.48 ms
llama_print_timings:      sample time =   749.78 ms /   428 runs   (    1.75 ms per token,   570.83 tokens per second)
llama_print_timings: prompt eval time =  4232.80 ms /    10 tokens (  423.28 ms per token,     2.36 tokens per second)
llama_print_timings:        eval time = 532203.03 ms /   427 runs   ( 1246.38 ms per token,     0.80 tokens per second)
llama_print_timings:       total time = 537415.52 ms

@Green-Sky
Copy link
Collaborator

@logicchains very nice, can you open a pr? even if its dirty, then mark it as a draft. :)

@logicchains
Copy link
Contributor

Done: #3049 . It works as is if we're fine with having a lot of duplication between convert-falcon180-hf-to-gguf.py and convert-falcon-hf-to-gguf.py.

@TheBloke
Copy link
Contributor

TheBloke commented Sep 7, 2023

@logicchains thank you so much! That's awesome

I used your script successfully, and am currently uploading all the quant formats to: https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF

Even the Q2_K is larger than 50GB and so unfortunately I have to split the files, so there's some manual work required by the user to rejoin them after download. (Oh how I wish GGUF had implemented support for multi-part/sharded files :( )

But they work! Thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.