Skip to content

Commit f775e7f

Browse files
committed
Use PCRE2 for regex
1 parent 7830ed5 commit f775e7f

2 files changed

Lines changed: 62 additions & 65 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pyo3 = { version = "0.17.3", features = ["extension-module"] }
1313

1414
# tiktoken dependencies
1515
fancy-regex = "0.10.0"
16-
regex = "1.7.0"
16+
pcre2 = { git = "https://github.com/nistath/rust-pcre2/" }
1717
rustc-hash = "1.1.0"
1818
bstr = "1.0.1"
1919

src/lib.rs

Lines changed: 61 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
#![allow(clippy::borrow_deref_ref)]
33

44
use std::collections::HashSet;
5-
use std::thread;
65

7-
use fancy_regex::Regex;
6+
use pcre2::bytes::{Regex, RegexBuilder};
87
use pyo3::exceptions;
98
use pyo3::prelude::*;
109
use pyo3::types::{PyBytes, PyList, PyTuple};
@@ -69,20 +68,34 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
6968
//
7069
// Regex
7170
// =====
72-
// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
73-
// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
74-
// the usual regex we use.
71+
// A considerable amount of time is spent in regex.
72+
// The easiest way to speed this up is by using a faster regex parser.
73+
// A reason performance varies across parsers might be their feature richness
74+
// The `fast_regex` crate can parse more kinds of regex than the `regex` crate.
75+
// Another reason, could be the implementation of the parser and whether it
76+
// has a JIT compiler capable of generating optimized code during pattern compilation.
77+
// `pcre2` is a C library that has such JIT capabilities and is wrapped in the `pcre2` crate.
7578
//
76-
// However, given that we're using a regex parse-able by `regex`, there isn't much difference
77-
// between using the `regex` crate and using the `fancy_regex` crate.
79+
// Given that we're using a regex that can be parsed by all of the above packages,
80+
// we chose to use `pcre2` due to its superior performance.
7881
//
79-
// There is an important interaction between threading, `regex` and `fancy_regex`.
80-
// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
81-
// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
82-
// old `regex`, we don't hit this, because `find_iter` has a different code path.
83-
// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
84-
// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
85-
// each thread.
82+
// There is an important interaction between threading and `prce2`.
83+
// `pcre2` uses scratch space, `pcre2::ffi::MatchData`, that may only be used by one thread at
84+
// a time. Internally, `pcre2::Regex` uses a `thread_local::ThreadLocal` to manage a pool
85+
// of copies of the scratch space. If a new thread is created, it will incur a penalty
86+
// when allocating a copy of this space. There are also internal mutexes on which there
87+
// will be contention if there are multiple new threads making scratch space.
88+
// Thus, it is recommended to keep the threads alive, for example, using a thread pool.
89+
//
90+
// There are a couple potentially better designs to consider:
91+
// 1. Have each thread explicitly own its own scratch space as opposed to looking it up in the
92+
// `thread_local::ThreadLocal`. This would require the user (us) to manage these, which involves
93+
// adjusting the `Encoding` Python class to keep track.
94+
// 2. Another option would be to have a lock-free object pool of scratch space and pull from it
95+
// whenever necessary, allocating if the pool is empty. This is agnostic to the thread that is
96+
// requesting scratch space and thus more flexible. However, there could be contention on the
97+
// internal linked-list and CPU-cache innefficiencies considering that the scratch space
98+
// could have been residing on another core's cache.
8699
//
87100
// Threading
88101
// =========
@@ -108,44 +121,18 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
108121

