Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pyo3 = { version = "0.17.3", features = ["extension-module"] }

# tiktoken dependencies
fancy-regex = "0.10.0"
regex = "1.7.0"
pcre2 = { git = "https://github.com/nistath/rust-pcre2/" }
rustc-hash = "1.1.0"
bstr = "1.0.1"

Expand Down
223 changes: 135 additions & 88 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,61 @@
#![allow(clippy::borrow_deref_ref)]

use std::collections::HashSet;
use std::thread;

use fancy_regex::Regex;
use pcre2::bytes::{Regex, RegexBuilder};
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList, PyTuple};
use pyo3::PyResult;
use rustc_hash::FxHashMap as HashMap;

fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
fn _byte_pair_merge<T>(
piece: &[u8],
ranks: &HashMap<Vec<u8>, usize>,
f: impl Fn(std::ops::Range<usize>) -> T,
) -> Vec<T> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();

// NOTE: using a macro here because a closure fails to get inlined
// according to optimization remarks.
macro_rules! get_rank {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can #[inline(always)] a closure, you just might need to surround it in braces

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, error[E0658]: attributes on expressions are experimental. So error[E0518]: attribute should be applied to function or closure is a little optimistic for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the best closure I can come up with is

let get_rank_skip = |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize|  {
    if (start_idx + skip + 2) < parts.len() {
        ranks
            .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
            .map(|r| *r)
    } else {
        None
    }
};

let get_rank = |parts: &Vec<(usize, usize)>, start_idx: usize| {
    get_rank_skip(parts, start_idx, 0)
};

Passing parts by reference, even though it can be captured, seems to be necessary. Otherwise, the borrow checker will complain that we mutably borrow (when assigning to parts) while the closure has an immutable borrow via the captured reference. It's possible there is a better solution, but I am not aware of one.
I think the best option is to keep the macro.

($start_idx:expr, $skip:expr) => {{
if ($start_idx + $skip + 2) < parts.len() {
ranks
.get(&piece[parts[$start_idx].0..parts[$start_idx + $skip + 2].0])
.map(|r| *r)
} else {
None
}
}};
($idx:expr) => {{
get_rank!($idx, 0)
}};
}

// If you have n parts and m merges, this does O(mn) work
// We could do something with a heap and do O(m log n) work
// We look up the ranks once in the beggining and iteratively update
// them during each merge, which reduces the number of rank lookups.
for i in 0..parts.len() - 2 {
match get_rank!(i) {
Some(rank) => {
// usize::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != usize::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}

// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// It is important to consider that n is often small (<100), and as such
// the cache-locality benefits outweigh the algorithmic complexity downsides
// of the `parts` vector data structure above.

// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
Expand All @@ -24,65 +65,87 @@ fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::o
if parts.len() == 1 {
break;
}
let mut min_rank: Option<(usize, usize)> = None;
for i in 0..parts.len() - 1 {
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
*r
} else {
continue;
};
if min_rank.is_none() || rank < min_rank.unwrap().0 {
min_rank = Some((rank, i));

// usize::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (usize, usize) = (usize::MAX, 0);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
if let Some((_, i)) = min_rank {
parts[i] = parts[i].start..parts[i + 1].end;

if min_rank.0 != usize::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX);
if i > 0 {
parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX);
}

parts.remove(i + 1);
} else {
break;
}
}
parts
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
for i in 0..parts.len() - 1 {
out.push(f(parts[i].0..parts[i + 1].0));
}
out
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| ranks[&piece[p.start..p.end]])
.collect()
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| &piece[p.start..p.end])
.collect()
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
}

