22#![ allow( clippy:: borrow_deref_ref) ]
33
44use std:: collections:: HashSet ;
5- use std:: thread;
65
7- use fancy_regex :: Regex ;
6+ use pcre2 :: bytes :: { Regex , RegexBuilder } ;
87use pyo3:: exceptions;
98use pyo3:: prelude:: * ;
109use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
@@ -69,20 +68,34 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
6968//
7069// Regex
7170// =====
72- // Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
73- // regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
74- // the usual regex we use.
71+ // A considerable amount of time is spent in regex.
72+ // The easiest way to speed this up is by using a faster regex parser.
73+ // A reason performance varies across parsers might be their feature richness
74+ // The `fast_regex` crate can parse more kinds of regex than the `regex` crate.
75+ // Another reason, could be the implementation of the parser and whether it
76+ // has a JIT compiler capable of generating optimized code during pattern compilation.
77+ // `pcre2` is a C library that has such JIT capabilities and is wrapped in the `pcre2` crate.
7578//
76- // However, given that we're using a regex parse-able by `regex`, there isn't much difference
77- // between using the `regex` crate and using the `fancy_regex` crate .
79+ // Given that we're using a regex that can be parsed by all of the above packages,
80+ // we chose to use `pcre2` due to its superior performance .
7881//
79- // There is an important interaction between threading, `regex` and `fancy_regex`.
80- // When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
81- // some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
82- // old `regex`, we don't hit this, because `find_iter` has a different code path.
83- // Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
84- // Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
85- // each thread.
82+ // There is an important interaction between threading and `prce2`.
83+ // `pcre2` uses scratch space, `pcre2::ffi::MatchData`, that may only be used by one thread at
84+ // a time. Internally, `pcre2::Regex` uses a `thread_local::ThreadLocal` to manage a pool
85+ // of copies of the scratch space. If a new thread is created, it will incur a penalty
86+ // when allocating a copy of this space. There are also internal mutexes on which there
87+ // will be contention if there are multiple new threads making scratch space.
88+ // Thus, it is recommended to keep the threads alive, for example, using a thread pool.
89+ //
90+ // There are a couple potentially better designs to consider:
91+ // 1. Have each thread explicitly own its own scratch space as opposed to looking it up in the
92+ // `thread_local::ThreadLocal`. This would require the user (us) to manage these, which involves
93+ // adjusting the `Encoding` Python class to keep track.
94+ // 2. Another option would be to have a lock-free object pool of scratch space and pull from it
95+ // whenever necessary, allocating if the pool is empty. This is agnostic to the thread that is
96+ // requesting scratch space and thus more flexible. However, there could be contention on the
97+ // internal linked-list and CPU-cache innefficiencies considering that the scratch space
98+ // could have been residing on another core's cache.
8699//
87100// Threading
88101// =========
@@ -108,44 +121,18 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
108121
109122use std:: num:: NonZeroU64 ;
110123pub struct FakeThreadId ( NonZeroU64 ) ;
111-
112- fn hash_current_thread ( ) -> usize {
113- // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
114- // that works great for our use case of avoiding collisions in our array. Unfortunately,
115- // it's private. However, there are only so many ways you can layout a u64, so just transmute
116- // https://github.com/rust-lang/rust/issues/67939
117- const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < std:: thread:: ThreadId > ( ) ] ;
118- const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < FakeThreadId > ( ) ] ;
119- let x = unsafe {
120- std:: mem:: transmute :: < std:: thread:: ThreadId , FakeThreadId > ( thread:: current ( ) . id ( ) ) . 0
121- } ;
122- u64:: from ( x) as usize
123- }
124-
125- const MAX_NUM_THREADS : usize = 128 ;
126124#[ pyclass]
127- struct CoreBPE {
125+ pub struct CoreBPE {
128126 encoder : HashMap < Vec < u8 > , usize > ,
129127 special_tokens_encoder : HashMap < String , usize > ,
130128 decoder : HashMap < usize , Vec < u8 > > ,
131129 special_tokens_decoder : HashMap < usize , Vec < u8 > > ,
132- regex_tls : Vec < Regex > ,
133- special_regex_tls : Vec < Regex > ,
130+ regex : Regex ,
131+ special_regex : Regex ,
134132 sorted_token_bytes : Vec < Vec < u8 > > ,
135133}
136134
137135impl CoreBPE {
138- fn _get_tl_regex ( & self ) -> & Regex {
139- // See performance notes above for what this is about
140- // It's also a little janky, please make a better version of it!
141- // However, it's nice that this doesn't leak memory to short-lived threads
142- & self . regex_tls [ hash_current_thread ( ) % MAX_NUM_THREADS ]
143- }
144-
145- fn _get_tl_special_regex ( & self ) -> & Regex {
146- & self . special_regex_tls [ hash_current_thread ( ) % MAX_NUM_THREADS ]
147- }
148-
149136 fn _decode_native ( & self , tokens : & [ usize ] ) -> Vec < u8 > {
150137 let mut ret = Vec :: with_capacity ( tokens. len ( ) * 2 ) ;
151138 for token in tokens {
@@ -161,10 +148,9 @@ impl CoreBPE {
161148 fn _encode_ordinary_native ( & self , text : & str ) -> Vec < usize > {
162149 // This is the core of the encoding logic; the other functions in here
163150 // just make things complicated :-)
164- let regex = self . _get_tl_regex ( ) ;
165151 let mut ret = vec ! [ ] ;
166- for mat in regex. find_iter ( text) {
167- let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
152+ for mat in self . regex . find_iter ( text. as_bytes ( ) ) {
153+ let piece = mat. unwrap ( ) . as_bytes ( ) ;
168154 if let Some ( token) = self . encoder . get ( piece) {
169155 ret. push ( * token) ;
170156 continue ;
@@ -175,8 +161,6 @@ impl CoreBPE {
175161 }
176162
177163 fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < usize > , usize ) {
178- let special_regex = self . _get_tl_special_regex ( ) ;
179- let regex = self . _get_tl_regex ( ) ;
180164 let mut ret = vec ! [ ] ;
181165
182166 let mut start = 0 ;
@@ -186,8 +170,11 @@ impl CoreBPE {
186170 let mut start_find = start;
187171 loop {
188172 // Find the next allowed special token, if any
189- next_special = special_regex. find_from_pos ( text, start_find) . unwrap ( ) ;
190- match next_special {
173+ next_special = self
174+ . special_regex
175+ . find_at ( text. as_bytes ( ) , start_find)
176+ . unwrap ( ) ;
177+ match & next_special {
191178 Some ( m) => {
192179 if allowed_special. contains ( & text[ m. start ( ) ..m. end ( ) ] ) {
193180 break ;
@@ -197,11 +184,13 @@ impl CoreBPE {
197184 None => break ,
198185 }
199186 }
200- let end = next_special. map_or ( text. len ( ) , |m| m. start ( ) ) ;
187+ let end: usize = next_special
188+ // .as_ref()
189+ . map_or ( text. len ( ) , |m| -> usize { m. start ( ) } ) ;
201190
202191 // Okay, here we go, compare this logic to _encode_ordinary_native
203- for mat in regex. find_iter ( & text[ start..end] ) {
204- let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
192+ for mat in self . regex . find_iter ( & text[ start..end] . as_bytes ( ) ) {
193+ let piece = mat. unwrap ( ) . as_bytes ( ) ;
205194 if let Some ( token) = self . encoder . get ( piece) {
206195 last_piece_token_len = 1 ;
207196 ret. push ( * token) ;
@@ -215,10 +204,10 @@ impl CoreBPE {
215204 match next_special {
216205 // And here we push the special token
217206 Some ( m) => {
218- let piece = m . as_str ( ) ;
207+ let piece = std :: str :: from_utf8 ( m . as_bytes ( ) ) . unwrap ( ) ;
219208 let token = self . special_tokens_encoder [ piece] ;
220209 ret. push ( token) ;
221- start = m. end ( ) ;
210+ start = m. start ( ) ;
222211 last_piece_token_len = 0 ;
223212 }
224213 None => break ,
@@ -394,18 +383,28 @@ impl CoreBPE {
394383 special_tokens_encoder : HashMap < String , usize > ,
395384 pattern : & str ,
396385 ) -> PyResult < Self > {
397- let regex = Regex :: new ( pattern)
398- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?;
386+ let builder = {
387+ let mut builder = RegexBuilder :: new ( ) ;
388+ builder. jit_if_available ( true ) ;
389+ builder
390+ } ;
391+
392+ fn pcre2_error_mapper ( e : pcre2:: Error ) -> pyo3:: PyErr {
393+ PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) )
394+ }
395+
396+ let regex = builder. build ( pattern) . map_err ( pcre2_error_mapper) ?;
399397
400- let special_regex = {
398+ let special_pattern = {
401399 let _parts = special_tokens_encoder
402400 . keys ( )
403401 . map ( |s| fancy_regex:: escape ( s) )
404402 . collect :: < Vec < _ > > ( ) ;
405- Regex :: new ( & _parts. join ( "|" ) )
406- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?
403+ & _parts. join ( "|" )
407404 } ;
408405
406+ let special_regex = builder. build ( special_pattern) . map_err ( pcre2_error_mapper) ?;
407+
409408 let decoder: HashMap < usize , Vec < u8 > > =
410409 encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
411410
@@ -425,10 +424,8 @@ impl CoreBPE {
425424 special_tokens_encoder,
426425 decoder,
427426 special_tokens_decoder,
428- regex_tls : ( 0 ..MAX_NUM_THREADS ) . map ( |_| regex. clone ( ) ) . collect ( ) ,
429- special_regex_tls : ( 0 ..MAX_NUM_THREADS )
430- . map ( |_| special_regex. clone ( ) )
431- . collect ( ) ,
427+ regex : regex,
428+ special_regex : special_regex,
432429 sorted_token_bytes,
433430 } )
434431 }
0 commit comments