Skip to content

Commit ea0da14

Browse files
committed
current draft!
1 parent 916df54 commit ea0da14

File tree

4 files changed

+88
-2
lines changed

4 files changed

+88
-2
lines changed

bindings/python/src/tokenizer.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,16 @@ impl PyTokenizer {
782782
self.tokenizer.get_vocab(with_added_tokens)
783783
}
784784

785+
/// Get the extra tokens
786+
///
787+
/// Returns:
788+
/// :obj:`Dict[str, int]`: The vocabulary
789+
#[pyo3(signature = ())]
790+
#[pyo3(text_signature = "(self)")]
791+
fn get_special_tokens_mapping(&self) -> Option<&HashMap<String, Vec<String>>> {
792+
self.tokenizer.get_special_tokens_mapping()
793+
}
794+
785795
/// Get the underlying vocabulary
786796
///
787797
/// Returns:
@@ -1848,6 +1858,22 @@ impl PyTokenizer {
18481858
fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
18491859
self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
18501860
}
1861+
1862+
/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
1863+
#[getter]
1864+
fn get_eos_token(&self, py: Python<'_>) -> Option<Vec<String>> {
1865+
self.tokenizer
1866+
.get_special_tokens_mapping()
1867+
.and_then(|token| token.get("eos_token"))
1868+
// into_pyobject -> Bound<PyAny>. Turn that into PyObject.
1869+
.map(|v| v.clone())
1870+
}
1871+
1872+
/// Set the :class:`~tokenizers.decoders.Decoder`
1873+
#[setter]
1874+
fn set_eos_token(&mut self, new_eos_token: Option<String>) {
1875+
self.tokenizer.with_special_tokens_mapping();
1876+
}
18511877
}
18521878

18531879
#[cfg(test)]

tokenizers/src/tokenizer/mod.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@ use std::{
2020
use serde::de::DeserializeOwned;
2121
use serde::{Deserialize, Serialize};
2222

23-
use crate::utils::iter::ResultShunt;
2423
use crate::utils::parallelism::*;
2524
use crate::utils::progress::{ProgressBar, ProgressStyle};
25+
use crate::{special_tokens_mapping::SpecialTokensMapping, utils::iter::ResultShunt};
2626

2727
mod added_vocabulary;
2828
mod encoding;
2929
pub mod normalizer;
3030
pub mod pattern;
3131
pub mod pre_tokenizer;
3232
mod serialization;
33+
pub mod special_tokens_mapping;
3334

3435
// Re-export wrappers
3536
pub use crate::decoders::DecoderWrapper;
@@ -293,6 +294,7 @@ pub struct TokenizerBuilder<M, N, PT, PP, D> {
293294

294295
truncation: Option<TruncationParams>,
295296
padding: Option<PaddingParams>,
297+
special_tokens_mapping: Option<SpecialTokensMapping>,
296298
}
297299

298300
impl<M, N, PT, PP, D> Default for TokenizerBuilder<M, N, PT, PP, D>
@@ -327,6 +329,7 @@ where
327329
added_vocabulary: AddedVocabulary::new(),
328330
truncation: None,
329331
padding: None,
332+
special_tokens_mapping: None,
330333
}
331334
}
332335

@@ -347,6 +350,7 @@ where
347350
added_vocabulary: self.added_vocabulary,
348351
truncation: self.truncation,
349352
padding: self.padding,
353+
special_tokens_mapping: self.special_tokens_mapping,
350354
})
351355
}
352356

@@ -404,6 +408,14 @@ where
404408
self.padding = padding;
405409
self
406410
}
411+
412+
pub fn with_special_tokens_mapping(
413+
mut self,
414+
special_tokens_mapping: Option<SpecialTokensMapping>,
415+
) -> Self {
416+
self.special_tokens_mapping = special_tokens_mapping;
417+
self
418+
}
407419
}
408420