// Various performance notes:
//
// Regex
// =====
// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
// the usual regex we use.
// A considerable amount of time is spent in regex.
// The easiest way to speed this up is by using a faster regex parser.
// A reason performance varies across parsers might be their feature richness
// The `fast_regex` crate can parse more kinds of regex than the `regex` crate.
// Another reason, could be the implementation of the parser and whether it
// has a JIT compiler capable of generating optimized code during pattern compilation.
// `pcre2` is a C library that has such JIT capabilities and is wrapped in the `pcre2` crate.
//
// However, given that we're using a regex parse-able by `regex`, there isn't much difference
// between using the `regex` crate and using the `fancy_regex` crate.
// Given that we're using a regex that can be parsed by all of the above packages,
// we chose to use `pcre2` due to its superior performance.
//
// There is an important interaction between threading, `regex` and `fancy_regex`.
// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
// old `regex`, we don't hit this, because `find_iter` has a different code path.
// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
// each thread.
// There is an important interaction between threading and `prce2`.
// `pcre2` uses scratch space, `pcre2::ffi::MatchData`, that may only be used by one thread at
// a time. Internally, `pcre2::Regex` uses a `thread_local::ThreadLocal` to manage a pool
// of copies of the scratch space. If a new thread is created, it will incur a penalty
// when allocating a copy of this space. There are also internal mutexes on which there
// will be contention if there are multiple new threads making scratch space.
// Thus, it is recommended to keep the threads alive, for example, using a thread pool.
//
// There are a couple potentially better designs to consider:
// 1. Have each thread explicitly own its own scratch space as opposed to looking it up in the
// `thread_local::ThreadLocal`. This would require the user (us) to manage these, which involves
// adjusting the `Encoding` Python class to keep track.
// 2. Another option would be to have a lock-free object pool of scratch space and pull from it
// whenever necessary, allocating if the pool is empty. This is agnostic to the thread that is
// requesting scratch space and thus more flexible. However, there could be contention on the
// internal linked-list and CPU-cache innefficiencies considering that the scratch space
// could have been residing on another core's cache.
//
// Threading
// =========
Expand All @@ -108,44 +171,18 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->

use std::num::NonZeroU64;
pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}

const MAX_NUM_THREADS: usize = 128;
#[pyclass]
struct CoreBPE {
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
regex_tls: Vec<Regex>,
special_regex_tls: Vec<Regex>,
regex: Regex,
special_regex: Regex,
sorted_token_bytes: Vec<Vec<u8>>,
}

impl CoreBPE {
fn _get_tl_regex(&self) -> &Regex {
// See performance notes above for what this is about
// It's also a little janky, please make a better version of it!
// However, it's nice that this doesn't leak memory to short-lived threads
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}

fn _get_tl_special_regex(&self) -> &Regex {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}

fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for token in tokens {
Expand All @@ -161,10 +198,9 @@ impl CoreBPE {
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_tl_regex();
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
for mat in self.regex.find_iter(text.as_bytes()) {
let piece = mat.unwrap().as_bytes();
if let Some(token) = self.encoder.get(piece) {
ret.push(*token);
continue;
Expand All @@ -175,8 +211,6 @@ impl CoreBPE {
}

fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();
let mut ret = vec![];

let mut start = 0;
Expand All @@ -186,8 +220,11 @@ impl CoreBPE {
let mut start_find = start;
loop {
// Find the next allowed special token, if any
next_special = special_regex.find_from_pos(text, start_find).unwrap();
match next_special {
next_special = self
.special_regex
.find_at(text.as_bytes(), start_find)
.unwrap();
match &next_special {
Some(m) => {
if allowed_special.contains(&text[m.start()..m.end()]) {
break;
Expand All @@ -197,11 +234,13 @@ impl CoreBPE {
None => break,
}
}
let end = next_special.map_or(text.len(), |m| m.start());
let end: usize = next_special
// .as_ref()
.map_or(text.len(), |m| -> usize { m.start() });

// Okay, here we go, compare this logic to _encode_ordinary_native
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
for mat in self.regex.find_iter(&text[start..end].as_bytes()) {
let piece = mat.unwrap().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
Expand All @@ -215,10 +254,10 @@ impl CoreBPE {
match next_special {
// And here we push the special token
Some(m) => {
let piece = m.as_str();
let piece = std::str::from_utf8(m.as_bytes()).unwrap();
let token = self.special_tokens_encoder[piece];
ret.push(token);
start = m.end();
start = m.start();
last_piece_token_len = 0;
}
None => break,
Expand Down Expand Up @@ -394,18 +433,28 @@ impl CoreBPE {
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
let builder = {
let mut builder = RegexBuilder::new();
builder.jit_if_available(true);
builder
};

let special_regex = {
fn pcre2_error_mapper(e: pcre2::Error) -> pyo3::PyErr {
PyErr::new::<exceptions::PyValueError, _>(e.to_string())
}

let regex = builder.build(pattern).map_err(pcre2_error_mapper)?;

let special_pattern = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
&_parts.join("|")
};

let special_regex = builder.build(special_pattern).map_err(pcre2_error_mapper)?;

let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();

Expand All @@ -425,10 +474,8 @@ impl CoreBPE {
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
regex: regex,
special_regex: special_regex,
sorted_token_bytes,
})
}
Expand Down