109122
use std::num::NonZeroU64;
110123
pub struct FakeThreadId(NonZeroU64);
111-
112-
fn hash_current_thread() -> usize {
113-
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
114-
// that works great for our use case of avoiding collisions in our array. Unfortunately,
115-
// it's private. However, there are only so many ways you can layout a u64, so just transmute
116-
// https://github.com/rust-lang/rust/issues/67939
117-
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
118-
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
119-
let x = unsafe {
120-
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
121-
};
122-
u64::from(x) as usize
123-
}
124-
125-
const MAX_NUM_THREADS: usize = 128;
126124
#[pyclass]
127-
struct CoreBPE {
125+
pub struct CoreBPE {
128126
encoder: HashMap<Vec<u8>, usize>,
129127
special_tokens_encoder: HashMap<String, usize>,
130128
decoder: HashMap<usize, Vec<u8>>,
131129
special_tokens_decoder: HashMap<usize, Vec<u8>>,
132-
regex_tls: Vec<Regex>,
133-
special_regex_tls: Vec<Regex>,
130+
regex: Regex,
131+
special_regex: Regex,
134132
sorted_token_bytes: Vec<Vec<u8>>,
135133
}
136134

137135
impl CoreBPE {
138-
fn _get_tl_regex(&self) -> &Regex {
139-
// See performance notes above for what this is about
140-
// It's also a little janky, please make a better version of it!
141-
// However, it's nice that this doesn't leak memory to short-lived threads
142-
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
143-
}
144-
145-
fn _get_tl_special_regex(&self) -> &Regex {
146-
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
147-
}
148-
149136
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
150137
let mut ret = Vec::with_capacity(tokens.len() * 2);
151138
for token in tokens {
@@ -161,10 +148,9 @@ impl CoreBPE {
161148
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
162149
// This is the core of the encoding logic; the other functions in here
163150
// just make things complicated :-)
164-
let regex = self._get_tl_regex();
165151
let mut ret = vec![];
166-
for mat in regex.find_iter(text) {
167-
let piece = mat.unwrap().as_str().as_bytes();
152+
for mat in self.regex.find_iter(text.as_bytes()) {
153+
let piece = mat.unwrap().as_bytes();
168154
if let Some(token) = self.encoder.get(piece) {
169155
ret.push(*token);
170156
continue;
@@ -175,8 +161,6 @@ impl CoreBPE {
175161
}
176162

177163
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
178-
let special_regex = self._get_tl_special_regex();
179-
let regex = self._get_tl_regex();
180164
let mut ret = vec![];
181165

182166
let mut start = 0;
@@ -186,8 +170,11 @@ impl CoreBPE {
186170
let mut start_find = start;
187171
loop {
188172
// Find the next allowed special token, if any
189-
next_special = special_regex.find_from_pos(text, start_find).unwrap();
190-
match next_special {
173+
next_special = self
174+
.special_regex
175+
.find_at(text.as_bytes(), start_find)
176+
.unwrap();
177+
match &next_special {
191178
Some(m) => {
192179
if allowed_special.contains(&text[m.start()..m.end()]) {
193180
break;
@@ -197,11 +184,13 @@ impl CoreBPE {
197184
None => break,
198185
}
199186
}
200-
let end = next_special.map_or(text.len(), |m| m.start());
187+
let end: usize = next_special
188+
// .as_ref()
189+
.map_or(text.len(), |m| -> usize { m.start() });
201190

202191
// Okay, here we go, compare this logic to _encode_ordinary_native
203-
for mat in regex.find_iter(&text[start..end]) {
204-
let piece = mat.unwrap().as_str().as_bytes();
192+
for mat in self.regex.find_iter(&text[start..end].as_bytes()) {
193+
let piece = mat.unwrap().as_bytes();
205194
if let Some(token) = self.encoder.get(piece) {
206195
last_piece_token_len = 1;
207196
ret.push(*token);
@@ -215,10 +204,10 @@ impl CoreBPE {
215204
match next_special {
216205
// And here we push the special token
217206
Some(m) => {
218-
let piece = m.as_str();
207+
let piece = std::str::from_utf8(m.as_bytes()).unwrap();
219208
let token = self.special_tokens_encoder[piece];
220209
ret.push(token);
221-
start = m.end();
210+
start = m.start();
222211
last_piece_token_len = 0;
223212
}
224213
None => break,
@@ -394,18 +383,28 @@ impl CoreBPE {
394383
special_tokens_encoder: HashMap<String, usize>,
395384
pattern: &str,
396385
) -> PyResult<Self> {
397-
let regex = Regex::new(pattern)
398-
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
386+
let builder = {
387+
let mut builder = RegexBuilder::new();
388+
builder.jit_if_available(true);
389+
builder
390+
};
391+
392+
fn pcre2_error_mapper(e: pcre2::Error) -> pyo3::PyErr {
393+
PyErr::new::<exceptions::PyValueError, _>(e.to_string())
394+
}
395+
396+
let regex = builder.build(pattern).map_err(pcre2_error_mapper)?;
399397

400-
let special_regex = {
398+
let special_pattern = {
401399
let _parts = special_tokens_encoder
402400
.keys()
403401
.map(|s| fancy_regex::escape(s))
404402
.collect::<Vec<_>>();
405-
Regex::new(&_parts.join("|"))
406-
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
403+
&_parts.join("|")
407404
};
408405

406+
let special_regex = builder.build(special_pattern).map_err(pcre2_error_mapper)?;
407+
409408
let decoder: HashMap<usize, Vec<u8>> =
410409
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
411410

@@ -425,10 +424,8 @@ impl CoreBPE {
425424
special_tokens_encoder,
426425
decoder,
427426
special_tokens_decoder,
428-
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
429-
special_regex_tls: (0..MAX_NUM_THREADS)
430-
.map(|_| special_regex.clone())
431-
.collect(),
427+
regex: regex,
428+
special_regex: special_regex,
432429
sorted_token_bytes,
433430
})
434431
}

0 commit comments

Comments
 (0)