409421
#[derive(Serialize, Deserialize, Debug, Clone)]
@@ -480,6 +492,7 @@ where
480492
added_vocabulary: t.added_vocabulary,
481493
padding: t.padding,
482494
truncation: t.truncation,
495+
special_tokens_mapping: t.special_tokens_mapping,
483496
})
484497
}
485498
}
@@ -524,6 +537,7 @@ pub struct TokenizerImpl<M, N, PT, PP, D> {
524537
// General processing parameters
525538
truncation: Option<TruncationParams>,
526539
padding: Option<PaddingParams>,
540+
special_tokens_mapping: Option<SpecialTokensMapping>,
527541
}
528542

529543
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
@@ -547,6 +561,7 @@ where
547561

548562
truncation: None,
549563
padding: None,
564+
special_tokens_mapping: None,
550565
}
551566
}
552567

@@ -654,6 +669,25 @@ where
654669
self.padding.as_ref()
655670
}
656671

672+
/// Set the special_tokens_mapping
673+
pub fn with_special_tokens_mapping(
674+
&mut self,
675+
special_tokens_mapping: Option<SpecialTokensMapping>,
676+
) -> &mut Self {
677+
self.special_tokens_mapping = special_tokens_mapping;
678+
self
679+
}
680+
681+
/// Get the currently set extra tokens
682+
pub fn get_special_tokens_mapping(&self) -> Option<&SpecialTokensMapping> {
683+
self.special_tokens_mapping.as_ref()
684+
}
685+
686+
/// Get the currently set extra tokens
687+
pub fn get_extra_token_muts(&mut self) -> Option<&mut SpecialTokensMapping> {
688+
self.special_tokens_mapping.as_mut()
689+
}
690+
657691
/// Get a mutable reference to the currently set padding parameters
658692
pub fn get_padding_mut(&mut self) -> Option<&mut PaddingParams> {
659693
self.padding.as_mut()

tokenizers/src/tokenizer/serialization.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ where
4242
tokenizer.serialize_field("post_processor", &self.post_processor)?;
4343
tokenizer.serialize_field("decoder", &self.decoder)?;
4444
tokenizer.serialize_field("model", &self.model)?;
45+
tokenizer.serialize_field("special_tokens_mapping", &self.special_tokens_mapping)?;
4546

4647
tokenizer.end()
4748
}
@@ -63,6 +64,7 @@ where
6364
"Tokenizer",
6465
&[
6566
"version",
67+
"special_tokens_mapping",
6668
"truncation",
6769
"padding",
6870
"added_tokens",
@@ -143,6 +145,9 @@ where
143145
"post_processor" => {
144146
builder = builder.with_post_processor(map.next_value()?);
145147
}
148+
"special_tokens_mapping" => {
149+
builder = builder.with_special_tokens_mapping(map.next_value()?);
150+
}
146151
_ => {}
147152
};
148153
}
@@ -221,7 +226,8 @@ mod tests {
221226
"continuing_subword_prefix": "",
222227
"max_input_chars_per_word": 100,
223228
"vocab": {}
224-
}
229+
},
230+
"special_tokens_mapping": null
225231
}"#;
226232
let tokenizer = Tokenizer::from_str(tok_json).unwrap();
227233

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use std::collections::{BTreeMap, BTreeSet};
2+
3+
use serde::Serialize;
4+
5+
#[derive(Debug, Clone, Serialize)]
6+
// A struct that represents the mapping between standard special token names like
7+
// `eos_token` or `bos_token` or `my_token` to the corresponding string tokens.
8+
//
9+
// We choose BTreeMap and set for ordered serialization + fast element check
10+
// Supports updating one entry, the whole entry
11+
// Example
12+
pub struct SpecialTokensMapping {
13+
inner: BTreeMap<String, BTreeSet<String>>,
14+
}
15+
16+
impl SpecialTokensMapping {
17+
pub fn new(inner: BTreeMap<String, BTreeSet<String>>) -> Self {
18+
Self { inner }
19+
}
20+
}

0 commit comments

Comments
 (0)