From cdaf034ca2a1aed80f660f58b48cdf2b47eba212 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Wed, 10 Dec 2025 03:36:11 -0500 Subject: [PATCH 01/31] rewrote passage chunking --- src/tevatron/retriever/arguments.py | 4 + src/tevatron/retriever/collator.py | 97 +++++++++++++++++----- src/tevatron/retriever/driver/train.py | 1 + src/tevatron/retriever/modeling/dense.py | 35 +++++++- src/tevatron/retriever/modeling/encoder.py | 55 ++++++++---- src/tevatron/retriever/trainer.py | 5 +- 6 files changed, 157 insertions(+), 40 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index 00034903..cce3285f 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -203,6 +203,10 @@ class DataArguments: metadata={"help": "padding side for the tokenizer, can be 'left' or 'right'"} ) + passage_chunk_size: int = field( + default=0, + metadata={"help": "Chunk size for chunked passage encoding with MaxSim. 0=disabled, >0=chunk size in tokens"} + ) @dataclass diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 20e02ef5..6323fcc1 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -24,7 +24,7 @@ def __call__(self, features: List[Tuple[str, List[str]]]): """ Collate function for training. :param features: list of (query, passages) tuples - :return: tokenized query_ids, passage_ids + :return: tokenized query_ids, passage_ids, [eos_positions if chunked] """ all_queries = [f[0] for f in features] all_passages = [] @@ -32,6 +32,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): all_passages.extend(f[1]) all_queries = [q[0] for q in all_queries] all_passages = [p[0] for p in all_passages] + + # Query tokenization q_collated = self.tokenizer( all_queries, padding=False, @@ -41,20 +43,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return_token_type_ids=False, add_special_tokens=True, ) - d_collated = self.tokenizer( - all_passages, - padding=False, - truncation=True, - max_length=self.data_args.passage_max_len-1 if self.data_args.append_eos_token else self.data_args.passage_max_len, - return_attention_mask=False, - return_token_type_ids=False, - add_special_tokens=True, - ) - if self.data_args.append_eos_token: q_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in q_collated['input_ids']] - d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']] - q_collated = self.tokenizer.pad( q_collated, padding=True, @@ -62,14 +52,79 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return_attention_mask=True, return_tensors='pt', ) - d_collated = self.tokenizer.pad( - d_collated, - padding=True, - pad_to_multiple_of=self.data_args.pad_to_multiple_of, - return_attention_mask=True, - return_tensors='pt', - ) - return q_collated, d_collated + + # Passage tokenization + if self.data_args.passage_chunk_size > 0: + d_collated, eos_positions = self._tokenize_chunked_passages(all_passages) + return q_collated, d_collated, eos_positions + else: + d_collated = self.tokenizer( + all_passages, + padding=False, + truncation=True, + max_length=self.data_args.passage_max_len-1 if self.data_args.append_eos_token else self.data_args.passage_max_len, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=True, + ) + if self.data_args.append_eos_token: + d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']] + d_collated = self.tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=self.data_args.pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) + return q_collated, d_collated + + def _tokenize_chunked_passages(self, passages: List[str]): + """ + Tokenize passages with EOS separators between chunks. + Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. + """ + chunk_size = self.data_args.passage_chunk_size + eos_id = self.tokenizer.eos_token_id + pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + + all_input_ids = [] + all_eos_positions = [] + + for passage in passages: + tokens = self.tokenizer.encode(passage, add_special_tokens=False) + + new_tokens = [] + eos_positions = [] + for i in range(0, max(len(tokens), 1), chunk_size): + chunk = tokens[i:i + chunk_size] + new_tokens.extend(chunk) + new_tokens.append(eos_id) + eos_positions.append(len(new_tokens) - 1) + + all_input_ids.append(new_tokens) + all_eos_positions.append(eos_positions) + + # Padding + max_len = min(max(len(ids) for ids in all_input_ids), self.data_args.passage_max_len) + if self.data_args.pad_to_multiple_of: + max_len = ((max_len + self.data_args.pad_to_multiple_of - 1) + // self.data_args.pad_to_multiple_of * self.data_args.pad_to_multiple_of) + + padded_ids, padded_mask, final_eos_positions = [], [], [] + for input_ids, eos_pos in zip(all_input_ids, all_eos_positions): + if len(input_ids) > max_len: + input_ids = input_ids[:max_len] + eos_pos = [p for p in eos_pos if p < max_len] + pad_len = max_len - len(input_ids) + padded_ids.append(input_ids + [pad_id] * pad_len) + padded_mask.append([1] * len(input_ids) + [0] * pad_len) + final_eos_positions.append(eos_pos) + + d_collated = { + 'input_ids': torch.tensor(padded_ids, dtype=torch.long), + 'attention_mask': torch.tensor(padded_mask, dtype=torch.long), + } + return d_collated, final_eos_positions @dataclass diff --git a/src/tevatron/retriever/driver/train.py b/src/tevatron/retriever/driver/train.py index 39abab45..15b13adc 100644 --- a/src/tevatron/retriever/driver/train.py +++ b/src/tevatron/retriever/driver/train.py @@ -87,6 +87,7 @@ def main(): torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation, ) + model.passage_chunk_size = data_args.passage_chunk_size train_dataset = TrainDataset(data_args) collator = TrainCollator(data_args, tokenizer) diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index 8bc50106..0cf64445 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F import logging from transformers import Qwen2_5OmniThinkerForConditionalGeneration from .encoder import EncoderModel @@ -8,14 +9,44 @@ class DenseModel(EncoderModel): + def __init__(self, encoder, pooling='cls', normalize=False, temperature=1.0): + super().__init__(encoder, pooling, normalize, temperature) + self.passage_chunk_size = 0 + self._eos_positions = None + + def set_eos_positions(self, eos_positions): + self._eos_positions = eos_positions + def encode_query(self, qry): query_hidden_states = self.encoder(**qry, return_dict=True) query_hidden_states = query_hidden_states.last_hidden_state return self._pooling(query_hidden_states, qry['attention_mask']) def encode_passage(self, psg): - # encode passage is the same as encode query - return self.encode_query(psg) + hidden_states = self.encoder(**psg, return_dict=True).last_hidden_state + + if self.passage_chunk_size > 0 and self._eos_positions is not None: + return self._encode_chunked_passage(hidden_states) + return self._pooling(hidden_states, psg['attention_mask']) + + def _encode_chunked_passage(self, hidden_states): + """Extract EOS position embeddings as chunk representations.""" + batch_size, seq_len, hidden_size = hidden_states.shape + max_chunks = max(len(pos) for pos in self._eos_positions) + + chunk_embs = torch.zeros(batch_size, max_chunks, hidden_size, + device=hidden_states.device, dtype=hidden_states.dtype) + chunk_mask = torch.zeros(batch_size, max_chunks, device=hidden_states.device) + + for i, positions in enumerate(self._eos_positions): + for j, pos in enumerate(positions): + if pos < seq_len: + chunk_embs[i, j] = hidden_states[i, pos] + chunk_mask[i, j] = 1.0 + + if self.normalize: + chunk_embs = F.normalize(chunk_embs, p=2, dim=-1) + return chunk_embs, chunk_mask def _pooling(self, last_hidden_state, attention_mask): diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index c3eedc35..56536e18 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -38,6 +38,7 @@ def __init__(self, self.pooling = pooling self.normalize = normalize self.temperature = temperature + self.passage_chunk_size = 0 self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: @@ -46,40 +47,50 @@ def __init__(self, def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None): q_reps = self.encode_query(query) if query else None - p_reps = self.encode_passage(passage) if passage else None + + # Handle chunked vs normal passage encoding + if passage is not None: + p_result = self.encode_passage(passage) + if self.passage_chunk_size > 0 and isinstance(p_result, tuple): + p_reps, chunk_mask = p_result + else: + p_reps, chunk_mask = p_result, None + else: + p_reps, chunk_mask = None, None # for inference if q_reps is None or p_reps is None: - return EncoderOutput( - q_reps=q_reps, - p_reps=p_reps - ) + return EncoderOutput(q_reps=q_reps, p_reps=p_reps) # for training if self.training: if self.is_ddp: q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) + if chunk_mask is not None: + chunk_mask = self._dist_gather_tensor(chunk_mask) - scores = self.compute_similarity(q_reps, p_reps) + if self.passage_chunk_size > 0 and chunk_mask is not None: + scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + else: + scores = self.compute_similarity(q_reps, p_reps) scores = scores.view(q_reps.size(0), -1) - target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) - target = target * (p_reps.size(0) // q_reps.size(0)) + num_psg_per_query = scores.size(1) // q_reps.size(0) + target = torch.arange(q_reps.size(0), device=scores.device, dtype=torch.long) + target = target * num_psg_per_query loss = self.compute_loss(scores / self.temperature, target) if self.is_ddp: - loss = loss * self.world_size # counter average weight reduction + loss = loss * self.world_size # for eval else: - scores = self.compute_similarity(q_reps, p_reps) + if self.passage_chunk_size > 0 and chunk_mask is not None: + scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + else: + scores = self.compute_similarity(q_reps, p_reps) loss = None - return EncoderOutput( - loss=loss, - scores=scores, - q_reps=q_reps, - p_reps=p_reps, - ) + return EncoderOutput(loss=loss, scores=scores, q_reps=q_reps, p_reps=p_reps) def encode_passage(self, psg): raise NotImplementedError('EncoderModel is an abstract class') @@ -90,6 +101,18 @@ def encode_query(self, qry): def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1)) + def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): + """ + MaxSim: max similarity between query and passage chunks. + q_reps: [Q, H], p_reps: [P, C, H], chunk_mask: [P, C] + Returns: [Q, P] + """ + chunk_scores = torch.einsum('qh,pch->qpc', q_reps, p_reps) + if chunk_mask is not None: + padding_mask = ~chunk_mask.unsqueeze(0).bool() + chunk_scores = chunk_scores.masked_fill(padding_mask, float('-inf')) + return chunk_scores.max(dim=-1).values + def compute_loss(self, scores, target): return self.cross_entropy(scores, target) diff --git a/src/tevatron/retriever/trainer.py b/src/tevatron/retriever/trainer.py index 0c6ceb58..30504361 100644 --- a/src/tevatron/retriever/trainer.py +++ b/src/tevatron/retriever/trainer.py @@ -45,7 +45,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - query, passage = inputs + query, passage, *rest = inputs + eos_positions = rest[0] if rest else None + if hasattr(model, 'set_eos_positions'): + model.set_eos_positions(eos_positions) return model(query=query, passage=passage).loss def training_step(self, *args): From 1536517cb02e628a04a02998bfe92ccdaf558c02 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Wed, 10 Dec 2025 14:05:42 -0500 Subject: [PATCH 02/31] added logic for left padding --- src/tevatron/retriever/arguments.py | 4 ++-- src/tevatron/retriever/collator.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index cce3285f..f0e7ffa2 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -199,8 +199,8 @@ class DataArguments: ) padding_side: str = field( - default='right', - metadata={"help": "padding side for the tokenizer, can be 'left' or 'right'"} + default='left', + metadata={"help": "padding side for the tokenizer, can be 'left' or 'right'. Use 'left' for last-token pooling (decoder models like Qwen/LLaMA), 'right' for cls pooling (encoder models like BERT)"} ) passage_chunk_size: int = field( diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 6323fcc1..d2435a1d 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -82,10 +82,12 @@ def _tokenize_chunked_passages(self, passages: List[str]): """ Tokenize passages with EOS separators between chunks. Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. + Respects tokenizer.padding_side for consistent padding direction. """ chunk_size = self.data_args.passage_chunk_size eos_id = self.tokenizer.eos_token_id pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + use_left_padding = self.tokenizer.padding_side == 'left' all_input_ids = [] all_eos_positions = [] @@ -115,10 +117,20 @@ def _tokenize_chunked_passages(self, passages: List[str]): if len(input_ids) > max_len: input_ids = input_ids[:max_len] eos_pos = [p for p in eos_pos if p < max_len] + pad_len = max_len - len(input_ids) - padded_ids.append(input_ids + [pad_id] * pad_len) - padded_mask.append([1] * len(input_ids) + [0] * pad_len) - final_eos_positions.append(eos_pos) + + if use_left_padding: + # Left padding: [PAD, PAD, ..., content] + padded_ids.append([pad_id] * pad_len + input_ids) + padded_mask.append([0] * pad_len + [1] * len(input_ids)) + # Adjust EOS positions: shift right by pad_len + final_eos_positions.append([p + pad_len for p in eos_pos]) + else: + # Right padding: [content, ..., PAD, PAD] + padded_ids.append(input_ids + [pad_id] * pad_len) + padded_mask.append([1] * len(input_ids) + [0] * pad_len) + final_eos_positions.append(eos_pos) d_collated = { 'input_ids': torch.tensor(padded_ids, dtype=torch.long), From dccfaf89874a313afabf6af28851996fa49fae33 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Wed, 10 Dec 2025 22:48:13 -0500 Subject: [PATCH 03/31] added search --- src/tevatron/retriever/collator.py | 88 +++++++++++++++++++++++- src/tevatron/retriever/driver/encode.py | 79 +++++++++++++++++---- src/tevatron/retriever/driver/search.py | 91 +++++++++++++++++++++++-- 3 files changed, 235 insertions(+), 23 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index d2435a1d..0eb37ac2 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -82,10 +82,15 @@ def _tokenize_chunked_passages(self, passages: List[str]): """ Tokenize passages with EOS separators between chunks. Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. - Respects tokenizer.padding_side for consistent padding direction. + + Uses the same token that tokenizer.add_special_tokens adds (e.g., <|endoftext|>) + so that query and passage use the same pooling token automatically. """ chunk_size = self.data_args.passage_chunk_size - eos_id = self.tokenizer.eos_token_id + # Get the token that tokenizer adds with add_special_tokens=True + # This ensures query and passage use the same token for pooling + sample = self.tokenizer.encode("x", add_special_tokens=True) + eos_id = sample[-1] # Last token added by tokenizer pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 use_left_padding = self.tokenizer.padding_side == 'left' @@ -289,6 +294,85 @@ def __call__(self, features): ) return content_ids, collated_inputs + +@dataclass +class ChunkedEncodeCollator: + """ + Collator for chunked passage encoding (inference/search). + Splits passages into chunks with EOS separators, similar to training. + """ + data_args: DataArguments + tokenizer: PreTrainedTokenizer + + def __call__(self, features): + """ + Collate function for chunked passage encoding. + :param features: list of (doc_id, text, image, video, audio) tuples + :return: (doc_ids, collated_inputs, eos_positions, chunk_counts) + + Uses the same token that tokenizer.add_special_tokens adds for consistency with query. + """ + doc_ids = [x[0] for x in features] + texts = [x[1] for x in features] + + chunk_size = self.data_args.passage_chunk_size + # Get the token that tokenizer adds with add_special_tokens=True + sample = self.tokenizer.encode("x", add_special_tokens=True) + eos_id = sample[-1] # Last token added by tokenizer + pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + use_left_padding = self.tokenizer.padding_side == 'left' + + all_input_ids = [] + all_eos_positions = [] + chunk_counts = [] + + for text in texts: + if text is None: + text = "" + tokens = self.tokenizer.encode(text, add_special_tokens=False) + + new_tokens = [] + eos_positions = [] + for i in range(0, max(len(tokens), 1), chunk_size): + chunk = tokens[i:i + chunk_size] + new_tokens.extend(chunk) + new_tokens.append(eos_id) + eos_positions.append(len(new_tokens) - 1) + + all_input_ids.append(new_tokens) + all_eos_positions.append(eos_positions) + chunk_counts.append(len(eos_positions)) + + # Padding + max_len = min(max(len(ids) for ids in all_input_ids), self.data_args.passage_max_len) + if self.data_args.pad_to_multiple_of: + max_len = ((max_len + self.data_args.pad_to_multiple_of - 1) // self.data_args.pad_to_multiple_of * self.data_args.pad_to_multiple_of) + + padded_ids, padded_mask, final_eos_positions = [], [], [] + for input_ids, eos_pos in zip(all_input_ids, all_eos_positions): + if len(input_ids) > max_len: + input_ids = input_ids[:max_len] + eos_pos = [p for p in eos_pos if p < max_len] + + pad_len = max_len - len(input_ids) + + if use_left_padding: + padded_ids.append([pad_id] * pad_len + input_ids) + padded_mask.append([0] * pad_len + [1] * len(input_ids)) + final_eos_positions.append([p + pad_len for p in eos_pos]) + else: + padded_ids.append(input_ids + [pad_id] * pad_len) + padded_mask.append([1] * len(input_ids) + [0] * pad_len) + final_eos_positions.append(eos_pos) + + collated_inputs = { + 'input_ids': torch.tensor(padded_ids, dtype=torch.long), + 'attention_mask': torch.tensor(padded_mask, dtype=torch.long), + } + + return doc_ids, collated_inputs, final_eos_positions, chunk_counts + + @dataclass class MultiModalEncodeCollator: """ diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 8749dfda..6908d789 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -8,6 +8,7 @@ from tqdm import tqdm import torch +import torch.nn.functional as F from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -18,7 +19,7 @@ from tevatron.retriever.arguments import ModelArguments, DataArguments, \ TevatronTrainingArguments as TrainingArguments from tevatron.retriever.dataset import EncodeDataset -from tevatron.retriever.collator import EncodeCollator +from tevatron.retriever.collator import EncodeCollator, ChunkedEncodeCollator from tevatron.retriever.modeling import EncoderOutput, DenseModel logger = logging.getLogger(__name__) @@ -78,10 +79,23 @@ def main(): data_args=data_args, ) - encode_collator = EncodeCollator( - data_args=data_args, - tokenizer=tokenizer, + # Check if using chunked passage encoding + use_chunked = ( + not data_args.encode_is_query and + data_args.passage_chunk_size > 0 ) + + if use_chunked: + logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") + encode_collator = ChunkedEncodeCollator( + data_args=data_args, + tokenizer=tokenizer, + ) + else: + encode_collator = EncodeCollator( + data_args=data_args, + tokenizer=tokenizer, + ) encode_loader = DataLoader( encode_dataset, @@ -96,23 +110,58 @@ def main(): model = model.to(training_args.device) model.eval() - for (batch_ids, batch) in tqdm(encode_loader): - lookup_indices.extend(batch_ids) + for batch in tqdm(encode_loader): with torch.amp.autocast('cuda') if training_args.fp16 or training_args.bf16 else nullcontext(): with torch.no_grad(): - for k, v in batch.items(): - batch[k] = v.to(training_args.device) - if data_args.encode_is_query: - model_output: EncoderOutput = model(query=batch) - encoded.append(model_output.q_reps.cpu().detach().numpy()) + if use_chunked: + # Chunked passage encoding + doc_ids, batch_inputs, eos_positions, chunk_counts = batch + + for k, v in batch_inputs.items(): + batch_inputs[k] = v.to(training_args.device) + + # Get hidden states from encoder + hidden_states = model.encoder(**batch_inputs, return_dict=True).last_hidden_state + # hidden_states: [batch_size, seq_len, hidden_size] + + # Extract embeddings at EOS positions + for i, (doc_id, positions) in enumerate(zip(doc_ids, eos_positions)): + for chunk_idx, pos in enumerate(positions): + if pos < hidden_states.shape[1]: + chunk_emb = hidden_states[i, pos] + + # Normalize if needed + if model.normalize: + chunk_emb = F.normalize(chunk_emb, p=2, dim=-1) + + encoded.append(chunk_emb.cpu().numpy()) + lookup_indices.append((doc_id, chunk_idx)) else: - model_output: EncoderOutput = model(passage=batch) - encoded.append(model_output.p_reps.cpu().detach().numpy()) - - encoded = np.concatenate(encoded) + # Standard query or passage encoding + batch_ids, batch_inputs = batch + lookup_indices.extend(batch_ids) + + for k, v in batch_inputs.items(): + batch_inputs[k] = v.to(training_args.device) + + if data_args.encode_is_query: + model_output: EncoderOutput = model(query=batch_inputs) + encoded.append(model_output.q_reps.cpu().detach().numpy()) + else: + model_output: EncoderOutput = model(passage=batch_inputs) + encoded.append(model_output.p_reps.cpu().detach().numpy()) + + # Combine encoded embeddings + if use_chunked: + encoded = np.stack(encoded) + logger.info(f"Encoded {len(set(d for d, c in lookup_indices))} docs into {len(lookup_indices)} chunks") + else: + encoded = np.concatenate(encoded) with open(data_args.encode_output_path, 'wb') as f: pickle.dump((encoded, lookup_indices), f) + + logger.info(f"Saved embeddings to {data_args.encode_output_path}, shape: {encoded.shape}") if __name__ == "__main__": diff --git a/src/tevatron/retriever/driver/search.py b/src/tevatron/retriever/driver/search.py index 1f374eac..5a7f84e8 100644 --- a/src/tevatron/retriever/driver/search.py +++ b/src/tevatron/retriever/driver/search.py @@ -3,6 +3,7 @@ import numpy as np import glob from argparse import ArgumentParser +from collections import defaultdict from itertools import chain from tqdm import tqdm import faiss @@ -29,6 +30,41 @@ def search_queries(retriever, q_reps, p_lookup, args): return all_scores, psg_indices +def search_queries_chunked(retriever, q_reps, p_lookup, args): + """ + Search with chunked passages and aggregate by document using MaxSim. + """ + # Search more chunks to ensure good recall after aggregation + search_depth = args.depth * args.chunk_multiplier + + if args.batch_size > 0: + all_scores, all_indices = retriever.batch_search(q_reps, search_depth, args.batch_size, args.quiet) + else: + all_scores, all_indices = retriever.search(q_reps, search_depth) + + # Aggregate by document ID using MaxSim + aggregated_results = [] + for q_idx in range(len(q_reps)): + scores = all_scores[q_idx] + indices = all_indices[q_idx] + + doc_max_scores = defaultdict(lambda: float('-inf')) + + for score, idx in zip(scores, indices): + if idx < 0: # FAISS returns -1 for insufficient results + continue + + doc_id, chunk_idx = p_lookup[idx] + # MaxSim: keep the maximum score for each document + doc_max_scores[doc_id] = max(doc_max_scores[doc_id], score) + + # Sort by score and take top-depth + sorted_docs = sorted(doc_max_scores.items(), key=lambda x: x[1], reverse=True)[:args.depth] + aggregated_results.append(sorted_docs) + + return aggregated_results + + def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): with open(ranking_save_file, 'w') as f: for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): @@ -38,6 +74,17 @@ def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): f.write(f'{qid}\t{idx}\t{s}\n') +def write_ranking_chunked(results, q_lookup, ranking_save_file): + """ + Write ranking results from chunked search. + results: List[List[Tuple[doc_id, score]]] + """ + with open(ranking_save_file, 'w') as f: + for qid, doc_scores in zip(q_lookup, results): + for doc_id, score in doc_scores: + f.write(f'{qid}\t{doc_id}\t{score}\n') + + def pickle_load(path): with open(path, 'rb') as f: reps, lookup = pickle.load(f) @@ -58,6 +105,11 @@ def main(): parser.add_argument('--save_ranking_to', required=True) parser.add_argument('--save_text', action='store_true') parser.add_argument('--quiet', action='store_true') + # Chunked search arguments + parser.add_argument('--chunked', action='store_true', + help='Enable chunked search with document-level MaxSim aggregation') + parser.add_argument('--chunk_multiplier', type=int, default=10, + help='Multiply search depth by this factor for chunked search to ensure recall') args = parser.parse_args() @@ -75,6 +127,14 @@ def main(): retriever.add(p_reps) look_up += p_lookup + # Auto-detect chunked format: lookup entries are tuples (doc_id, chunk_idx) + is_chunked = args.chunked or (len(look_up) > 0 and isinstance(look_up[0], tuple)) + + if is_chunked: + unique_docs = len(set(doc_id for doc_id, _ in look_up)) + logger.info(f"Chunked mode: {len(look_up)} chunks from {unique_docs} documents") + logger.info(f"Search depth: {args.depth} docs, chunk search depth: {args.depth * args.chunk_multiplier}") + q_reps, q_lookup = pickle_load(args.query_reps) q_reps = q_reps @@ -96,14 +156,33 @@ def main(): ngpu=num_gpus) logger.info('Index Search Start') - all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) - logger.info('Index Search Finished') - - if args.save_text: - write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) + + if is_chunked: + # Chunked search with MaxSim aggregation + aggregated_results = search_queries_chunked(retriever, q_reps, look_up, args) + logger.info('Index Search Finished (chunked mode with MaxSim aggregation)') + + if args.save_text: + write_ranking_chunked(aggregated_results, q_lookup, args.save_ranking_to) + else: + # Convert to arrays for pickle + all_scores = [] + all_doc_ids = [] + for doc_scores in aggregated_results: + scores = [s for _, s in doc_scores] + doc_ids = [d for d, _ in doc_scores] + all_scores.append(scores) + all_doc_ids.append(doc_ids) + pickle_save((all_scores, all_doc_ids), args.save_ranking_to) else: - pickle_save((all_scores, psg_indices), args.save_ranking_to) + # Standard search + all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) + logger.info('Index Search Finished') + if args.save_text: + write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) + else: + pickle_save((all_scores, psg_indices), args.save_ranking_to) if __name__ == '__main__': main() From 1769715e6e2e2f6187c406091ae80cadc558f9f2 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 00:10:16 -0500 Subject: [PATCH 04/31] changed the tokenizer logic --- src/tevatron/retriever/collator.py | 72 +++++++++++------------------- 1 file changed, 25 insertions(+), 47 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 0eb37ac2..58ada992 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -86,61 +86,39 @@ def _tokenize_chunked_passages(self, passages: List[str]): Uses the same token that tokenizer.add_special_tokens adds (e.g., <|endoftext|>) so that query and passage use the same pooling token automatically. """ - chunk_size = self.data_args.passage_chunk_size - # Get the token that tokenizer adds with add_special_tokens=True - # This ensures query and passage use the same token for pooling - sample = self.tokenizer.encode("x", add_special_tokens=True) - eos_id = sample[-1] # Last token added by tokenizer - pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - use_left_padding = self.tokenizer.padding_side == 'left' + chunk_length = self.data_args.passage_chunk_size + sep_id = 151645 # <|separator|> + eos_id = 151643 # <|endoftext|> all_input_ids = [] - all_eos_positions = [] + all_sep_positions = [] for passage in passages: - tokens = self.tokenizer.encode(passage, add_special_tokens=False) - + tokens = self.tokenizer.encode(passage, add_special_tokens=False) # There maybe some differences between models, this is for qwen3-embedding-0.6b, it only adds <|separator|> and endoftext new_tokens = [] - eos_positions = [] - for i in range(0, max(len(tokens), 1), chunk_size): - chunk = tokens[i:i + chunk_size] - new_tokens.extend(chunk) - new_tokens.append(eos_id) - eos_positions.append(len(new_tokens) - 1) - + sep_positions = [] + i = 1 + while i < len(tokens): + if i % chunk_length == 0: + new_tokens.append(sep_id) + sep_positions.append(i-1) + else: + new_tokens.append(tokens[i-1]) + i += 1 + new_tokens.append(eos_id) # edge case, what if the new_tokens[-1] is sep_id? + new_tokens.append(sep_id) + sep_positions.append(len(new_tokens)-1) all_input_ids.append(new_tokens) - all_eos_positions.append(eos_positions) + all_sep_positions.append(sep_positions) # Padding - max_len = min(max(len(ids) for ids in all_input_ids), self.data_args.passage_max_len) - if self.data_args.pad_to_multiple_of: - max_len = ((max_len + self.data_args.pad_to_multiple_of - 1) - // self.data_args.pad_to_multiple_of * self.data_args.pad_to_multiple_of) - - padded_ids, padded_mask, final_eos_positions = [], [], [] - for input_ids, eos_pos in zip(all_input_ids, all_eos_positions): - if len(input_ids) > max_len: - input_ids = input_ids[:max_len] - eos_pos = [p for p in eos_pos if p < max_len] - - pad_len = max_len - len(input_ids) - - if use_left_padding: - # Left padding: [PAD, PAD, ..., content] - padded_ids.append([pad_id] * pad_len + input_ids) - padded_mask.append([0] * pad_len + [1] * len(input_ids)) - # Adjust EOS positions: shift right by pad_len - final_eos_positions.append([p + pad_len for p in eos_pos]) - else: - # Right padding: [content, ..., PAD, PAD] - padded_ids.append(input_ids + [pad_id] * pad_len) - padded_mask.append([1] * len(input_ids) + [0] * pad_len) - final_eos_positions.append(eos_pos) - - d_collated = { - 'input_ids': torch.tensor(padded_ids, dtype=torch.long), - 'attention_mask': torch.tensor(padded_mask, dtype=torch.long), - } + d_collated = self.tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=self.data_args.pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) return d_collated, final_eos_positions From 16d604a21b3e89ec21f43b22dd074dcce863a52c Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 01:23:06 -0500 Subject: [PATCH 05/31] added train collator debug --- src/tevatron/retriever/collator.py | 35 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 58ada992..6129c1c1 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -86,7 +86,7 @@ def _tokenize_chunked_passages(self, passages: List[str]): Uses the same token that tokenizer.add_special_tokens adds (e.g., <|endoftext|>) so that query and passage use the same pooling token automatically. """ - chunk_length = self.data_args.passage_chunk_size + chunk_len = self.data_args.passage_chunk_size -1 sep_id = 151645 # <|separator|> eos_id = 151643 # <|endoftext|> @@ -94,23 +94,20 @@ def _tokenize_chunked_passages(self, passages: List[str]): all_sep_positions = [] for passage in passages: - tokens = self.tokenizer.encode(passage, add_special_tokens=False) # There maybe some differences between models, this is for qwen3-embedding-0.6b, it only adds <|separator|> and endoftext - new_tokens = [] - sep_positions = [] - i = 1 - while i < len(tokens): - if i % chunk_length == 0: - new_tokens.append(sep_id) - sep_positions.append(i-1) - else: - new_tokens.append(tokens[i-1]) - i += 1 - new_tokens.append(eos_id) # edge case, what if the new_tokens[-1] is sep_id? - new_tokens.append(sep_id) - sep_positions.append(len(new_tokens)-1) - all_input_ids.append(new_tokens) - all_sep_positions.append(sep_positions) + tokens = self.tokenizer.encode(passage, add_special_tokens=False) + tokens.append(eos_id) + ids = [] + sep_pos = [] + for i in range(0, len(tokens), chunk_len): + chunk = tokens[i:i + chunk_len] # up to 128 tokens + ids.extend(chunk) + ids.append(sep_id) # SEP at end of this chunk + sep_pos.append(len(ids) - 1) # position of SEP + + all_input_ids.append(ids) + all_sep_positions.append(sep_pos) + d_collated = {'input_ids': all_input_ids} # Padding d_collated = self.tokenizer.pad( d_collated, @@ -119,7 +116,9 @@ def _tokenize_chunked_passages(self, passages: List[str]): return_attention_mask=True, return_tensors='pt', ) - return d_collated, final_eos_positions + print(d_collated['input_ids'][0]) + print(all_sep_positions[0]) + return d_collated, all_sep_positions @dataclass From 61dbf6eb3fb2f2ac1b73d714eb11d5f72100f689 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 01:40:12 -0500 Subject: [PATCH 06/31] traincollator is done --- src/tevatron/retriever/collator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 6129c1c1..a9082379 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -116,8 +116,6 @@ def _tokenize_chunked_passages(self, passages: List[str]): return_attention_mask=True, return_tensors='pt', ) - print(d_collated['input_ids'][0]) - print(all_sep_positions[0]) return d_collated, all_sep_positions From 6ec22a9ff46694c7fd3f23a56526b80266d18863 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 02:06:33 -0500 Subject: [PATCH 07/31] fixed some comments --- src/tevatron/retriever/collator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index a9082379..cb322603 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -55,7 +55,7 @@ def __call__(self, features: List[Tuple[str, List[str]]]): # Passage tokenization if self.data_args.passage_chunk_size > 0: - d_collated, eos_positions = self._tokenize_chunked_passages(all_passages) + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) return q_collated, d_collated, eos_positions else: d_collated = self.tokenizer( @@ -78,7 +78,7 @@ def __call__(self, features: List[Tuple[str, List[str]]]): ) return q_collated, d_collated - def _tokenize_chunked_passages(self, passages: List[str]): + def _tokenize_and_pad_chunked_passages(self, passages: List[str]): """ Tokenize passages with EOS separators between chunks. Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. @@ -99,7 +99,7 @@ def _tokenize_chunked_passages(self, passages: List[str]): ids = [] sep_pos = [] for i in range(0, len(tokens), chunk_len): - chunk = tokens[i:i + chunk_len] # up to 128 tokens + chunk = tokens[i:i + chunk_len] # up to self.data_args.passage_chunk_size -1 tokens ids.extend(chunk) ids.append(sep_id) # SEP at end of this chunk sep_pos.append(len(ids) - 1) # position of SEP From 40eddf8d81710b427379e5976ba711658f056b6a Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 14:57:25 -0500 Subject: [PATCH 08/31] modified chunkedencoder --- src/tevatron/retriever/collator.py | 74 +++++++++---------------- src/tevatron/retriever/driver/encode.py | 16 +----- 2 files changed, 30 insertions(+), 60 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index cb322603..e489fa29 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -275,6 +275,7 @@ class ChunkedEncodeCollator: """ Collator for chunked passage encoding (inference/search). Splits passages into chunks with EOS separators, similar to training. + Uses the same chunking logic as TrainCollator._tokenize_and_pad_chunked_passages. """ data_args: DataArguments tokenizer: PreTrainedTokenizer @@ -283,69 +284,48 @@ def __call__(self, features): """ Collate function for chunked passage encoding. :param features: list of (doc_id, text, image, video, audio) tuples - :return: (doc_ids, collated_inputs, eos_positions, chunk_counts) - - Uses the same token that tokenizer.add_special_tokens adds for consistency with query. + :return: (doc_ids, collated_inputs, sep_positions, chunk_counts) """ doc_ids = [x[0] for x in features] texts = [x[1] for x in features] - chunk_size = self.data_args.passage_chunk_size - # Get the token that tokenizer adds with add_special_tokens=True - sample = self.tokenizer.encode("x", add_special_tokens=True) - eos_id = sample[-1] # Last token added by tokenizer - pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - use_left_padding = self.tokenizer.padding_side == 'left' + chunk_len = self.data_args.passage_chunk_size - 1 + sep_id = 151645 # <|separator|> + eos_id = 151643 # <|endoftext|> all_input_ids = [] - all_eos_positions = [] + all_sep_positions = [] chunk_counts = [] for text in texts: if text is None: text = "" tokens = self.tokenizer.encode(text, add_special_tokens=False) + tokens.append(eos_id) - new_tokens = [] - eos_positions = [] - for i in range(0, max(len(tokens), 1), chunk_size): - chunk = tokens[i:i + chunk_size] - new_tokens.extend(chunk) - new_tokens.append(eos_id) - eos_positions.append(len(new_tokens) - 1) - - all_input_ids.append(new_tokens) - all_eos_positions.append(eos_positions) - chunk_counts.append(len(eos_positions)) - - # Padding - max_len = min(max(len(ids) for ids in all_input_ids), self.data_args.passage_max_len) - if self.data_args.pad_to_multiple_of: - max_len = ((max_len + self.data_args.pad_to_multiple_of - 1) // self.data_args.pad_to_multiple_of * self.data_args.pad_to_multiple_of) - - padded_ids, padded_mask, final_eos_positions = [], [], [] - for input_ids, eos_pos in zip(all_input_ids, all_eos_positions): - if len(input_ids) > max_len: - input_ids = input_ids[:max_len] - eos_pos = [p for p in eos_pos if p < max_len] - - pad_len = max_len - len(input_ids) + ids = [] + sep_pos = [] + for i in range(0, len(tokens), chunk_len): + chunk = tokens[i:i + chunk_len] # up to passage_chunk_size - 1 tokens + ids.extend(chunk) + ids.append(sep_id) # SEP at end of this chunk + sep_pos.append(len(ids) - 1) # position of SEP - if use_left_padding: - padded_ids.append([pad_id] * pad_len + input_ids) - padded_mask.append([0] * pad_len + [1] * len(input_ids)) - final_eos_positions.append([p + pad_len for p in eos_pos]) - else: - padded_ids.append(input_ids + [pad_id] * pad_len) - padded_mask.append([1] * len(input_ids) + [0] * pad_len) - final_eos_positions.append(eos_pos) + all_input_ids.append(ids) + all_sep_positions.append(sep_pos) + chunk_counts.append(len(sep_pos)) - collated_inputs = { - 'input_ids': torch.tensor(padded_ids, dtype=torch.long), - 'attention_mask': torch.tensor(padded_mask, dtype=torch.long), - } + # Use tokenizer.pad() for consistent padding + d_collated = {'input_ids': all_input_ids} + d_collated = self.tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=self.data_args.pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) - return doc_ids, collated_inputs, final_eos_positions, chunk_counts + return doc_ids, d_collated, all_sep_positions, chunk_counts @dataclass diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 6908d789..b0c4d862 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -79,23 +79,13 @@ def main(): data_args=data_args, ) - # Check if using chunked passage encoding - use_chunked = ( - not data_args.encode_is_query and - data_args.passage_chunk_size > 0 - ) + use_chunked = not data_args.encode_is_query and data_args.passage_chunk_size > 0 if use_chunked: logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") - encode_collator = ChunkedEncodeCollator( - data_args=data_args, - tokenizer=tokenizer, - ) + encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) else: - encode_collator = EncodeCollator( - data_args=data_args, - tokenizer=tokenizer, - ) + encode_collator = EncodeCollator(data_args=data_args, tokenizer=tokenizer) encode_loader = DataLoader( encode_dataset, From 0ebcf37959f17345a5e41e12945ca10466a2549e Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 16:25:54 -0500 Subject: [PATCH 09/31] Modified forward and model --- src/tevatron/retriever/driver/encode.py | 27 +++++--------- src/tevatron/retriever/modeling/dense.py | 41 ++++++++++++---------- src/tevatron/retriever/modeling/encoder.py | 37 ++++++++++--------- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index b0c4d862..2583496d 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -104,30 +104,21 @@ def main(): with torch.amp.autocast('cuda') if training_args.fp16 or training_args.bf16 else nullcontext(): with torch.no_grad(): if use_chunked: - # Chunked passage encoding - doc_ids, batch_inputs, eos_positions, chunk_counts = batch - + doc_ids, batch_inputs, sep_positions, chunk_counts = batch for k, v in batch_inputs.items(): batch_inputs[k] = v.to(training_args.device) - # Get hidden states from encoder - hidden_states = model.encoder(**batch_inputs, return_dict=True).last_hidden_state - # hidden_states: [batch_size, seq_len, hidden_size] + # Use DenseModel's encode_passage to extract chunk embeddings + chunk_embs, chunk_mask = model.encode_passage(batch_inputs, sep_positions=sep_positions) - # Extract embeddings at EOS positions - for i, (doc_id, positions) in enumerate(zip(doc_ids, eos_positions)): - for chunk_idx, pos in enumerate(positions): - if pos < hidden_states.shape[1]: - chunk_emb = hidden_states[i, pos] - - # Normalize if needed - if model.normalize: - chunk_emb = F.normalize(chunk_emb, p=2, dim=-1) - - encoded.append(chunk_emb.cpu().numpy()) + # Flatten chunk embeddings and create lookup indices + batch_size, max_chunks, hidden_size = chunk_embs.shape + for i, doc_id in enumerate(doc_ids): + for chunk_idx in range(max_chunks): + if chunk_mask[i, chunk_idx] > 0: # Valid chunk + encoded.append(chunk_embs[i, chunk_idx].cpu().detach().numpy()) lookup_indices.append((doc_id, chunk_idx)) else: - # Standard query or passage encoding batch_ids, batch_inputs = batch lookup_indices.extend(batch_ids) diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index 0cf64445..fe817455 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -12,40 +12,45 @@ class DenseModel(EncoderModel): def __init__(self, encoder, pooling='cls', normalize=False, temperature=1.0): super().__init__(encoder, pooling, normalize, temperature) self.passage_chunk_size = 0 - self._eos_positions = None - - def set_eos_positions(self, eos_positions): - self._eos_positions = eos_positions + self.sep_positions = None def encode_query(self, qry): query_hidden_states = self.encoder(**qry, return_dict=True) query_hidden_states = query_hidden_states.last_hidden_state return self._pooling(query_hidden_states, qry['attention_mask']) - def encode_passage(self, psg): + def encode_passage(self, psg, sep_positions=None): hidden_states = self.encoder(**psg, return_dict=True).last_hidden_state - - if self.passage_chunk_size > 0 and self._eos_positions is not None: - return self._encode_chunked_passage(hidden_states) + if self.passage_chunk_size > 0: + return self._pooling_chunked(hidden_states, self.sep_positions) return self._pooling(hidden_states, psg['attention_mask']) - def _encode_chunked_passage(self, hidden_states): - """Extract EOS position embeddings as chunk representations.""" - batch_size, seq_len, hidden_size = hidden_states.shape - max_chunks = max(len(pos) for pos in self._eos_positions) + def _pooling_chunked(self, last_hidden_state, sep_positions): + batch_size, seq_len, hidden_size = last_hidden_state.shape - chunk_embs = torch.zeros(batch_size, max_chunks, hidden_size, - device=hidden_states.device, dtype=hidden_states.dtype) - chunk_mask = torch.zeros(batch_size, max_chunks, device=hidden_states.device) + if not sep_positions: + # No chunks, return empty + return torch.zeros(batch_size, 0, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype), \ + torch.zeros(batch_size, 0, device=last_hidden_state.device) - for i, positions in enumerate(self._eos_positions): + # Find max number of chunks across all passages + max_chunks = max(len(pos_list) for pos_list in sep_positions) + + chunk_embs = torch.zeros(batch_size, max_chunks, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype) + chunk_mask = torch.zeros(batch_size, max_chunks, device=last_hidden_state.device, dtype=torch.float) + + # Extract embeddings at sep_positions (this is the pooling operation for chunked passages) + for i, positions in enumerate(sep_positions): for j, pos in enumerate(positions): - if pos < seq_len: - chunk_embs[i, j] = hidden_states[i, pos] + if 0 <= pos < seq_len: + chunk_embs[i, j] = last_hidden_state[i, pos] chunk_mask[i, j] = 1.0 + else: + logger.warning(f"Position {pos} out of bounds for sequence length {seq_len} in batch {i}, chunk {j}") if self.normalize: chunk_embs = F.normalize(chunk_embs, p=2, dim=-1) + return chunk_embs, chunk_mask diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index 56536e18..1ccd678d 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -47,28 +47,24 @@ def __init__(self, def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None): q_reps = self.encode_query(query) if query else None - - # Handle chunked vs normal passage encoding - if passage is not None: - p_result = self.encode_passage(passage) - if self.passage_chunk_size > 0 and isinstance(p_result, tuple): - p_reps, chunk_mask = p_result - else: - p_reps, chunk_mask = p_result, None - else: - p_reps, chunk_mask = None, None + p_reps, chunk_mask = None, None + if passage: + p_reps = self.encode_passage(passage) + if self.passage_chunk_size > 0 and isinstance(p_reps, tuple): + p_reps, chunk_mask = p_reps # for inference if q_reps is None or p_reps is None: - return EncoderOutput(q_reps=q_reps, p_reps=p_reps) + return EncoderOutput( + q_reps=q_reps, + p_reps=p_reps + ) # for training if self.training: if self.is_ddp: q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) - if chunk_mask is not None: - chunk_mask = self._dist_gather_tensor(chunk_mask) if self.passage_chunk_size > 0 and chunk_mask is not None: scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) @@ -82,7 +78,7 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = loss = self.compute_loss(scores / self.temperature, target) if self.is_ddp: - loss = loss * self.world_size + loss = loss * self.world_size # counter average weight reduction # for eval else: if self.passage_chunk_size > 0 and chunk_mask is not None: @@ -90,7 +86,12 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = else: scores = self.compute_similarity(q_reps, p_reps) loss = None - return EncoderOutput(loss=loss, scores=scores, q_reps=q_reps, p_reps=p_reps) + return EncoderOutput( + loss=loss, + scores=scores, + q_reps=q_reps, + p_reps=p_reps, + ) def encode_passage(self, psg): raise NotImplementedError('EncoderModel is an abstract class') @@ -105,9 +106,13 @@ def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): """ MaxSim: max similarity between query and passage chunks. q_reps: [Q, H], p_reps: [P, C, H], chunk_mask: [P, C] + Q: number of queries + P: number of passages + C: number of chunks per passage + H: dimension of the embeddings Returns: [Q, P] """ - chunk_scores = torch.einsum('qh,pch->qpc', q_reps, p_reps) + chunk_scores = torch.einsum('qh,pch->qpc', q_reps, p_reps) # 第 q 个 query 和第 p 个 passage 的第 c 个 chunk 的相似度 if chunk_mask is not None: padding_mask = ~chunk_mask.unsqueeze(0).bool() chunk_scores = chunk_scores.masked_fill(padding_mask, float('-inf')) From e7e3bc368700240e26675f294ea3e241263d4c4f Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 18:27:07 -0500 Subject: [PATCH 10/31] Modified inference on chunked passage, in progress --- src/tevatron/retriever/arguments.py | 4 ++-- src/tevatron/retriever/driver/encode.py | 7 ++++--- src/tevatron/retriever/modeling/dense.py | 12 ++++++------ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index f0e7ffa2..cce3285f 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -199,8 +199,8 @@ class DataArguments: ) padding_side: str = field( - default='left', - metadata={"help": "padding side for the tokenizer, can be 'left' or 'right'. Use 'left' for last-token pooling (decoder models like Qwen/LLaMA), 'right' for cls pooling (encoder models like BERT)"} + default='right', + metadata={"help": "padding side for the tokenizer, can be 'left' or 'right'"} ) passage_chunk_size: int = field( diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 2583496d..55044691 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -10,6 +10,8 @@ import torch import torch.nn.functional as F +from rich import print + from torch.utils.data import DataLoader from transformers import AutoTokenizer from transformers import ( @@ -105,11 +107,10 @@ def main(): with torch.no_grad(): if use_chunked: doc_ids, batch_inputs, sep_positions, chunk_counts = batch + print(batch_inputs) for k, v in batch_inputs.items(): batch_inputs[k] = v.to(training_args.device) - - # Use DenseModel's encode_passage to extract chunk embeddings - chunk_embs, chunk_mask = model.encode_passage(batch_inputs, sep_positions=sep_positions) + chunk_embs, chunk_mask = model.encode_passage(batch_inputs, sep_positions) # Flatten chunk embeddings and create lookup indices batch_size, max_chunks, hidden_size = chunk_embs.shape diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index fe817455..43fec213 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -21,8 +21,8 @@ def encode_query(self, qry): def encode_passage(self, psg, sep_positions=None): hidden_states = self.encoder(**psg, return_dict=True).last_hidden_state - if self.passage_chunk_size > 0: - return self._pooling_chunked(hidden_states, self.sep_positions) + if self.passage_chunk_size > 0 and sep_positions: + return self._pooling_chunked(hidden_states, sep_positions) return self._pooling(hidden_states, psg['attention_mask']) def _pooling_chunked(self, last_hidden_state, sep_positions): @@ -36,22 +36,22 @@ def _pooling_chunked(self, last_hidden_state, sep_positions): # Find max number of chunks across all passages max_chunks = max(len(pos_list) for pos_list in sep_positions) - chunk_embs = torch.zeros(batch_size, max_chunks, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype) + chunk_reps = torch.zeros(batch_size, max_chunks, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype) chunk_mask = torch.zeros(batch_size, max_chunks, device=last_hidden_state.device, dtype=torch.float) # Extract embeddings at sep_positions (this is the pooling operation for chunked passages) for i, positions in enumerate(sep_positions): for j, pos in enumerate(positions): if 0 <= pos < seq_len: - chunk_embs[i, j] = last_hidden_state[i, pos] + chunk_reps[i, j] = last_hidden_state[i, pos] chunk_mask[i, j] = 1.0 else: logger.warning(f"Position {pos} out of bounds for sequence length {seq_len} in batch {i}, chunk {j}") if self.normalize: - chunk_embs = F.normalize(chunk_embs, p=2, dim=-1) + chunk_reps = F.normalize(chunk_reps, p=2, dim=-1) - return chunk_embs, chunk_mask + return chunk_reps, chunk_mask def _pooling(self, last_hidden_state, attention_mask): From 2d9939f2c6bebc01a1a5bb855bb54a2e02a7d632 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 14 Dec 2025 18:53:02 -0500 Subject: [PATCH 11/31] fixed a chunk size not passed to model --- src/tevatron/retriever/driver/encode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 55044691..24cfc514 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -85,6 +85,7 @@ def main(): if use_chunked: logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") + model.passage_chunk_size = data_args.passage_chunk_size encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) else: encode_collator = EncodeCollator(data_args=data_args, tokenizer=tokenizer) From ed3e302c2d49a97a5c1c751884fad514a09ef4c7 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Mon, 15 Dec 2025 14:29:40 -0500 Subject: [PATCH 12/31] changed eos to sep --- src/tevatron/retriever/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tevatron/retriever/trainer.py b/src/tevatron/retriever/trainer.py index 30504361..e759c42c 100644 --- a/src/tevatron/retriever/trainer.py +++ b/src/tevatron/retriever/trainer.py @@ -46,9 +46,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): query, passage, *rest = inputs - eos_positions = rest[0] if rest else None - if hasattr(model, 'set_eos_positions'): - model.set_eos_positions(eos_positions) + sep_positions = rest[0] if rest else None + if hasattr(model, 'sep_positions'): + model.sep_positions = sep_positions return model(query=query, passage=passage).loss def training_step(self, *args): From add3832f2071525e257658cbe42cf9f9bbb3b928 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 16 Dec 2025 00:49:27 -0500 Subject: [PATCH 13/31] added logs --- src/tevatron/retriever/collator.py | 4 +- src/tevatron/retriever/modeling/encoder.py | 62 ++++++++++++++++++++-- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index e489fa29..710cca91 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -55,8 +55,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): # Passage tokenization if self.data_args.passage_chunk_size > 0: - d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) - return q_collated, d_collated, eos_positions + d_collated, sep_positions = self._tokenize_and_pad_chunked_passages(all_passages) + return q_collated, d_collated, sep_positions else: d_collated = self.tokenizer( all_passages, diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index 1ccd678d..ed8cf123 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Dict, Optional +import os import torch import torch.distributed as dist from torch import nn, Tensor @@ -49,7 +50,21 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = q_reps = self.encode_query(query) if query else None p_reps, chunk_mask = None, None if passage: - p_reps = self.encode_passage(passage) + # If training with chunked passages, sep_positions is produced by the collator and + # attached to the model by TevatronTrainer.compute_loss(). Forward() needs to pass it + # into encode_passage() to actually get chunk reps/masks. + sep_positions = getattr(self, "sep_positions", None) + if self.passage_chunk_size > 0 and sep_positions is not None: + print(f"sep_positions: {sep_positions}") + try: + p_reps = self.encode_passage(passage, sep_positions=sep_positions) + except TypeError: + # Some models (e.g., multimodal) don't accept sep_positions. + p_reps = self.encode_passage(passage) + else: + p_reps = self.encode_passage(passage) + print(f"p_reps: {p_reps}") + print(f"type(p_reps): {type(p_reps)}") if self.passage_chunk_size > 0 and isinstance(p_reps, tuple): p_reps, chunk_mask = p_reps @@ -65,10 +80,14 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = if self.is_ddp: q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) - + print(f"passage_chunk_size: {self.passage_chunk_size}") + print(f"chunk_mask: {chunk_mask}") if self.passage_chunk_size > 0 and chunk_mask is not None: + print(f"start compute maxsim similarity==========================") scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + print(f"end compute maxsim similarity==========================") else: + print(f"start compute similarity==========================") scores = self.compute_similarity(q_reps, p_reps) scores = scores.view(q_reps.size(0), -1) @@ -116,7 +135,44 @@ def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): if chunk_mask is not None: padding_mask = ~chunk_mask.unsqueeze(0).bool() chunk_scores = chunk_scores.masked_fill(padding_mask, float('-inf')) - return chunk_scores.max(dim=-1).values + max_vals, max_idx = chunk_scores.max(dim=-1) # [Q, P], [Q, P] + + # Print argmax chunk index + (optional) original token position from sep_positions + if True: + # only log from rank-0 if DDP + if (not getattr(self, "is_ddp", False)) or getattr(self, "process_rank", 0) == 0: + sep_positions = getattr(self, "sep_positions", None) + # If DDP gathered passages, sep_positions may not align; only use when sizes match. + sep_ok = ( + isinstance(sep_positions, (list, tuple)) + and len(sep_positions) == p_reps.size(0) + ) + qn, pn = max_idx.size(0), max_idx.size(1) + for qi in range(qn): + for pi in range(pn): + ci = int(max_idx[qi, pi].item()) + # last valid chunk index for this passage (by mask) + if chunk_mask is not None: + valid = int(chunk_mask[pi].sum().item()) + last_ci = max(valid - 1, 0) + else: + last_ci = p_reps.size(1) - 1 + + if sep_ok and sep_positions[pi]: + pos_list = sep_positions[pi] + best_pos = pos_list[ci] if 0 <= ci < len(pos_list) else None + last_pos = pos_list[-1] + logger.info( + f"[maxsim] q={qi} p={pi} best_chunk={ci} best_pos={best_pos} " + f"last_chunk={last_ci} last_pos={last_pos} best_score={float(max_vals[qi, pi].item()):.6f}" + ) + else: + logger.info( + f"[maxsim] q={qi} p={pi} best_chunk={ci} last_chunk={last_ci} " + f"best_score={float(max_vals[qi, pi].item()):.6f}" + ) + + return max_vals def compute_loss(self, scores, target): return self.cross_entropy(scores, target) From 249dd9da7c69046f75f5b805a895284c9a78ea25 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Fri, 19 Dec 2025 14:22:49 -0500 Subject: [PATCH 14/31] added some scripts --- finetune.sh | 27 ++++ finetune_with_chunk.sh | 27 ++++ req.txt | 271 +++++++++++++++++++++++++++++++++++++++ run_retrieval.sh | 65 ++++++++++ run_retrieval_chunked.sh | 65 ++++++++++ 5 files changed, 455 insertions(+) create mode 100755 finetune.sh create mode 100755 finetune_with_chunk.sh create mode 100644 req.txt create mode 100755 run_retrieval.sh create mode 100755 run_retrieval_chunked.sh diff --git a/finetune.sh b/finetune.sh new file mode 100755 index 00000000..dd983c33 --- /dev/null +++ b/finetune.sh @@ -0,0 +1,27 @@ +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.train \ + --output_dir retriever-qwen3-emb-ft-chunk-1219-no-chunk-4-group-512-passage \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --do_train \ + --lora \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 50 \ + --dataset_name Tevatron/scifact \ + --dataset_split train \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --passage_prefix "" \ + --bf16 \ + --pooling last \ + --padding_side left \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 4 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 32 \ + --passage_max_len 512 \ + --num_train_epochs 10 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --gradient_accumulation_steps 1 \ + --passage_chunk_size 0 diff --git a/finetune_with_chunk.sh b/finetune_with_chunk.sh new file mode 100755 index 00000000..712fdf09 --- /dev/null +++ b/finetune_with_chunk.sh @@ -0,0 +1,27 @@ +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.train \ + --output_dir retriever-qwen3-emb-ft-chunk-1219-1 \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --do_train \ + --lora \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 50 \ + --dataset_name Tevatron/scifact \ + --dataset_split train \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --passage_prefix "" \ + --bf16 \ + --pooling last \ + --padding_side right \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 4 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 32 \ + --passage_max_len 512 \ + --num_train_epochs 10 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --gradient_accumulation_steps 1 \ + --passage_chunk_size 256 diff --git a/req.txt b/req.txt new file mode 100644 index 00000000..b033b240 --- /dev/null +++ b/req.txt @@ -0,0 +1,271 @@ +accelerate==1.10.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.0 +aiosignal==1.4.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.11.0 +astor==0.8.1 +attrs==25.4.0 +audioread==3.0.1 +Authlib==1.6.5 +av==16.0.1 +beautifulsoup4==4.14.2 +beir==2.2.0 +blake3==1.0.8 +blinker==1.9.0 +blis==1.3.0 +cachetools==6.2.1 +catalogue==2.0.10 +cbor==1.0.0 +cbor2==5.7.0 +certifi==2025.10.5 +cffi==2.0.0 +charset-normalizer==3.4.3 +click==8.2.1 +clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 +cloudpathlib==0.23.0 +cloudpickle==3.1.1 +coloredlogs==15.0.1 +compressed-tensors==0.11.0 +confection==0.1.5 +contourpy==1.3.3 +cramjam==2.11.0 +cryptography==46.0.2 +cupy-cuda12x==13.6.0 +cycler==0.12.1 +cyclopts==3.24.0 +cymem==2.0.11 +Cython==3.1.4 +datasets==2.19.0 +decorator==5.2.1 +decord==0.6.0 +deepspeed==0.18.0 +depyf==0.19.0 +dill==0.3.8 +diskcache==5.6.3 +distro==1.9.0 +dnspython==2.8.0 +docstring_parser==0.17.0 +docutils==0.22.2 +einops==0.8.1 +email-validator==2.3.0 +exceptiongroup==1.3.0 +fairscale==0.4.13 +faiss-cpu==1.12.0 +fastapi==0.119.0 +fastapi-cli==0.0.13 +fastapi-cloud-cli==0.3.1 +fastmcp==2.12.4 +fastparquet==2024.11.0 +fastrlock==0.8.3 +filelock==3.20.0 +flash_attn==2.8.3 +Flask==3.1.2 +flatbuffers==25.9.23 +fonttools==4.60.1 +frozendict==2.4.6 +frozenlist==1.8.0 +fsspec==2024.3.1 +ftfy==6.3.1 +gguf==0.17.1 +h11==0.16.0 +hf-xet==1.1.10 +hjson==3.1.0 +httpcore==1.0.9 +httptools==0.7.1 +httpx==0.28.1 +httpx-sse==0.4.3 +huggingface-hub==0.35.3 +humanfriendly==10.0 +idna==3.10 +ijson==3.4.0.post0 +iniconfig==2.3.0 +inscriptis==2.6.0 +interegular==0.3.3 +ir_datasets==0.5.11 +isodate==0.7.2 +itsdangerous==2.2.0 +Jinja2==3.1.6 +jiter==0.11.0 +joblib==1.5.2 +jsonschema==4.25.1 +jsonschema-path==0.3.4 +jsonschema-specifications==2025.9.1 +kiwisolver==1.4.9 +langcodes==3.5.0 +language_data==1.3.0 +lark==1.2.2 +lazy-object-proxy==1.12.0 +lazy_loader==0.4 +librosa==0.11.0 +llguidance==0.7.30 +llvmlite==0.44.0 +lm-format-enforcer==0.11.3 +lxml==6.0.2 +lz4==4.4.4 +marisa-trie==1.3.1 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +matplotlib==3.10.7 +mcp==1.17.0 +mdurl==0.1.2 +mistral_common==1.8.5 +ml_dtypes==0.5.3 +more-itertools==10.8.0 +mpmath==1.3.0 +msgpack==1.1.2 +msgspec==0.19.0 +multidict==6.7.0 +multiprocess==0.70.16 +murmurhash==1.0.13 +networkx==3.5 +ninja==1.13.0 +numba==0.61.2 +numpy==2.2.6 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.27.3 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.8.90 +omegaconf==2.3.0 +onnx==1.19.1 +onnxoptimizer==0.3.13 +onnxruntime==1.23.1 +openai==2.3.0 +openai-harmony==0.0.4 +openapi-core==0.19.5 +openapi-pydantic==0.5.1 +openapi-schema-validator==0.6.3 +openapi-spec-validator==0.7.2 +opencv-python==4.12.0.88 +opencv-python-headless==4.12.0.88 +orjson==3.11.3 +outlines_core==0.2.11 +packaging==25.0 +pandas==2.3.3 +parse==1.20.2 +partial-json-parser==0.2.1.1.post6 +pathable==0.4.4 +peft==0.17.1 +pillow==11.3.0 +platformdirs==4.5.0 +pluggy==1.6.0 +pooch==1.8.2 +preshed==3.0.10 +prometheus-fastapi-instrumentator==7.1.0 +prometheus_client==0.23.1 +propcache==0.4.1 +protobuf==6.32.1 +psutil==7.1.0 +py-cpuinfo==9.0.0 +pyarrow==21.0.0 +pyarrow-hotfix==0.7 +pybase64==1.4.2 +pybind11==3.0.1 +pycountry==24.6.1 +pycparser==2.23 +pydantic==2.12.0 +pydantic-extra-types==2.10.6 +pydantic-settings==2.11.0 +pydantic_core==2.41.1 +Pygments==2.19.2 +pyjnius==1.7.0 +pynndescent==0.5.13 +pyparsing==3.2.5 +pyperclip==1.11.0 +-e git+ssh://git@github.com/FarmersWrap/pyserini.git@a1995bffa243636c89029735236348c1e5206161#egg=pyserini +pytest==9.0.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-json-logger==4.0.0 +python-multipart==0.0.20 +pytrec_eval-terrier==0.5.9 +pytz==2025.2 +PyYAML==6.0.3 +pyzmq==27.1.0 +qwen-omni-utils==0.0.8 +ranx==0.3.21 +ray==2.50.0 +referencing==0.36.2 +regex==2025.9.18 +requests==2.32.5 +rfc3339-validator==0.1.4 +rich==14.2.0 +rich-rst==1.3.1 +rich-toolkit==0.15.1 +rignore==0.7.1 +rpds-py==0.27.1 +safetensors==0.6.2 +scikit-learn==1.7.2 +scipy==1.16.2 +seaborn==0.13.2 +sentence-transformers==5.1.1 +sentencepiece==0.2.1 +sentry-sdk==2.42.0 +setproctitle==1.3.7 +setuptools==80.9.0 +shellingham==1.5.4 +six==1.17.0 +smart_open==7.3.1 +sniffio==1.3.1 +soundfile==0.13.1 +soupsieve==2.8 +soxr==1.0.0 +spacy==3.8.7 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.5.1 +sse-starlette==3.0.2 +starlette==0.48.0 +sympy==1.14.0 +tabulate==0.9.0 +-e git+ssh://git@github.com/FarmersWrap/tevatron.git@add3832f2071525e257658cbe42cf9f9bbb3b928#egg=tevatron +thinc==8.3.6 +threadpoolctl==3.6.0 +tiktoken==0.12.0 +timm==1.0.20 +tokenizers==0.22.1 +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 +tqdm==4.67.1 +transformers==4.57.0 +trec-car-tools==2.6 +triton==3.4.0 +typeguard==4.4.4 +typer==0.19.2 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.2 +umap-learn==0.5.9.post2 +uniir_for_pyserini==0.1.1 +unlzw3==0.2.3 +urllib3==2.5.0 +uvicorn==0.37.0 +uvloop==0.22.1 +vllm==0.11.0 +warc3-wet==0.2.5 +warc3-wet-clueweb09==0.2.5 +wasabi==1.1.3 +watchfiles==1.1.1 +wcwidth==0.2.14 +weasel==0.4.1 +websockets==15.0.1 +Werkzeug==3.1.1 +wheel==0.45.1 +wrapt==1.17.3 +xformers==0.0.32.post1 +xgrammar==0.1.25 +xxhash==3.6.0 +yarl==1.22.0 +zlib-state==0.1.10 diff --git a/run_retrieval.sh b/run_retrieval.sh new file mode 100755 index 00000000..9ee8d347 --- /dev/null +++ b/run_retrieval.sh @@ -0,0 +1,65 @@ +output_dir=retriever-qwen3-emb-ft-chunk-1219-no-chunk-4-group-512-passage +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --query_max_len 512 \ + --dataset_name Tevatron/beir \ + --dataset_config scifact \ + --dataset_split test \ + --encode_output_path ${output_dir}/queries_scifact.pkl \ + --encode_is_query + + +# Encode corpus +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --passage_prefix "" \ + --passage_max_len 512 \ + --dataset_name Tevatron/beir-corpus \ + --dataset_config scifact \ + --dataset_split train \ + --encode_output_path ${output_dir}/corpus_scifact.pkl \ + --passage_chunk_size 0 + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 100 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 100 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query +python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-scifact-test ${output_dir}/rank.scifact.trec + +# recall_100 all 0.9767 +# ndcg_cut_10 all 0.7801 + diff --git a/run_retrieval_chunked.sh b/run_retrieval_chunked.sh new file mode 100755 index 00000000..b80ae37d --- /dev/null +++ b/run_retrieval_chunked.sh @@ -0,0 +1,65 @@ +output_dir=retriever-qwen3-emb-ft-chunk-1219-1 +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --query_max_len 512 \ + --dataset_name Tevatron/beir \ + --dataset_config scifact \ + --dataset_split test \ + --encode_output_path ${output_dir}/queries_scifact.pkl \ + --encode_is_query + + +# Encode corpus +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --passage_prefix "" \ + --passage_max_len 512 \ + --dataset_name Tevatron/beir-corpus \ + --dataset_config scifact \ + --dataset_split train \ + --encode_output_path ${output_dir}/corpus_scifact.pkl \ + --passage_chunk_size 256 + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 100 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 1000 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query +python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-scifact-test ${output_dir}/rank.scifact.trec + +# recall_100 all 0.9767 +# ndcg_cut_10 all 0.7801 + From 9efb43baeb4af0b70500af5f3273f08671a42709 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 21 Dec 2025 23:52:53 -0500 Subject: [PATCH 15/31] added tests --- pytest.ini | 4 + src/tevatron/retriever/collator.py | 69 +++- src/tevatron/retriever/driver/encode.py | 2 +- src/tevatron/retriever/modeling/encoder.py | 16 +- src/tevatron/retriever/trainer.py | 1 + tests/test_chunking.py | 460 +++++++++++++++++++++ 6 files changed, 528 insertions(+), 24 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/test_chunking.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..9b649206 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + unit: fast unit tests (no external downloads) + diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 710cca91..d7c849c0 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -5,9 +5,10 @@ from transformers import PreTrainedTokenizer, ProcessorMixin from qwen_omni_utils import process_mm_info from PIL import Image +from rich import print from tevatron.retriever.arguments import DataArguments - +torch.set_printoptions(threshold=float('inf'), linewidth=10000) logger = logging.getLogger(__name__) @@ -87,26 +88,44 @@ def _tokenize_and_pad_chunked_passages(self, passages: List[str]): so that query and passage use the same pooling token automatically. """ chunk_len = self.data_args.passage_chunk_size -1 - sep_id = 151645 # <|separator|> - eos_id = 151643 # <|endoftext|> + # sep_id = 151645 # <|separator|> + eos_id = self.tokenizer.eos_token_id + if eos_id is None: + raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") + max_length = self.data_args.passage_max_len # cap total length (incl. EOS per chunk) all_input_ids = [] all_sep_positions = [] for passage in passages: tokens = self.tokenizer.encode(passage, add_special_tokens=False) - tokens.append(eos_id) ids = [] sep_pos = [] - for i in range(0, len(tokens), chunk_len): - chunk = tokens[i:i + chunk_len] # up to self.data_args.passage_chunk_size -1 tokens + + # Build chunked ids, optionally capped by max_length (total tokens including EOS separators). + i = 0 + while i < len(tokens): + if max_length and max_length > 0: + remaining = max_length - len(ids) + # Need at least 1 slot for EOS; otherwise stop (don't add empty chunks). + if remaining <= 1: + break + take = min(chunk_len, len(tokens) - i, remaining - 1) + if take <= 0: + break + else: + take = min(chunk_len, len(tokens) - i) + + chunk = tokens[i:i + take] # up to chunk_len tokens ids.extend(chunk) - ids.append(sep_id) # SEP at end of this chunk - sep_pos.append(len(ids) - 1) # position of SEP + ids.append(eos_id) # EOS at end of this chunk + sep_pos.append(len(ids) - 1) # position of EOS (pooling position) + i += take all_input_ids.append(ids) all_sep_positions.append(sep_pos) + print(f"all_input_ids: {all_input_ids}") d_collated = {'input_ids': all_input_ids} # Padding d_collated = self.tokenizer.pad( @@ -116,6 +135,12 @@ def _tokenize_and_pad_chunked_passages(self, passages: List[str]): return_attention_mask=True, return_tensors='pt', ) + # print(f"d_collated: {d_collated['input_ids']}") + # print(f"length of d_collated: {len(d_collated['input_ids'])}") + # print(f"attention mask: {d_collated['attention_mask']}") + # print(f"length of attention mask: {len(d_collated['attention_mask'])}") + # print(f"all_sep_positions: {all_sep_positions[0]}") + # input("Press Enter to continue...") return d_collated, all_sep_positions @@ -291,7 +316,10 @@ def __call__(self, features): chunk_len = self.data_args.passage_chunk_size - 1 sep_id = 151645 # <|separator|> - eos_id = 151643 # <|endoftext|> + eos_id = self.tokenizer.eos_token_id + if eos_id is None: + raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") + max_length = self.data_args.passage_max_len # cap total length (incl. EOS per chunk) all_input_ids = [] all_sep_positions = [] @@ -301,15 +329,26 @@ def __call__(self, features): if text is None: text = "" tokens = self.tokenizer.encode(text, add_special_tokens=False) - tokens.append(eos_id) - ids = [] sep_pos = [] - for i in range(0, len(tokens), chunk_len): - chunk = tokens[i:i + chunk_len] # up to passage_chunk_size - 1 tokens + + i = 0 + while i < len(tokens): + if max_length and max_length > 0: + remaining = max_length - len(ids) + if remaining <= 1: + break + take = min(chunk_len, len(tokens) - i, remaining - 1) + if take <= 0: + break + else: + take = min(chunk_len, len(tokens) - i) + + chunk = tokens[i:i + take] # up to chunk_len tokens ids.extend(chunk) - ids.append(sep_id) # SEP at end of this chunk - sep_pos.append(len(ids) - 1) # position of SEP + ids.append(eos_id) # EOS at end of this chunk + sep_pos.append(len(ids) - 1) # position of EOS + i += take all_input_ids.append(ids) all_sep_positions.append(sep_pos) diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 24cfc514..48c33494 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -108,7 +108,7 @@ def main(): with torch.no_grad(): if use_chunked: doc_ids, batch_inputs, sep_positions, chunk_counts = batch - print(batch_inputs) + # print(batch_inputs) for k, v in batch_inputs.items(): batch_inputs[k] = v.to(training_args.device) chunk_embs, chunk_mask = model.encode_passage(batch_inputs, sep_positions) diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index ed8cf123..ff68aa29 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -55,7 +55,7 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = # into encode_passage() to actually get chunk reps/masks. sep_positions = getattr(self, "sep_positions", None) if self.passage_chunk_size > 0 and sep_positions is not None: - print(f"sep_positions: {sep_positions}") + # print(f"sep_positions: {sep_positions}") try: p_reps = self.encode_passage(passage, sep_positions=sep_positions) except TypeError: @@ -63,8 +63,8 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = p_reps = self.encode_passage(passage) else: p_reps = self.encode_passage(passage) - print(f"p_reps: {p_reps}") - print(f"type(p_reps): {type(p_reps)}") + # print(f"p_reps: {p_reps}") + # print(f"type(p_reps): {type(p_reps)}") if self.passage_chunk_size > 0 and isinstance(p_reps, tuple): p_reps, chunk_mask = p_reps @@ -80,14 +80,14 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = if self.is_ddp: q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) - print(f"passage_chunk_size: {self.passage_chunk_size}") - print(f"chunk_mask: {chunk_mask}") + # print(f"passage_chunk_size: {self.passage_chunk_size}") + # print(f"chunk_mask: {chunk_mask}") if self.passage_chunk_size > 0 and chunk_mask is not None: - print(f"start compute maxsim similarity==========================") + # print(f"start compute maxsim similarity==========================") scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) - print(f"end compute maxsim similarity==========================") + # print(f"end compute maxsim similarity==========================") else: - print(f"start compute similarity==========================") + # print(f"start compute similarity==========================") scores = self.compute_similarity(q_reps, p_reps) scores = scores.view(q_reps.size(0), -1) diff --git a/src/tevatron/retriever/trainer.py b/src/tevatron/retriever/trainer.py index e759c42c..403f8024 100644 --- a/src/tevatron/retriever/trainer.py +++ b/src/tevatron/retriever/trainer.py @@ -47,6 +47,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): query, passage, *rest = inputs sep_positions = rest[0] if rest else None + # input(f"trainer.compute_loss: sep_positions: {sep_positions}") if hasattr(model, 'sep_positions'): model.sep_positions = sep_positions return model(query=query, passage=passage).loss diff --git a/tests/test_chunking.py b/tests/test_chunking.py new file mode 100644 index 00000000..96d4f7a6 --- /dev/null +++ b/tests/test_chunking.py @@ -0,0 +1,460 @@ +import sys +from pathlib import Path + +import pytest + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + # tevatron/tests/test_chunking.py -> tevatron/ -> tevatron/src + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +def _strictly_increasing(xs): + return all(xs[i] > xs[i - 1] for i in range(1, len(xs))) + +REAL_TEXT = ( + "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " + "development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging " + "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to " + "calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in " + "preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter " + "development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white " + "matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to " + "1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both " + "times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with " + "greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed " + "higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, " + "p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- " + "0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). " + "Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and " + "preterm infants at term showed marked differences in white matter fiber organization. The data indicate that " + "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " + "development in cerebral white matter in living infants" +) + +TOKENIZER_DIR_NAME = "retriever-qwen3-emb-ft-chunk-batch-2-group-16-maxlen-512-chunk-256-eos" + + +@pytest.fixture(scope="session") +def train_tokenizer(): + """ + Use the exact tokenizer saved by the finetune_with_chunk.sh output_dir, + and mimic tevatron.retriever.driver.train's tokenizer setup. + """ + _add_tevatron_src_to_path() + from transformers import AutoTokenizer + + tok_dir = _tevatron_root() / TOKENIZER_DIR_NAME + if not tok_dir.exists(): + pytest.skip(f"local tokenizer dir not found: {tok_dir}") + + tok = AutoTokenizer.from_pretrained(str(tok_dir), local_files_only=True) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.padding_side = "right" # finetune_with_chunk.sh uses --padding_side right + return tok + + +@pytest.mark.unit +def test_train_collator_chunked_passages(train_tokenizer): + """ + Restore finetune_with_chunk.sh passage chunking scene: + - passage_max_len=512 + - passage_chunk_size=256 + - pad_to_multiple_of=16 (DataArguments default) + - padding_side=right + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + passage_max_len=512, + passage_chunk_size=256, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + # ======================================================================== + # NOTE: This test directly calls _tokenize_and_pad_chunked_passages() instead + # of collator.__call__() to test chunking in isolation. + # + # If we used collator.__call__(features) with passage_chunk_size > 0, it would return: + # (q_batch, p_batch, sep_positions) # 3-element tuple + # + # Where: + # - q_batch: dict with "input_ids" and "attention_mask" for queries + # - p_batch: dict with "input_ids" and "attention_mask" for chunked passages + # - sep_positions: list of lists, e.g., [[255, 430]] - EOS token positions per passage + # Used by the model to extract chunk embeddings via MaxSim pooling + # ======================================================================== + d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + got_unpadded_len = sum(got_mask) + + assert got_unpadded_len == 431 + assert sep_positions == [[255, 430]] + # EOS token at sep positions + assert got_ids[255] == train_tokenizer.eos_token_id + assert got_ids[430] == train_tokenizer.eos_token_id + print("length of got_ids: ", len(got_ids)) + + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, 151645, 20, 51615, 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, 304, 5382, 41434, 151645, 151643 + ] + assert got_ids == expected_ids + + # Hardcoded attention_mask: 431 ones (unpadded tokens) + 1 zero (padding) + # Padded to multiple of 16: 431 -> 432 + expected_mask = [1] * 431 + [0] * 1 + assert len(got_mask) == 432 + assert got_mask == expected_mask + # Verify attention_mask is 1 at sep_positions (EOS tokens should be attended) + assert got_mask[255] == 1 + assert got_mask[430] == 1 + + +@pytest.mark.unit +def test_chunk_size_zero_with_train_tokenizer_disables_chunking_and_truncates(train_tokenizer): + """ + With passage_chunk_size=0, TrainCollator should take the non-chunk path and + truncate passages to passage_max_len (like finetune_mldr_dev.sh). + + Hardcoded golden output: both passages are truncated to exactly 64 tokens + (passage_max_len), with no padding needed since 64 is already a multiple of 16. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + query_max_len=32, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=2, + passage_chunk_size=0, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + # ======================================================================== + # HOW features IS CONSTRUCTED: + # ======================================================================== + # features mimics what TrainDataset.__getitem__() returns. Each element is: + # (query_tuple, list_of_passage_tuples) + # + # Where: + # - query_tuple: (text, image, video, audio) - in this test, only text is used + # - list_of_passage_tuples: [(text, image, video, audio), ...] - one per passage + # + # Structure breakdown: + # - ("q1", None, None, None) = query with text="q1", no multimodal content + # - [(REAL_TEXT, ...), (REAL_TEXT, ...)] = 2 passages (train_group_size=2) + # Each passage tuple: (text=REAL_TEXT, image=None, video=None, audio=None) + # ======================================================================== + features = [ + (("q1", None, None, None), [(REAL_TEXT, None, None, None), (REAL_TEXT, None, None, None)]), + ] + + # ======================================================================== + # WHAT collator(features) RETURNS: + # ======================================================================== + # Since passage_chunk_size=0 (no chunking), TrainCollator.__call__() returns: + # (q_batch, p_batch) # 2-element tuple + # + # Where: + # q_batch: dict with PyTorch tensors for queries + # - "input_ids": tensor([[token_ids for "q1"]]) # shape: [num_queries, query_seq_len] + # - "attention_mask": tensor([[1, 1, ...]]) # shape: [num_queries, query_seq_len] + # + # p_batch: dict with PyTorch tensors for passages (FLATTENED across all queries) + # - "input_ids": tensor([ + # [token_ids for passage 1 (truncated to passage_max_len=64)], + # [token_ids for passage 2 (truncated to passage_max_len=64)] + # ]) # shape: [total_passages, passage_seq_len] + # - "attention_mask": tensor([ + # [1, 1, ..., 1], # 64 ones (no padding since 64 is multiple of 16) + # [1, 1, ..., 1] # 64 ones + # ]) # shape: [total_passages, passage_seq_len] + # + # Note: The collator flattens all passages from all queries into a single batch. + # With 1 query and train_group_size=2, we get 2 passages in p_batch. + # ======================================================================== + out = collator(features) + assert len(out) == 2 # Verify non-chunking path returns 2 elements + q_batch, p_batch = out # Unpack: q_batch (queries), p_batch (passages) + + # Hardcoded golden output (both passages are identical since same input text) + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, 151643 + ] + expected_mask = [1] * 64 # No padding needed (64 is multiple of 16) + + assert p_batch["input_ids"].shape[0] == 2 # train_group_size=2 + + for i in range(p_batch["input_ids"].shape[0]): + got_ids = p_batch["input_ids"][i].tolist() + got_mask = p_batch["attention_mask"][i].tolist() + unpadded_len = sum(got_mask) + + assert unpadded_len == 64 + assert len(got_ids) == 64 + assert len(got_mask) == 64 + assert got_ids == expected_ids + assert got_mask == expected_mask + + +@pytest.mark.unit +def test_chunking_chunk_size_equal_maxlen_is_capped_to_single_chunk(train_tokenizer): + """ + When chunk_size == max_len, chunking should be capped to exactly max_len total tokens + (incl. EOS), with exactly one EOS at the end. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + long_text = (REAL_TEXT + " ") * 20 + data_args = DataArguments( + passage_chunk_size=64, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, 151645 + ] + expected_sep_positions = [[63]] + expected_mask = [1] * 64 + + assert sum(mask) == 64 + assert len(ids) == 64 + assert sep_positions == expected_sep_positions + assert ids == expected_ids + assert ids[63] == 151645 + assert mask == expected_mask + assert _strictly_increasing(sep_positions[0]) + + +@pytest.mark.unit +def test_chunking_chunk_size_greater_than_maxlen_is_capped_to_single_chunk(train_tokenizer): + """ + When chunk_size > max_len, chunking should still be capped to exactly max_len total tokens + (incl. EOS), with exactly one EOS at the end. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + long_text = (REAL_TEXT + " ") * 20 + data_args = DataArguments( + passage_chunk_size=128, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output (same as chunk_size == max_len case) + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, 151645 + ] + expected_sep_positions = [[63]] + expected_mask = [1] * 64 + + assert sum(mask) == 64 + assert len(ids) == 64 + assert sep_positions == expected_sep_positions + assert ids == expected_ids + assert ids[63] == 151645 + assert mask == expected_mask + assert _strictly_increasing(sep_positions[0]) + + +@pytest.mark.unit +def test_chunking_short_passage_shorter_than_chunk_size(train_tokenizer): + """ + When passage is shorter than chunk_size, it should still get one chunk with EOS, + and padding should be applied to pad_to_multiple_of. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + short_text = "Hello world" + data_args = DataArguments( + passage_chunk_size=64, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([short_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: "Hello world" -> 2 tokens + 1 EOS = 3 tokens, padded to 16 + expected_ids = [9707, 1879, 151645] + [151643] * 13 # 3 content + 13 padding + expected_sep_positions = [[2]] + expected_mask = [1, 1, 1] + [0] * 13 # 3 ones + 13 zeros + + assert sum(mask) == 3 + assert len(ids) == 16 # Padded to multiple of 16 + assert sep_positions == expected_sep_positions + assert ids == expected_ids + assert ids[2] == 151645 # EOS at position 2 + assert mask == expected_mask + assert _strictly_increasing(sep_positions[0]) + + +@pytest.mark.unit +def test_chunking_passage_needs_padding_unpadded_not_multiple_of_pad_to_multiple_of(train_tokenizer): + """ + When unpadded length is not a multiple of pad_to_multiple_of, padding should be added. + This tests: unpadded_len=50, pad_to_multiple_of=16 -> padded_len=64. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + passage_chunk_size=32, + passage_max_len=50, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: 50 unpadded tokens (2 chunks: 31+1 EOS, 18+1 EOS), padded to 64 + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 151645, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, + 279, 9981, 57330, 151645 + ] + [151643] * 14 # 50 content + 14 padding + expected_sep_positions = [[31, 49]] + expected_mask = [1] * 50 + [0] * 14 # 50 ones + 14 zeros + + assert sum(mask) == 50 + assert len(ids) == 64 # Padded to multiple of 16 + assert sep_positions == expected_sep_positions + assert ids == expected_ids + assert ids[31] == 151645 # First EOS + assert ids[49] == 151645 # Second EOS + assert mask == expected_mask + assert _strictly_increasing(sep_positions[0]) + + +@pytest.mark.unit +def test_chunking_multiple_passages_different_lengths(train_tokenizer): + """ + Test batch processing with multiple passages of different lengths: + - Short passage (2 tokens) + - Medium passage (18 tokens) + - Long passage (128 tokens, multiple chunks) + All should be padded to the same length (longest unpadded length rounded up to pad_to_multiple_of). + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + texts = ["Short", REAL_TEXT[:100], REAL_TEXT] + data_args = DataArguments( + passage_chunk_size=64, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages(texts) + + # Hardcoded golden outputs + # Passage 0: "Short" -> 1 token + 1 EOS = 2 tokens, padded to 128 + expected_ids_0 = [12472, 151645] + [151643] * 126 + expected_mask_0 = [1, 1] + [0] * 126 + expected_sep_0 = [1] + + # Passage 1: REAL_TEXT[:100] -> 17 tokens + 1 EOS = 18 tokens, padded to 128 + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 1062, 151645 + ] + [151643] * 110 + expected_mask_1 = [1] * 18 + [0] * 110 + expected_sep_1 = [17] + + # Passage 2: REAL_TEXT -> 2 chunks (63+1 EOS, 63+1 EOS) = 128 tokens, no padding needed + expected_ids_2 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, 151645, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, + 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, + 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, + 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, + 151645 + ] + expected_mask_2 = [1] * 128 + expected_sep_2 = [63, 127] + + ids_0 = d_collated["input_ids"][0].tolist() + mask_0 = d_collated["attention_mask"][0].tolist() + ids_1 = d_collated["input_ids"][1].tolist() + mask_1 = d_collated["attention_mask"][1].tolist() + ids_2 = d_collated["input_ids"][2].tolist() + mask_2 = d_collated["attention_mask"][2].tolist() + + # Passage 0 assertions + assert sum(mask_0) == 2 + assert len(ids_0) == 128 + assert ids_0 == expected_ids_0 + assert mask_0 == expected_mask_0 + assert sep_positions[0] == expected_sep_0 + + # Passage 1 assertions + assert sum(mask_1) == 18 + assert len(ids_1) == 128 + assert ids_1 == expected_ids_1 + assert mask_1 == expected_mask_1 + assert sep_positions[1] == expected_sep_1 + + # Passage 2 assertions + assert sum(mask_2) == 128 + assert len(ids_2) == 128 + assert ids_2 == expected_ids_2 + assert mask_2 == expected_mask_2 + assert sep_positions[2] == expected_sep_2 + assert _strictly_increasing(sep_positions[2]) + From 9c37e291ad39b7a67b66d7f6077d8ce0818661db Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Mon, 22 Dec 2025 02:33:15 -0500 Subject: [PATCH 16/31] added tests --- src/tevatron/retriever/collator.py | 194 ++++++++------------ src/tevatron/retriever/driver/encode.py | 4 +- src/tevatron/retriever/modeling/dense.py | 18 +- src/tevatron/retriever/modeling/encoder.py | 28 +-- src/tevatron/retriever/trainer.py | 8 +- tests/test_chunking.py | 204 ++++++++++++--------- 6 files changed, 221 insertions(+), 235 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index d7c849c0..7a1ff052 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -13,6 +13,76 @@ logger = logging.getLogger(__name__) +def _tokenize_and_pad_chunked_passages( + passages: List[str], + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, +) -> Tuple[dict, List[List[int]]]: + """ + Tokenize passages with EOS separators between chunks. + Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. + + Uses the same token that tokenizer.add_special_tokens adds (e.g., <|endoftext|>) + so that query and passage use the same pooling token automatically. + + :param passages: List of passage texts to tokenize and chunk + :param tokenizer: Tokenizer to use for encoding + :param data_args: DataArguments containing chunk_size, max_len, pad_to_multiple_of + :return: Tuple of (collated_dict, eos_positions) where: + - collated_dict: dict with 'input_ids' and 'attention_mask' tensors + - eos_positions: list of lists, one per passage, containing EOS token positions + """ + chunk_len = data_args.passage_chunk_size - 1 + eos_id = tokenizer.eos_token_id + if eos_id is None: + raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") + max_length = data_args.passage_max_len # cap total length (incl. EOS per chunk) + + all_input_ids = [] + all_eos_positions = [] + + for passage in passages: + if passage is None: + passage = "" + tokens = tokenizer.encode(passage, add_special_tokens=False) + ids = [] + eos_pos = [] + + # Build chunked ids, optionally capped by max_length (total tokens including EOS separators). + i = 0 + while i < len(tokens): + if max_length and max_length > 0: + remaining = max_length - len(ids) + # Need at least 1 slot for EOS; otherwise stop (don't add empty chunks). + if remaining <= 1: + break + take = min(chunk_len, len(tokens) - i, remaining - 1) + if take <= 0: + break + else: + take = min(chunk_len, len(tokens) - i) + + chunk = tokens[i:i + take] # up to chunk_len tokens + ids.extend(chunk) + ids.append(eos_id) # EOS at end of this chunk + eos_pos.append(len(ids) - 1) # position of EOS (pooling position) + i += take + + all_input_ids.append(ids) + all_eos_positions.append(eos_pos) + + d_collated = {'input_ids': all_input_ids} + # Padding + d_collated = tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=data_args.pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) + return d_collated, all_eos_positions + + @dataclass class TrainCollator: """ @@ -56,8 +126,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): # Passage tokenization if self.data_args.passage_chunk_size > 0: - d_collated, sep_positions = self._tokenize_and_pad_chunked_passages(all_passages) - return q_collated, d_collated, sep_positions + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) + return q_collated, d_collated, eos_positions else: d_collated = self.tokenizer( all_passages, @@ -80,68 +150,7 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return q_collated, d_collated def _tokenize_and_pad_chunked_passages(self, passages: List[str]): - """ - Tokenize passages with EOS separators between chunks. - Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. - - Uses the same token that tokenizer.add_special_tokens adds (e.g., <|endoftext|>) - so that query and passage use the same pooling token automatically. - """ - chunk_len = self.data_args.passage_chunk_size -1 - # sep_id = 151645 # <|separator|> - eos_id = self.tokenizer.eos_token_id - if eos_id is None: - raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") - max_length = self.data_args.passage_max_len # cap total length (incl. EOS per chunk) - - all_input_ids = [] - all_sep_positions = [] - - for passage in passages: - tokens = self.tokenizer.encode(passage, add_special_tokens=False) - ids = [] - sep_pos = [] - - # Build chunked ids, optionally capped by max_length (total tokens including EOS separators). - i = 0 - while i < len(tokens): - if max_length and max_length > 0: - remaining = max_length - len(ids) - # Need at least 1 slot for EOS; otherwise stop (don't add empty chunks). - if remaining <= 1: - break - take = min(chunk_len, len(tokens) - i, remaining - 1) - if take <= 0: - break - else: - take = min(chunk_len, len(tokens) - i) - - chunk = tokens[i:i + take] # up to chunk_len tokens - ids.extend(chunk) - ids.append(eos_id) # EOS at end of this chunk - sep_pos.append(len(ids) - 1) # position of EOS (pooling position) - i += take - - all_input_ids.append(ids) - all_sep_positions.append(sep_pos) - - print(f"all_input_ids: {all_input_ids}") - d_collated = {'input_ids': all_input_ids} - # Padding - d_collated = self.tokenizer.pad( - d_collated, - padding=True, - pad_to_multiple_of=self.data_args.pad_to_multiple_of, - return_attention_mask=True, - return_tensors='pt', - ) - # print(f"d_collated: {d_collated['input_ids']}") - # print(f"length of d_collated: {len(d_collated['input_ids'])}") - # print(f"attention mask: {d_collated['attention_mask']}") - # print(f"length of attention mask: {len(d_collated['attention_mask'])}") - # print(f"all_sep_positions: {all_sep_positions[0]}") - # input("Press Enter to continue...") - return d_collated, all_sep_positions + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args) @dataclass @@ -309,62 +318,17 @@ def __call__(self, features): """ Collate function for chunked passage encoding. :param features: list of (doc_id, text, image, video, audio) tuples - :return: (doc_ids, collated_inputs, sep_positions, chunk_counts) + :return: (doc_ids, collated_inputs, eos_positions) """ doc_ids = [x[0] for x in features] texts = [x[1] for x in features] - chunk_len = self.data_args.passage_chunk_size - 1 - sep_id = 151645 # <|separator|> - eos_id = self.tokenizer.eos_token_id - if eos_id is None: - raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") - max_length = self.data_args.passage_max_len # cap total length (incl. EOS per chunk) - - all_input_ids = [] - all_sep_positions = [] - chunk_counts = [] - - for text in texts: - if text is None: - text = "" - tokens = self.tokenizer.encode(text, add_special_tokens=False) - ids = [] - sep_pos = [] - - i = 0 - while i < len(tokens): - if max_length and max_length > 0: - remaining = max_length - len(ids) - if remaining <= 1: - break - take = min(chunk_len, len(tokens) - i, remaining - 1) - if take <= 0: - break - else: - take = min(chunk_len, len(tokens) - i) - - chunk = tokens[i:i + take] # up to chunk_len tokens - ids.extend(chunk) - ids.append(eos_id) # EOS at end of this chunk - sep_pos.append(len(ids) - 1) # position of EOS - i += take - - all_input_ids.append(ids) - all_sep_positions.append(sep_pos) - chunk_counts.append(len(sep_pos)) - - # Use tokenizer.pad() for consistent padding - d_collated = {'input_ids': all_input_ids} - d_collated = self.tokenizer.pad( - d_collated, - padding=True, - pad_to_multiple_of=self.data_args.pad_to_multiple_of, - return_attention_mask=True, - return_tensors='pt', - ) + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts) - return doc_ids, d_collated, all_sep_positions, chunk_counts + return doc_ids, d_collated, all_eos_positions + + def _tokenize_and_pad_chunked_passages(self, passages: List[str]): + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args) @dataclass diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 48c33494..c9133c58 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -107,11 +107,11 @@ def main(): with torch.amp.autocast('cuda') if training_args.fp16 or training_args.bf16 else nullcontext(): with torch.no_grad(): if use_chunked: - doc_ids, batch_inputs, sep_positions, chunk_counts = batch + doc_ids, batch_inputs, eos_positions = batch # print(batch_inputs) for k, v in batch_inputs.items(): batch_inputs[k] = v.to(training_args.device) - chunk_embs, chunk_mask = model.encode_passage(batch_inputs, sep_positions) + chunk_embs, chunk_mask = model.encode_passage(batch_inputs, eos_positions) # Flatten chunk embeddings and create lookup indices batch_size, max_chunks, hidden_size = chunk_embs.shape diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index 43fec213..f0f222ab 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -12,35 +12,35 @@ class DenseModel(EncoderModel): def __init__(self, encoder, pooling='cls', normalize=False, temperature=1.0): super().__init__(encoder, pooling, normalize, temperature) self.passage_chunk_size = 0 - self.sep_positions = None + self.eos_positions = None def encode_query(self, qry): query_hidden_states = self.encoder(**qry, return_dict=True) query_hidden_states = query_hidden_states.last_hidden_state return self._pooling(query_hidden_states, qry['attention_mask']) - def encode_passage(self, psg, sep_positions=None): + def encode_passage(self, psg, eos_positions=None): hidden_states = self.encoder(**psg, return_dict=True).last_hidden_state - if self.passage_chunk_size > 0 and sep_positions: - return self._pooling_chunked(hidden_states, sep_positions) + if self.passage_chunk_size > 0 and eos_positions: + return self._pooling_chunked(hidden_states, eos_positions) return self._pooling(hidden_states, psg['attention_mask']) - def _pooling_chunked(self, last_hidden_state, sep_positions): + def _pooling_chunked(self, last_hidden_state, eos_positions): batch_size, seq_len, hidden_size = last_hidden_state.shape - if not sep_positions: + if not eos_positions: # No chunks, return empty return torch.zeros(batch_size, 0, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype), \ torch.zeros(batch_size, 0, device=last_hidden_state.device) # Find max number of chunks across all passages - max_chunks = max(len(pos_list) for pos_list in sep_positions) + max_chunks = max(len(pos_list) for pos_list in eos_positions) chunk_reps = torch.zeros(batch_size, max_chunks, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype) chunk_mask = torch.zeros(batch_size, max_chunks, device=last_hidden_state.device, dtype=torch.float) - # Extract embeddings at sep_positions (this is the pooling operation for chunked passages) - for i, positions in enumerate(sep_positions): + # Extract embeddings at eos_positions (this is the pooling operation for chunked passages) + for i, positions in enumerate(eos_positions): for j, pos in enumerate(positions): if 0 <= pos < seq_len: chunk_reps[i, j] = last_hidden_state[i, pos] diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index ff68aa29..8a25c556 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -50,16 +50,16 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = q_reps = self.encode_query(query) if query else None p_reps, chunk_mask = None, None if passage: - # If training with chunked passages, sep_positions is produced by the collator and + # If training with chunked passages, eos_positions is produced by the collator and # attached to the model by TevatronTrainer.compute_loss(). Forward() needs to pass it # into encode_passage() to actually get chunk reps/masks. - sep_positions = getattr(self, "sep_positions", None) - if self.passage_chunk_size > 0 and sep_positions is not None: - # print(f"sep_positions: {sep_positions}") + eos_positions = getattr(self, "eos_positions", None) + if self.passage_chunk_size > 0 and eos_positions is not None: + # print(f"eos_positions: {eos_positions}") try: - p_reps = self.encode_passage(passage, sep_positions=sep_positions) + p_reps = self.encode_passage(passage, eos_positions=eos_positions) except TypeError: - # Some models (e.g., multimodal) don't accept sep_positions. + # Some models (e.g., multimodal) don't accept eos_positions. p_reps = self.encode_passage(passage) else: p_reps = self.encode_passage(passage) @@ -137,15 +137,15 @@ def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): chunk_scores = chunk_scores.masked_fill(padding_mask, float('-inf')) max_vals, max_idx = chunk_scores.max(dim=-1) # [Q, P], [Q, P] - # Print argmax chunk index + (optional) original token position from sep_positions + # Print argmax chunk index + (optional) original token position from eos_positions if True: # only log from rank-0 if DDP if (not getattr(self, "is_ddp", False)) or getattr(self, "process_rank", 0) == 0: - sep_positions = getattr(self, "sep_positions", None) - # If DDP gathered passages, sep_positions may not align; only use when sizes match. - sep_ok = ( - isinstance(sep_positions, (list, tuple)) - and len(sep_positions) == p_reps.size(0) + eos_positions = getattr(self, "eos_positions", None) + # If DDP gathered passages, eos_positions may not align; only use when sizes match. + eos_ok = ( + isinstance(eos_positions, (list, tuple)) + and len(eos_positions) == p_reps.size(0) ) qn, pn = max_idx.size(0), max_idx.size(1) for qi in range(qn): @@ -158,8 +158,8 @@ def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): else: last_ci = p_reps.size(1) - 1 - if sep_ok and sep_positions[pi]: - pos_list = sep_positions[pi] + if eos_ok and eos_positions[pi]: + pos_list = eos_positions[pi] best_pos = pos_list[ci] if 0 <= ci < len(pos_list) else None last_pos = pos_list[-1] logger.info( diff --git a/src/tevatron/retriever/trainer.py b/src/tevatron/retriever/trainer.py index 403f8024..22beac16 100644 --- a/src/tevatron/retriever/trainer.py +++ b/src/tevatron/retriever/trainer.py @@ -46,10 +46,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): query, passage, *rest = inputs - sep_positions = rest[0] if rest else None - # input(f"trainer.compute_loss: sep_positions: {sep_positions}") - if hasattr(model, 'sep_positions'): - model.sep_positions = sep_positions + eos_positions = rest[0] if rest else None + # input(f"trainer.compute_loss: eos_positions: {eos_positions}") + if hasattr(model, 'eos_positions'): + model.eos_positions = eos_positions return model(query=query, passage=passage).loss def training_step(self, *args): diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 96d4f7a6..01965716 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -36,24 +36,18 @@ def _strictly_increasing(xs): "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " "development in cerebral white matter in living infants" ) - -TOKENIZER_DIR_NAME = "retriever-qwen3-emb-ft-chunk-batch-2-group-16-maxlen-512-chunk-256-eos" - +EOS_TOKEN_ID = 151645 +PADDING_TOKEN_ID = 151643 @pytest.fixture(scope="session") def train_tokenizer(): """ - Use the exact tokenizer saved by the finetune_with_chunk.sh output_dir, - and mimic tevatron.retriever.driver.train's tokenizer setup. + Use the Qwen 0.6B tokenizer. """ _add_tevatron_src_to_path() from transformers import AutoTokenizer - tok_dir = _tevatron_root() / TOKENIZER_DIR_NAME - if not tok_dir.exists(): - pytest.skip(f"local tokenizer dir not found: {tok_dir}") - - tok = AutoTokenizer.from_pretrained(str(tok_dir), local_files_only=True) + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id tok.padding_side = "right" # finetune_with_chunk.sh uses --padding_side right @@ -86,29 +80,29 @@ def test_train_collator_chunked_passages(train_tokenizer): # of collator.__call__() to test chunking in isolation. # # If we used collator.__call__(features) with passage_chunk_size > 0, it would return: - # (q_batch, p_batch, sep_positions) # 3-element tuple + # (q_batch, p_batch, eos_positions) # 3-element tuple # # Where: # - q_batch: dict with "input_ids" and "attention_mask" for queries # - p_batch: dict with "input_ids" and "attention_mask" for chunked passages - # - sep_positions: list of lists, e.g., [[255, 430]] - EOS token positions per passage + # - eos_positions: list of lists, e.g., [[255, 430]] - EOS token positions per passage # Used by the model to extract chunk embeddings via MaxSim pooling # ======================================================================== - d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) got_ids = d_collated["input_ids"][0].tolist() got_mask = d_collated["attention_mask"][0].tolist() got_unpadded_len = sum(got_mask) assert got_unpadded_len == 431 - assert sep_positions == [[255, 430]] - # EOS token at sep positions + assert eos_positions == [[255, 430]] + # EOS token at eos positions assert got_ids[255] == train_tokenizer.eos_token_id assert got_ids[430] == train_tokenizer.eos_token_id print("length of got_ids: ", len(got_ids)) expected_ids = [ - 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, 151645, 20, 51615, 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, 304, 5382, 41434, 151645, 151643 + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, EOS_TOKEN_ID, 20, 51615, 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, 304, 5382, 41434, EOS_TOKEN_ID, PADDING_TOKEN_ID ] assert got_ids == expected_ids @@ -117,7 +111,7 @@ def test_train_collator_chunked_passages(train_tokenizer): expected_mask = [1] * 431 + [0] * 1 assert len(got_mask) == 432 assert got_mask == expected_mask - # Verify attention_mask is 1 at sep_positions (EOS tokens should be attended) + # Verify attention_mask is 1 at eos_positions (EOS tokens should be attended) assert got_mask[255] == 1 assert got_mask[430] == 1 @@ -125,11 +119,9 @@ def test_train_collator_chunked_passages(train_tokenizer): @pytest.mark.unit def test_chunk_size_zero_with_train_tokenizer_disables_chunking_and_truncates(train_tokenizer): """ - With passage_chunk_size=0, TrainCollator should take the non-chunk path and - truncate passages to passage_max_len (like finetune_mldr_dev.sh). + With passage_chunk_size > 0, TrainCollator should take the chunking path. - Hardcoded golden output: both passages are truncated to exactly 64 tokens - (passage_max_len), with no padding needed since 64 is already a multiple of 16. + Tests chunked passages with passage_max_len=64 and passage_chunk_size=32. """ from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator @@ -141,7 +133,7 @@ def test_chunk_size_zero_with_train_tokenizer_disables_chunking_and_truncates(tr padding_side="right", append_eos_token=False, train_group_size=2, - passage_chunk_size=0, + passage_chunk_size=32, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) @@ -167,53 +159,58 @@ def test_chunk_size_zero_with_train_tokenizer_disables_chunking_and_truncates(tr # ======================================================================== # WHAT collator(features) RETURNS: # ======================================================================== - # Since passage_chunk_size=0 (no chunking), TrainCollator.__call__() returns: - # (q_batch, p_batch) # 2-element tuple + # Since passage_chunk_size > 0 (chunking enabled), TrainCollator.__call__() returns: + # (q_batch, p_batch, eos_positions) # 3-element tuple # # Where: # q_batch: dict with PyTorch tensors for queries # - "input_ids": tensor([[token_ids for "q1"]]) # shape: [num_queries, query_seq_len] # - "attention_mask": tensor([[1, 1, ...]]) # shape: [num_queries, query_seq_len] # - # p_batch: dict with PyTorch tensors for passages (FLATTENED across all queries) + # p_batch: dict with PyTorch tensors for chunked passages (FLATTENED across all queries) # - "input_ids": tensor([ - # [token_ids for passage 1 (truncated to passage_max_len=64)], - # [token_ids for passage 2 (truncated to passage_max_len=64)] + # [token_ids for passage 1 (chunked, padded to multiple of 16)], + # [token_ids for passage 2 (chunked, padded to multiple of 16)] # ]) # shape: [total_passages, passage_seq_len] # - "attention_mask": tensor([ - # [1, 1, ..., 1], # 64 ones (no padding since 64 is multiple of 16) - # [1, 1, ..., 1] # 64 ones + # [1, 1, ..., 0, 0, ...], # attention mask with padding + # [1, 1, ..., 0, 0, ...] # ]) # shape: [total_passages, passage_seq_len] # + # eos_positions: list of lists, e.g., [[31, 63], [31, 63]] - EOS token positions per passage + # Used by the model to extract chunk embeddings via MaxSim pooling + # # Note: The collator flattens all passages from all queries into a single batch. # With 1 query and train_group_size=2, we get 2 passages in p_batch. # ======================================================================== out = collator(features) - assert len(out) == 2 # Verify non-chunking path returns 2 elements - q_batch, p_batch = out # Unpack: q_batch (queries), p_batch (passages) - - # Hardcoded golden output (both passages are identical since same input text) - expected_ids = [ - 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, - 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, - 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, - 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, - 349, 2326, 151643 - ] - expected_mask = [1] * 64 # No padding needed (64 is multiple of 16) + assert len(out) == 3 # Verify chunking path returns 3 elements + q_batch, p_batch, eos_positions = out # Unpack: q_batch (queries), p_batch (passages), eos_positions assert p_batch["input_ids"].shape[0] == 2 # train_group_size=2 + assert len(eos_positions) == 2 # One list of eos positions per passage for i in range(p_batch["input_ids"].shape[0]): got_ids = p_batch["input_ids"][i].tolist() got_mask = p_batch["attention_mask"][i].tolist() unpadded_len = sum(got_mask) - assert unpadded_len == 64 + # Verify chunking structure + assert len(eos_positions[i]) > 0 # Should have at least one chunk + assert _strictly_increasing(eos_positions[i]) # EOS positions should be strictly increasing + + # Verify EOS tokens at eos positions + for eos_pos in eos_positions[i]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 # EOS tokens should be attended + eos_positions[0][0] == 31 + eos_positions[0][1] == 63 + eos_positions[1][0] == 31 + eos_positions[1][1] == 63 + # Verify padding to multiple of 16 assert len(got_ids) == 64 assert len(got_mask) == 64 - assert got_ids == expected_ids - assert got_mask == expected_mask + assert len(got_ids) == len(got_mask) @pytest.mark.unit @@ -234,7 +231,7 @@ def test_chunking_chunk_size_equal_maxlen_is_capped_to_single_chunk(train_tokeni append_eos_token=False, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([long_text]) ids = d_collated["input_ids"][0].tolist() mask = d_collated["attention_mask"][0].tolist() @@ -244,18 +241,19 @@ def test_chunking_chunk_size_equal_maxlen_is_capped_to_single_chunk(train_tokeni 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, - 349, 2326, 151645 + 349, 2326, EOS_TOKEN_ID ] - expected_sep_positions = [[63]] + expected_eos_positions = [[63]] expected_mask = [1] * 64 assert sum(mask) == 64 assert len(ids) == 64 - assert sep_positions == expected_sep_positions + assert eos_positions == expected_eos_positions assert ids == expected_ids - assert ids[63] == 151645 + assert ids[63] == EOS_TOKEN_ID + assert EOS_TOKEN_ID not in ids[0:63] # EOS token should not be in the first 63 tokens assert mask == expected_mask - assert _strictly_increasing(sep_positions[0]) + assert _strictly_increasing(eos_positions[0]) @pytest.mark.unit @@ -276,7 +274,7 @@ def test_chunking_chunk_size_greater_than_maxlen_is_capped_to_single_chunk(train append_eos_token=False, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([long_text]) ids = d_collated["input_ids"][0].tolist() mask = d_collated["attention_mask"][0].tolist() @@ -286,18 +284,18 @@ def test_chunking_chunk_size_greater_than_maxlen_is_capped_to_single_chunk(train 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, - 349, 2326, 151645 + 349, 2326, EOS_TOKEN_ID ] - expected_sep_positions = [[63]] + expected_eos_positions = [[63]] expected_mask = [1] * 64 assert sum(mask) == 64 assert len(ids) == 64 - assert sep_positions == expected_sep_positions + assert eos_positions == expected_eos_positions assert ids == expected_ids - assert ids[63] == 151645 + assert ids[63] == EOS_TOKEN_ID assert mask == expected_mask - assert _strictly_increasing(sep_positions[0]) + assert _strictly_increasing(eos_positions[0]) @pytest.mark.unit @@ -318,22 +316,22 @@ def test_chunking_short_passage_shorter_than_chunk_size(train_tokenizer): append_eos_token=False, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([short_text]) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([short_text]) ids = d_collated["input_ids"][0].tolist() mask = d_collated["attention_mask"][0].tolist() # Hardcoded golden output: "Hello world" -> 2 tokens + 1 EOS = 3 tokens, padded to 16 - expected_ids = [9707, 1879, 151645] + [151643] * 13 # 3 content + 13 padding - expected_sep_positions = [[2]] + expected_ids = [9707, 1879, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 13 # 3 content + 13 padding + expected_eos_positions = [[2]] expected_mask = [1, 1, 1] + [0] * 13 # 3 ones + 13 zeros assert sum(mask) == 3 assert len(ids) == 16 # Padded to multiple of 16 - assert sep_positions == expected_sep_positions + assert eos_positions == expected_eos_positions assert ids == expected_ids - assert ids[2] == 151645 # EOS at position 2 + assert ids[2] == EOS_TOKEN_ID # EOS at position 2 assert mask == expected_mask - assert _strictly_increasing(sep_positions[0]) + assert _strictly_increasing(eos_positions[0]) @pytest.mark.unit @@ -353,7 +351,7 @@ def test_chunking_passage_needs_padding_unpadded_not_multiple_of_pad_to_multiple append_eos_token=False, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) ids = d_collated["input_ids"][0].tolist() mask = d_collated["attention_mask"][0].tolist() @@ -361,20 +359,19 @@ def test_chunking_passage_needs_padding_unpadded_not_multiple_of_pad_to_multiple expected_ids = [ 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, - 151645, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, - 279, 9981, 57330, 151645 - ] + [151643] * 14 # 50 content + 14 padding - expected_sep_positions = [[31, 49]] + EOS_TOKEN_ID, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, + 279, 9981, 57330, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 14 # 50 content + 14 padding + expected_eos_positions = [[31, 49]] expected_mask = [1] * 50 + [0] * 14 # 50 ones + 14 zeros - assert sum(mask) == 50 assert len(ids) == 64 # Padded to multiple of 16 - assert sep_positions == expected_sep_positions + assert eos_positions == expected_eos_positions assert ids == expected_ids - assert ids[31] == 151645 # First EOS - assert ids[49] == 151645 # Second EOS + assert ids[31] == EOS_TOKEN_ID # First EOS + assert ids[49] == EOS_TOKEN_ID # Second EOS assert mask == expected_mask - assert _strictly_increasing(sep_positions[0]) + assert _strictly_increasing(eos_positions[0]) @pytest.mark.unit @@ -384,12 +381,17 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): - Short passage (2 tokens) - Medium passage (18 tokens) - Long passage (128 tokens, multiple chunks) + - Very long passage (158 tokens, multiple chunks) All should be padded to the same length (longest unpadded length rounded up to pad_to_multiple_of). """ from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator - texts = ["Short", REAL_TEXT[:100], REAL_TEXT] + # Create a passage that will result in ~158 tokens + # REAL_TEXT is ~431 tokens, so we'll use a portion of it repeated or extended + long_passage = REAL_TEXT + " " + REAL_TEXT[:200] + + texts = ["Short", REAL_TEXT[:100], REAL_TEXT, long_passage] data_args = DataArguments( passage_chunk_size=64, passage_max_len=128, @@ -398,36 +400,48 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): append_eos_token=False, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - d_collated, sep_positions = collator._tokenize_and_pad_chunked_passages(texts) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages(texts) - # Hardcoded golden outputs - # Passage 0: "Short" -> 1 token + 1 EOS = 2 tokens, padded to 128 - expected_ids_0 = [12472, 151645] + [151643] * 126 + expected_ids_0 = [12472, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 126 expected_mask_0 = [1, 1] + [0] * 126 - expected_sep_0 = [1] + expected_eos_0 = [1] - # Passage 1: REAL_TEXT[:100] -> 17 tokens + 1 EOS = 18 tokens, padded to 128 + # Passage 1: REAL_TEXT[:100] -> 17 tokens + 1 EOS = 18 tokens, padded to 160 expected_ids_1 = [ 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, - 1062, 151645 - ] + [151643] * 110 + 1062, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 110 expected_mask_1 = [1] * 18 + [0] * 110 - expected_sep_1 = [17] + expected_eos_1 = [17] - # Passage 2: REAL_TEXT -> 2 chunks (63+1 EOS, 63+1 EOS) = 128 tokens, no padding needed + # Passage 2: REAL_TEXT -> 2 chunks (63+1 EOS, 63+1 EOS) = 128 tokens, padded to 160 expected_ids_2 = [ 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, - 349, 2326, 151645, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, + 349, 2326, EOS_TOKEN_ID, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, - 151645 + EOS_TOKEN_ID ] expected_mask_2 = [1] * 128 - expected_sep_2 = [63, 127] + expected_eos_2 = [63, 127] + + expected_ids_3 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, EOS_TOKEN_ID, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, + 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, + 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, + 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, + EOS_TOKEN_ID + ] + expected_mask_3 = [1] * 128 + expected_eos_3 = [63, 127] ids_0 = d_collated["input_ids"][0].tolist() mask_0 = d_collated["attention_mask"][0].tolist() @@ -435,26 +449,34 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): mask_1 = d_collated["attention_mask"][1].tolist() ids_2 = d_collated["input_ids"][2].tolist() mask_2 = d_collated["attention_mask"][2].tolist() + ids_3 = d_collated["input_ids"][3].tolist() + mask_3 = d_collated["attention_mask"][3].tolist() # Passage 0 assertions assert sum(mask_0) == 2 assert len(ids_0) == 128 assert ids_0 == expected_ids_0 assert mask_0 == expected_mask_0 - assert sep_positions[0] == expected_sep_0 + assert eos_positions[0] == expected_eos_0 # Passage 1 assertions assert sum(mask_1) == 18 assert len(ids_1) == 128 assert ids_1 == expected_ids_1 assert mask_1 == expected_mask_1 - assert sep_positions[1] == expected_sep_1 + assert eos_positions[1] == expected_eos_1 # Passage 2 assertions assert sum(mask_2) == 128 assert len(ids_2) == 128 assert ids_2 == expected_ids_2 assert mask_2 == expected_mask_2 - assert sep_positions[2] == expected_sep_2 - assert _strictly_increasing(sep_positions[2]) - + assert eos_positions[2] == expected_eos_2 + assert _strictly_increasing(eos_positions[2]) + + # Passage 3 assertions + assert sum(mask_3) == 128 + assert len(ids_3) == 128 + assert eos_positions[3] == expected_eos_3 + assert ids_3 == expected_ids_3 + assert mask_3 == expected_mask_3 From a18c5787f6b5afcbe13f2aa70807e1196fcf1851 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Mon, 22 Dec 2025 03:52:45 -0500 Subject: [PATCH 17/31] added log --- src/tevatron/retriever/driver/encode.py | 22 ++++++++++++++++------ src/tevatron/retriever/modeling/dense.py | 12 +++++++++++- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index c9133c58..08ce2ca5 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -82,7 +82,9 @@ def main(): ) use_chunked = not data_args.encode_is_query and data_args.passage_chunk_size > 0 - + print("data_args.encode_is_query: ", data_args.encode_is_query) + print("data_args.passage_chunk_size: ", data_args.passage_chunk_size) + print("use_chunked: ", use_chunked) if use_chunked: logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") model.passage_chunk_size = data_args.passage_chunk_size @@ -108,12 +110,13 @@ def main(): with torch.no_grad(): if use_chunked: doc_ids, batch_inputs, eos_positions = batch - # print(batch_inputs) + # batch_inputs: input_ids, attention_mask for k, v in batch_inputs.items(): batch_inputs[k] = v.to(training_args.device) + print(f"eos_positions: {eos_positions}") chunk_embs, chunk_mask = model.encode_passage(batch_inputs, eos_positions) - - # Flatten chunk embeddings and create lookup indices + # chunk_embs: [batch_size, max_chunks, hidden_size] + # chunk_mask: [batch_size, max_chunks] batch_size, max_chunks, hidden_size = chunk_embs.shape for i, doc_id in enumerate(doc_ids): for chunk_idx in range(max_chunks): @@ -133,11 +136,18 @@ def main(): else: model_output: EncoderOutput = model(passage=batch_inputs) encoded.append(model_output.p_reps.cpu().detach().numpy()) - - # Combine encoded embeddings if use_chunked: + print("use_chunked: ", use_chunked) + print(f"encoded: {encoded}") + print(f"lookup_indices: {lookup_indices}") + print(f"length of encoded: {len(encoded)}") + print(f"length of lookup_indices: {len(lookup_indices)}") + # Combine encoded embeddings encoded = np.stack(encoded) logger.info(f"Encoded {len(set(d for d, c in lookup_indices))} docs into {len(lookup_indices)} chunks") + print(f"encoded.shape: {encoded.shape}") + print(f"length of encoded: {len(encoded)}") + input("Press Enter to continue...") else: encoded = np.concatenate(encoded) diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index f0f222ab..232144d4 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -5,7 +5,7 @@ from .encoder import EncoderModel logger = logging.getLogger(__name__) - +EOS_TOKEN_ID = 151645 class DenseModel(EncoderModel): @@ -20,13 +20,20 @@ def encode_query(self, qry): return self._pooling(query_hidden_states, qry['attention_mask']) def encode_passage(self, psg, eos_positions=None): + print(f"eos_positions: {eos_positions}") hidden_states = self.encoder(**psg, return_dict=True).last_hidden_state if self.passage_chunk_size > 0 and eos_positions: + for i, ep in enumerate(eos_positions): + for eos_pos in ep: + assert psg['input_ids'][i][eos_pos] == EOS_TOKEN_ID + return self._pooling_chunked(hidden_states, eos_positions) return self._pooling(hidden_states, psg['attention_mask']) def _pooling_chunked(self, last_hidden_state, eos_positions): batch_size, seq_len, hidden_size = last_hidden_state.shape + print(f"last_hidden_state.shape: {last_hidden_state.shape}") + print(f"eos_positions: {eos_positions}") if not eos_positions: # No chunks, return empty @@ -34,6 +41,9 @@ def _pooling_chunked(self, last_hidden_state, eos_positions): torch.zeros(batch_size, 0, device=last_hidden_state.device) # Find max number of chunks across all passages + for eos_pos in eos_positions: + print(f"eos_pos: {eos_pos}") + print(f"type(eos_pos): {type(eos_pos)}") max_chunks = max(len(pos_list) for pos_list in eos_positions) chunk_reps = torch.zeros(batch_size, max_chunks, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype) From f0ee78685413dfe36efe4ee8e24c6c825506fc64 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 11:30:33 -0500 Subject: [PATCH 18/31] review --- src/tevatron/retriever/collator.py | 22 + src/tevatron/retriever/driver/encode.py | 3 +- src/tevatron/retriever/driver/train.py | 1 + src/tevatron/retriever/modeling/dense.py | 5 +- tests/test_chunking.py | 349 ++++++++++++++ tests/test_chunking_pooling_equivalence.py | 153 ++++++ tests/test_forward.py | 270 +++++++++++ tests/test_pooling.py | 532 +++++++++++++++++++++ tests/test_search.py | 211 ++++++++ 9 files changed, 1543 insertions(+), 3 deletions(-) create mode 100644 tests/test_chunking_pooling_equivalence.py create mode 100644 tests/test_forward.py create mode 100644 tests/test_pooling.py create mode 100644 tests/test_search.py diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 7a1ff052..7b27b282 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -72,6 +72,14 @@ def _tokenize_and_pad_chunked_passages( all_eos_positions.append(eos_pos) d_collated = {'input_ids': all_input_ids} + + # Store original lengths before padding to adjust eos_positions for left padding + original_lengths = [len(ids) for ids in all_input_ids] + + # Set tokenizer padding_side before padding + original_padding_side = tokenizer.padding_side + tokenizer.padding_side = data_args.padding_side + # Padding d_collated = tokenizer.pad( d_collated, @@ -80,6 +88,20 @@ def _tokenize_and_pad_chunked_passages( return_attention_mask=True, return_tensors='pt', ) + + # Restore original padding_side + tokenizer.padding_side = original_padding_side + + # Adjust eos_positions for left padding + # When padding_side is 'left', padding tokens are added at the beginning, + # so EOS positions need to be shifted by the padding length + if data_args.padding_side == 'left': + padded_lengths = d_collated['input_ids'].shape[1] # All sequences have same length after padding + for i, eos_pos_list in enumerate(all_eos_positions): + padding_length = padded_lengths - original_lengths[i] + # Shift each EOS position by the padding length + all_eos_positions[i] = [pos + padding_length for pos in eos_pos_list] + return d_collated, all_eos_positions diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 08ce2ca5..59a92e97 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -54,7 +54,8 @@ def main(): ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - + + tokenizer.eos_token_id = tokenizer.pad_token_id if data_args.padding_side == 'right': tokenizer.padding_side = 'right' else: diff --git a/src/tevatron/retriever/driver/train.py b/src/tevatron/retriever/driver/train.py index 15b13adc..aaa6c163 100644 --- a/src/tevatron/retriever/driver/train.py +++ b/src/tevatron/retriever/driver/train.py @@ -64,6 +64,7 @@ def main(): model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) + tokenizer.eos_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index 232144d4..3904854b 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -5,7 +5,7 @@ from .encoder import EncoderModel logger = logging.getLogger(__name__) -EOS_TOKEN_ID = 151645 +EOS_TOKEN_ID = 151643 class DenseModel(EncoderModel): @@ -28,6 +28,7 @@ def encode_passage(self, psg, eos_positions=None): assert psg['input_ids'][i][eos_pos] == EOS_TOKEN_ID return self._pooling_chunked(hidden_states, eos_positions) + return self._pooling(hidden_states, psg['attention_mask']) def _pooling_chunked(self, last_hidden_state, eos_positions): @@ -113,4 +114,4 @@ def encode_query(self, qry): def encode_passage(self, psg): # encode passage is the same as encode query - return self.encode_query(psg) \ No newline at end of file + return self.encode_query(psg) diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 01965716..8c736858 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -480,3 +480,352 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): assert eos_positions[3] == expected_eos_3 assert ids_3 == expected_ids_3 assert mask_3 == expected_mask_3 + + +@pytest.mark.unit +def test_non_chunked_padding_side_behavior(train_tokenizer): + """ + Test non-chunked passage encoding behavior with left vs right padding. + This verifies that padding_side affects how _pooling('last'/'eos') extracts embeddings. + """ + import torch + from unittest.mock import Mock + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + from tevatron.retriever.modeling.dense import DenseModel + + # Test passage - will be truncated to max_len + test_passage = REAL_TEXT # Long passage that will be truncated + + # Test Case 1: Right padding + data_args_right = DataArguments( + passage_max_len=64, + passage_chunk_size=0, # No chunking + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + append_eos_token=False, + ) + + collator_right = TrainCollator(data_args=data_args_right, tokenizer=train_tokenizer) + q_batch_right, p_batch_right = collator_right([("query", [test_passage], [])]) + + # Verify right padding structure + input_ids_right = p_batch_right['input_ids'][0] + attention_mask_right = p_batch_right['attention_mask'][0] + seq_len_right = len(attention_mask_right) + + # With right padding, content tokens are at the beginning, padding at the end + # Last position should be padding (since passage is truncated and padded) + # Note: first position might be special token (BOS) due to add_special_tokens=True + assert attention_mask_right[-1] == 0, "Right padding: last position should be padding" + + # Last valid token position + last_valid_pos_right = attention_mask_right.sum().item() - 1 + + # Test Case 2: Left padding + data_args_left = DataArguments( + passage_max_len=64, + passage_chunk_size=0, # No chunking + pad_to_multiple_of=16, + padding_side="left", + passage_prefix="", + append_eos_token=False, + ) + + collator_left = TrainCollator(data_args=data_args_left, tokenizer=train_tokenizer) + q_batch_left, p_batch_left = collator_left([("query", [test_passage], [])]) + + # Verify left padding structure + input_ids_left = p_batch_left['input_ids'][0] + attention_mask_left = p_batch_left['attention_mask'][0] + seq_len_left = len(attention_mask_left) + + # With left padding, padding tokens are at the beginning, content at the end + # Due to pad_to_multiple_of, the actual behavior depends on content length + # Key observation: The pooling logic checks if last position is valid to determine left padding + num_valid_left = attention_mask_left.sum().item() + + # The _pooling logic: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + # If last position is 1 for all sequences, it treats it as left padding + is_detected_as_left_padding = (attention_mask_left[-1] == 1).item() + + # Verify both versions tokenized the same content (ignoring padding) + content_tokens_right = input_ids_right[attention_mask_right.bool()].tolist() + content_tokens_left = input_ids_left[attention_mask_left.bool()].tolist() + assert content_tokens_right == content_tokens_left, "Content tokens should be identical" + + # Test Case 3: Verify pooling behavior with mock model + hidden_size = 64 + + class MockEncoderOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + def mock_encoder_forward(**kwargs): + input_ids = kwargs['input_ids'] + batch_size, seq_len = input_ids.shape + # Create hidden states where each position encodes its position index + hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) + for i in range(batch_size): + for j in range(seq_len): + # Encode position j in the first dimension + hidden_states[i, j, 0] = float(j) + return MockEncoderOutput(last_hidden_state=hidden_states) + + mock_encoder = Mock(side_effect=mock_encoder_forward) + mock_encoder.config = Mock() + mock_encoder.config.hidden_size = hidden_size + + model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model.passage_chunk_size = 0 # No chunking + + # Test right padding pooling + p_reps_right = model.encode_passage(p_batch_right) + + # Test left padding pooling + p_reps_left = model.encode_passage(p_batch_left) + + # Verify pooling extracts from correct positions + # Right padding: uses sequence_lengths calculation (attention_mask.sum() - 1) + expected_pos_right = last_valid_pos_right + assert torch.allclose(p_reps_right[0, 0], torch.tensor(float(expected_pos_right))) + + # Left padding: The _pooling logic checks: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + # If last position is 1, it uses last_hidden_state[:, -1] + # Otherwise, it calculates sequence_lengths = attention_mask.sum(dim=1) - 1 + if is_detected_as_left_padding: + expected_pos_left = seq_len_left - 1 + else: + expected_pos_left = num_valid_left - 1 + assert torch.allclose(p_reps_left[0, 0], torch.tensor(float(expected_pos_left))) + + # Verify the key difference: right padding always uses sequence_lengths calculation + # Left padding uses last position if detected as left padding, otherwise sequence_lengths + # The actual positions depend on the padding structure + print(f"Right padding: extracted from position {expected_pos_right} (last_valid_pos)") + print(f"Left padding: extracted from position {expected_pos_left} (is_left_padding={is_detected_as_left_padding})") + print(f"Right padding mask: first={attention_mask_right[0].item()}, last={attention_mask_right[-1].item()}") + print(f"Left padding mask: first={attention_mask_left[0].item()}, last={attention_mask_left[-1].item()}") + + +@pytest.mark.unit +def test_chunked_passages_left_padding(train_tokenizer): + """ + Test chunked passage encoding with left padding. + This verifies that EOS positions are correctly adjusted when padding is on the left. + """ + import torch + from unittest.mock import Mock + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + from tevatron.retriever.modeling.dense import DenseModel + + # Test passage that will be chunked + test_passage = REAL_TEXT + + # Test Case 1: Right padding (baseline) + data_args_right = DataArguments( + passage_max_len=128, + passage_chunk_size=64, + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + append_eos_token=False, + ) + + collator_right = TrainCollator(data_args=data_args_right, tokenizer=train_tokenizer) + q_batch_right, p_batch_right, eos_positions_right = collator_right([("query", [test_passage], [])]) + + # Verify right padding structure + input_ids_right = p_batch_right['input_ids'][0] + attention_mask_right = p_batch_right['attention_mask'][0] + seq_len_right = len(attention_mask_right) + + # With right padding, content tokens are at the beginning, padding at the end + assert attention_mask_right[-1] == 0, "Right padding: last position should be padding" + + # Verify EOS positions are correct (should be in the content area, before padding) + for eos_pos in eos_positions_right[0]: + assert eos_pos < attention_mask_right.sum().item(), f"EOS position {eos_pos} should be in valid token range" + assert input_ids_right[eos_pos] == train_tokenizer.eos_token_id, f"Position {eos_pos} should be EOS token" + + # Test Case 2: Left padding + data_args_left = DataArguments( + passage_max_len=128, + passage_chunk_size=64, + pad_to_multiple_of=16, + padding_side="left", + passage_prefix="", + append_eos_token=False, + ) + + collator_left = TrainCollator(data_args=data_args_left, tokenizer=train_tokenizer) + q_batch_left, p_batch_left, eos_positions_left = collator_left([("query", [test_passage], [])]) + + # Verify left padding structure + input_ids_left = p_batch_left['input_ids'][0] + attention_mask_left = p_batch_left['attention_mask'][0] + seq_len_left = len(attention_mask_left) + + # With left padding, padding tokens are at the beginning, content at the end + # Note: Due to pad_to_multiple_of, the actual padding structure may vary + # Check that there is padding at the beginning + num_valid_tokens = attention_mask_left.sum().item() + padding_length = seq_len_left - num_valid_tokens + if padding_length > 0: + # If there's padding, first positions should be padding + assert attention_mask_left[0] == 0, "Left padding: first position should be padding when padding exists" + assert attention_mask_left[-1] == 1, "Left padding: last position should be content (valid token)" + + # Verify EOS positions are correctly adjusted for left padding + # EOS positions should be shifted by the padding length + + # Verify all EOS positions are in the valid token range (after padding) + for eos_pos in eos_positions_left[0]: + assert eos_pos >= padding_length, f"EOS position {eos_pos} should be after padding (padding_length={padding_length})" + assert eos_pos < seq_len_left, f"EOS position {eos_pos} should be within sequence length {seq_len_left}" + assert input_ids_left[eos_pos] == train_tokenizer.eos_token_id, f"Position {eos_pos} should be EOS token" + assert attention_mask_left[eos_pos] == 1, f"EOS position {eos_pos} should be in valid token range" + + # Verify that EOS positions are correctly shifted + # The relative positions within the content should be the same, but absolute positions differ + # Right padding: EOS at positions like [63, 127] (before padding) + # Left padding: EOS at positions like [padding_length + 63, padding_length + 127] (after padding) + assert len(eos_positions_right[0]) == len(eos_positions_left[0]), "Should have same number of chunks" + + # Verify the relative positions are preserved (EOS positions differ by padding_length) + for i, (eos_right, eos_left) in enumerate(zip(eos_positions_right[0], eos_positions_left[0])): + expected_left_pos = eos_right + padding_length + assert eos_left == expected_left_pos, \ + f"Chunk {i}: EOS position should be shifted by padding_length. " \ + f"Expected {expected_left_pos}, got {eos_left} (right={eos_right}, padding_length={padding_length})" + + # Test Case 3: Verify pooling behavior with mock model + hidden_size = 64 + + class MockEncoderOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + def mock_encoder_forward(**kwargs): + input_ids = kwargs['input_ids'] + batch_size, seq_len = input_ids.shape + # Create hidden states where each position encodes its position index + hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) + for i in range(batch_size): + for j in range(seq_len): + # Encode position j in the first dimension + hidden_states[i, j, 0] = float(j) + return MockEncoderOutput(last_hidden_state=hidden_states) + + mock_encoder = Mock(side_effect=mock_encoder_forward) + mock_encoder.config = Mock() + mock_encoder.config.hidden_size = hidden_size + + model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model.passage_chunk_size = 64 + + # Test right padding pooling + chunk_reps_right, chunk_mask_right = model.encode_passage(p_batch_right, eos_positions_right) + + # Test left padding pooling + chunk_reps_left, chunk_mask_left = model.encode_passage(p_batch_left, eos_positions_left) + + # Verify pooling extracts from correct EOS positions + # Right padding: extracts from eos_positions_right + # Left padding: extracts from eos_positions_left (which are adjusted) + assert chunk_reps_right.shape == chunk_reps_left.shape, "Should have same number of chunks" + assert chunk_mask_right.shape == chunk_mask_left.shape, "Should have same chunk mask shape" + + # Verify that embeddings are extracted from the correct positions + # For right padding, EOS at position 63 should give embedding with value 63.0 + # For left padding, EOS at position (padding_length + 63) should give embedding with value (padding_length + 63.0) + for i, (eos_right, eos_left) in enumerate(zip(eos_positions_right[0], eos_positions_left[0])): + # Right padding: embedding should encode position eos_right + assert torch.allclose(chunk_reps_right[0, i, 0], torch.tensor(float(eos_right))), \ + f"Right padding chunk {i}: embedding should encode EOS position {eos_right}" + + # Left padding: embedding should encode position eos_left + assert torch.allclose(chunk_reps_left[0, i, 0], torch.tensor(float(eos_left))), \ + f"Left padding chunk {i}: embedding should encode EOS position {eos_left}" + + # Verify masks are correct + assert chunk_mask_right[0, i] == 1.0, f"Right padding chunk {i} should be valid" + assert chunk_mask_left[0, i] == 1.0, f"Left padding chunk {i} should be valid" + + # Verify that the embeddings differ by the padding length (in the first dimension) + # This confirms that EOS positions are correctly adjusted + for i in range(len(eos_positions_right[0])): + expected_diff = float(padding_length) + actual_diff = chunk_reps_left[0, i, 0] - chunk_reps_right[0, i, 0] + assert torch.allclose(actual_diff, torch.tensor(expected_diff)), \ + f"Chunk {i}: embedding difference should equal padding_length. " \ + f"Expected {expected_diff}, got {actual_diff.item()}" + + print(f"Right padding EOS positions: {eos_positions_right[0]}") + print(f"Left padding EOS positions: {eos_positions_left[0]}") + print(f"Padding length: {padding_length}") + print(f"Sequence length: {seq_len_left}") + print(f"Valid tokens: {num_valid_tokens}") + + # Test Case 4: Verify with append_eos_token=True + data_args_right_eos = DataArguments( + passage_max_len=64, + passage_chunk_size=0, + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + append_eos_token=True, + ) + + data_args_left_eos = DataArguments( + passage_max_len=64, + passage_chunk_size=0, + pad_to_multiple_of=16, + padding_side="left", + passage_prefix="", + append_eos_token=True, + ) + + collator_right_eos = TrainCollator(data_args=data_args_right_eos, tokenizer=train_tokenizer) + collator_left_eos = TrainCollator(data_args=data_args_left_eos, tokenizer=train_tokenizer) + + q_batch_eos_right, p_batch_eos_right = collator_right_eos([("query", [test_passage], [])]) + q_batch_eos_left, p_batch_eos_left = collator_left_eos([("query", [test_passage], [])]) + + # Verify EOS token is present in both + content_right_eos = p_batch_eos_right['input_ids'][0][p_batch_eos_right['attention_mask'][0].bool()].tolist() + content_left_eos = p_batch_eos_left['input_ids'][0][p_batch_eos_left['attention_mask'][0].bool()].tolist() + + assert content_right_eos[-1] == train_tokenizer.eos_token_id + assert content_left_eos[-1] == train_tokenizer.eos_token_id + + # Test pooling with EOS + p_reps_eos_right = model.encode_passage(p_batch_eos_right) + p_reps_eos_left = model.encode_passage(p_batch_eos_left) + + # Both should extract from EOS position + mask_eos_right = p_batch_eos_right['attention_mask'][0] + mask_eos_left = p_batch_eos_left['attention_mask'][0] + + # Right padding: uses sequence_lengths calculation + last_valid_eos_right = mask_eos_right.sum().item() - 1 + + # Left padding: checks if last position is valid + is_left_padding_eos = (mask_eos_left[-1] == 1).item() + if is_left_padding_eos: + last_valid_eos_left = mask_eos_left.shape[0] - 1 + else: + last_valid_eos_left = mask_eos_left.sum().item() - 1 + + assert torch.allclose(p_reps_eos_right[0, 0], torch.tensor(float(last_valid_eos_right))) + assert torch.allclose(p_reps_eos_left[0, 0], torch.tensor(float(last_valid_eos_left))) + + # With EOS, the extracted positions should be where EOS is located + assert p_batch_eos_right['input_ids'][0][last_valid_eos_right] == train_tokenizer.eos_token_id + assert p_batch_eos_left['input_ids'][0][last_valid_eos_left] == train_tokenizer.eos_token_id + + # Summary: This test verifies that padding_side affects pooling position calculation + # Right padding: always uses attention_mask.sum() - 1 + # Left padding: uses seq_len - 1 if last position is valid, otherwise attention_mask.sum() - 1 diff --git a/tests/test_chunking_pooling_equivalence.py b/tests/test_chunking_pooling_equivalence.py new file mode 100644 index 00000000..ca2ec4a0 --- /dev/null +++ b/tests/test_chunking_pooling_equivalence.py @@ -0,0 +1,153 @@ +""" +Test to verify that when chunk_size == passage_max_len and there's only one chunk, +chunked and non-chunked modes should produce the same embeddings. +""" +import sys +from pathlib import Path +import pytest +import torch +from transformers import AutoTokenizer + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.fixture +def train_tokenizer(): + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + + +@pytest.mark.unit +def test_chunked_vs_non_chunked_when_chunk_size_equals_max_len(train_tokenizer): + """ + When chunk_size == passage_max_len and passage fits in one chunk, + chunked and non-chunked should produce identical embeddings. + """ + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator, ChunkedEncodeCollator + from tevatron.retriever.modeling.dense import DenseModel + from unittest.mock import Mock + + # Test passage that fits in one chunk + test_passage = "This is a test passage that will fit in one chunk." + + # Configuration: chunk_size == passage_max_len + passage_max_len = 64 + chunk_size = 64 # Same as max_len + + # Test Case 1: Non-chunked mode + data_args_non_chunked = DataArguments( + passage_max_len=passage_max_len, + passage_chunk_size=0, # No chunking + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + append_eos_token=False, # Default: False + ) + + collator_non_chunked = TrainCollator(data_args=data_args_non_chunked, tokenizer=train_tokenizer) + q_batch_non_chunked, p_batch_non_chunked = collator_non_chunked([("query", [test_passage], [])]) + + # Test Case 2: Chunked mode with chunk_size == max_len + data_args_chunked = DataArguments( + passage_max_len=passage_max_len, + passage_chunk_size=chunk_size, # Same as max_len + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + append_eos_token=False, # Same as non-chunked + ) + + collator_chunked = TrainCollator(data_args=data_args_chunked, tokenizer=train_tokenizer) + q_batch_chunked, p_batch_chunked, eos_positions = collator_chunked([("query", [test_passage], [])]) + + # Verify tokenization differences + input_ids_non_chunked = p_batch_non_chunked['input_ids'][0] + input_ids_chunked = p_batch_chunked['input_ids'][0] + + # Chunked mode adds EOS after chunk, non-chunked doesn't (when append_eos_token=False) + # So chunked should have one more token (the EOS) + non_chunked_content = input_ids_non_chunked[p_batch_non_chunked['attention_mask'][0].bool()].tolist() + chunked_content = input_ids_chunked[p_batch_chunked['attention_mask'][0].bool()].tolist() + + print(f"Non-chunked content tokens: {len(non_chunked_content)}") + print(f"Chunked content tokens: {len(chunked_content)}") + print(f"EOS positions: {eos_positions}") + + # Chunked should have EOS token at the end of the chunk + assert chunked_content[-1] == train_tokenizer.eos_token_id, "Chunked mode should have EOS at end" + # Non-chunked should NOT have EOS (when append_eos_token=False) + assert non_chunked_content[-1] != train_tokenizer.eos_token_id, "Non-chunked mode should NOT have EOS" + + # The content tokens (excluding EOS) should be the same + chunked_content_without_eos = chunked_content[:-1] + assert non_chunked_content == chunked_content_without_eos, "Content tokens should be identical (excluding EOS)" + + # Now test pooling behavior + hidden_size = 64 + + class MockEncoderOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + def mock_encoder_forward(**kwargs): + input_ids = kwargs['input_ids'] + batch_size, seq_len = input_ids.shape + # Create hidden states where each position encodes its position index + hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) + for i in range(batch_size): + for j in range(seq_len): + # Encode position j in the first dimension + hidden_states[i, j, 0] = float(j) + return MockEncoderOutput(last_hidden_state=hidden_states) + + mock_encoder = Mock(side_effect=mock_encoder_forward) + mock_encoder.config = Mock() + mock_encoder.config.hidden_size = hidden_size + + # Non-chunked model + model_non_chunked = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model_non_chunked.passage_chunk_size = 0 + + # Chunked model + model_chunked = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model_chunked.passage_chunk_size = chunk_size + + # Get embeddings + p_reps_non_chunked = model_non_chunked.encode_passage(p_batch_non_chunked) + p_reps_chunked_tuple = model_chunked.encode_passage(p_batch_chunked, eos_positions) + p_reps_chunked, chunk_mask = p_reps_chunked_tuple + + # Non-chunked: extracts from last content token position + # Chunked: extracts from EOS position (which is one position after last content token) + + # Get the actual positions + mask_non_chunked = p_batch_non_chunked['attention_mask'][0] + last_valid_pos_non_chunked = mask_non_chunked.sum().item() - 1 + + # Chunked: EOS position + eos_pos = eos_positions[0][0] # First (and only) chunk's EOS position + + print(f"Non-chunked extracts from position: {last_valid_pos_non_chunked}") + print(f"Chunked extracts from position: {eos_pos}") + print(f"Non-chunked embedding value: {p_reps_non_chunked[0, 0].item()}") + print(f"Chunked embedding value: {p_reps_chunked[0, 0, 0].item()}") + + # These should be DIFFERENT because they extract from different positions + # Non-chunked: last content token + # Chunked: EOS token (one position after) + assert eos_pos == last_valid_pos_non_chunked + 1, \ + f"EOS should be one position after last content token: {eos_pos} vs {last_valid_pos_non_chunked}" + + # The embeddings will be different because they're extracted from different positions + # This is the root cause of the inconsistency! + assert not torch.allclose(p_reps_non_chunked[0], p_reps_chunked[0, 0]), \ + "Embeddings should be different because they're extracted from different token positions" diff --git a/tests/test_forward.py b/tests/test_forward.py new file mode 100644 index 00000000..50ec91d9 --- /dev/null +++ b/tests/test_forward.py @@ -0,0 +1,270 @@ +import sys +from pathlib import Path + +import pytest +import torch +from unittest.mock import Mock + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + # tevatron/tests/test_forward.py -> tevatron/ -> tevatron/src + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.fixture(scope="session") +def train_tokenizer(): + """ + Use the Qwen 0.6B tokenizer. + """ + _add_tevatron_src_to_path() + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.padding_side = "right" + return tok + + +@pytest.mark.unit +def test_compute_maxsim_similarity(): + """ + Test compute_maxsim_similarity function to verify MaxSim pooling logic. + """ + _add_tevatron_src_to_path() + from tevatron.retriever.modeling.encoder import EncoderModel + + # Create a concrete implementation for testing + class TestEncoderModel(EncoderModel): + def encode_query(self, qry): + raise NotImplementedError + def encode_passage(self, psg): + raise NotImplementedError + + model = TestEncoderModel(encoder=Mock(), pooling='last', normalize=False) + + # Test Case 1: Basic MaxSim computation + # Q=2 queries, P=3 passages, C=4 chunks per passage, H=8 hidden size + Q, P, C, H = 2, 3, 4, 8 + + q_reps = torch.randn(Q, H) + p_reps = torch.randn(P, C, H) + chunk_mask = torch.ones(P, C) # All chunks valid + + scores = model.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + + # Verify output shape + assert scores.shape == (Q, P) + + # Verify scores are computed correctly + # For each query-passage pair, score should be max of chunk similarities + for q_idx in range(Q): + for p_idx in range(P): + # Compute chunk scores manually + chunk_scores = torch.einsum('h,ch->c', q_reps[q_idx], p_reps[p_idx]) + expected_score = chunk_scores.max().item() + assert torch.allclose(scores[q_idx, p_idx], torch.tensor(expected_score)) + + # Test Case 2: With padding (some chunks are invalid) + chunk_mask_padded = torch.tensor([ + [1.0, 1.0, 1.0, 0.0], # Passage 0: 3 valid chunks + [1.0, 1.0, 0.0, 0.0], # Passage 1: 2 valid chunks + [1.0, 0.0, 0.0, 0.0], # Passage 2: 1 valid chunk + ]) + + scores_padded = model.compute_maxsim_similarity(q_reps, p_reps, chunk_mask_padded) + + # Verify shape + assert scores_padded.shape == (Q, P) + + # Verify that padding chunks don't affect the max + for q_idx in range(Q): + for p_idx in range(P): + # Compute chunk scores manually, masking out invalid chunks + chunk_scores = torch.einsum('h,ch->c', q_reps[q_idx], p_reps[p_idx]) + # Mask invalid chunks with -inf + valid_mask = chunk_mask_padded[p_idx].bool() + chunk_scores_masked = chunk_scores.clone() + chunk_scores_masked[~valid_mask] = float('-inf') + expected_score = chunk_scores_masked.max().item() + assert torch.allclose(scores_padded[q_idx, p_idx], torch.tensor(expected_score)) + + # Test Case 3: Single chunk per passage + P_single, C_single = 2, 1 + p_reps_single = torch.randn(P_single, C_single, H) + chunk_mask_single = torch.ones(P_single, C_single) + + scores_single = model.compute_maxsim_similarity(q_reps, p_reps_single, chunk_mask_single) + assert scores_single.shape == (Q, P_single) + + # With single chunk, MaxSim should equal the single chunk similarity + for q_idx in range(Q): + for p_idx in range(P_single): + expected_score = torch.dot(q_reps[q_idx], p_reps_single[p_idx, 0]).item() + assert torch.allclose(scores_single[q_idx, p_idx], torch.tensor(expected_score)) + + # Test Case 4: Different number of chunks per passage + # This tests that max_chunks is handled correctly + p_reps_uneven = torch.randn(P, C, H) + # Passage 0: all 4 chunks valid + # Passage 1: first 2 chunks valid + # Passage 2: first 1 chunk valid + chunk_mask_uneven = torch.tensor([ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + ]) + + scores_uneven = model.compute_maxsim_similarity(q_reps, p_reps_uneven, chunk_mask_uneven) + assert scores_uneven.shape == (Q, P) + + # Verify that only valid chunks are considered + for q_idx in range(Q): + for p_idx in range(P): + chunk_scores = torch.einsum('h,ch->c', q_reps[q_idx], p_reps_uneven[p_idx]) + valid_mask = chunk_mask_uneven[p_idx].bool() + chunk_scores_masked = chunk_scores.clone() + chunk_scores_masked[~valid_mask] = float('-inf') + expected_score = chunk_scores_masked.max().item() + assert torch.allclose(scores_uneven[q_idx, p_idx], torch.tensor(expected_score)) + + +@pytest.mark.unit +def test_forward_with_chunking(train_tokenizer): + """ + Test model forward function with chunked passages. + This tests the integration of encode_query, encode_passage, and compute_maxsim_similarity. + """ + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + from tevatron.retriever.modeling.dense import DenseModel + + REAL_TEXT = ( + "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " + "development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging " + "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient." + ) + + # Setup data arguments + data_args = DataArguments( + passage_chunk_size=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + query_prefix="", + append_eos_token=False, + ) + + # Create collator + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Create test data: 2 queries, 2 passages (1 positive each) + queries = ["What is cerebral white matter?", "What is MRI?"] + passages = [REAL_TEXT, "MRI stands for Magnetic Resonance Imaging."] + + # Collate data + q_batch, p_batch, eos_positions = collator([ + (q, [p], []) for q, p in zip(queries, passages) + ]) + + # Create mock encoder + hidden_size = 64 + + class MockEncoderOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + def mock_encoder_forward(**kwargs): + input_ids = kwargs['input_ids'] + batch_size, seq_len = input_ids.shape + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return MockEncoderOutput(last_hidden_state=hidden_states) + + mock_encoder = Mock(side_effect=mock_encoder_forward) + mock_encoder.config = Mock() + mock_encoder.config.hidden_size = hidden_size + + # Create model + model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model.passage_chunk_size = data_args.passage_chunk_size + model.eos_positions = eos_positions + model.training = True + + # Forward pass + output = model(query=q_batch, passage=p_batch) + + # Verify output structure + assert hasattr(output, 'q_reps') + assert hasattr(output, 'p_reps') + assert hasattr(output, 'scores') + assert hasattr(output, 'loss') + + # Verify query representations + assert output.q_reps.shape[0] == len(queries) + assert output.q_reps.shape[1] == hidden_size + + # Verify passage representations (chunked) + assert isinstance(output.p_reps, tuple) # Should be (chunk_reps, chunk_mask) + chunk_reps, chunk_mask = output.p_reps + assert chunk_reps.shape[0] == len(passages) + assert chunk_reps.shape[2] == hidden_size + assert chunk_mask.shape == (len(passages), chunk_reps.shape[1]) + + # Verify scores shape + # With 2 queries and 2 passages, scores should be [2, 2] + assert output.scores.shape == (len(queries), len(passages)) + + # Verify loss is computed + assert output.loss is not None + assert output.loss.item() >= 0 # Loss should be non-negative + + # Test Case 2: Verify MaxSim is used (not regular similarity) + # Create a scenario where MaxSim gives different result than mean pooling + model.eval() + with torch.no_grad(): + # Use known embeddings where max chunk is different from mean + q_reps_test = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + # Passage with 2 chunks: first chunk similar to query, second chunk dissimilar + p_reps_test = torch.tensor([ + [[1.0, 0.0], [0.0, 1.0]], # Passage 0: chunk 0 matches query 0, chunk 1 matches query 1 + [[0.0, 1.0], [1.0, 0.0]], # Passage 1: chunk 0 matches query 1, chunk 1 matches query 0 + ], dtype=torch.float32) + chunk_mask_test = torch.ones(2, 2) + + scores_test = model.compute_maxsim_similarity(q_reps_test, p_reps_test, chunk_mask_test) + + # Query 0 with Passage 0: max similarity should be with chunk 0 (1.0) + assert torch.allclose(scores_test[0, 0], torch.tensor(1.0)) + # Query 0 with Passage 1: max similarity should be with chunk 1 (1.0) + assert torch.allclose(scores_test[0, 1], torch.tensor(1.0)) + # Query 1 with Passage 0: max similarity should be with chunk 1 (1.0) + assert torch.allclose(scores_test[1, 0], torch.tensor(1.0)) + # Query 1 with Passage 1: max similarity should be with chunk 0 (1.0) + assert torch.allclose(scores_test[1, 1], torch.tensor(1.0)) + + # Test Case 3: Verify padding chunks are ignored + p_reps_padded = torch.randn(2, 3, hidden_size) + chunk_mask_padded = torch.tensor([ + [1.0, 1.0, 0.0], # Passage 0: 2 valid chunks + [1.0, 0.0, 0.0], # Passage 1: 1 valid chunk + ]) + + scores_padded = model.compute_maxsim_similarity(output.q_reps, p_reps_padded, chunk_mask_padded) + assert scores_padded.shape == (len(queries), len(passages)) + + # Verify that padding doesn't affect the max (should only consider valid chunks) + for q_idx in range(len(queries)): + for p_idx in range(len(passages)): + chunk_scores = torch.einsum('h,ch->c', output.q_reps[q_idx], p_reps_padded[p_idx]) + valid_mask = chunk_mask_padded[p_idx].bool() + chunk_scores_masked = chunk_scores.clone() + chunk_scores_masked[~valid_mask] = float('-inf') + expected_score = chunk_scores_masked.max().item() + assert torch.allclose(scores_padded[q_idx, p_idx], torch.tensor(expected_score)) diff --git a/tests/test_pooling.py b/tests/test_pooling.py new file mode 100644 index 00000000..0f0ae921 --- /dev/null +++ b/tests/test_pooling.py @@ -0,0 +1,532 @@ +import sys +from pathlib import Path + +import pytest + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + # tevatron/tests/test_pooling.py -> tevatron/ -> tevatron/src + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +REAL_TEXT = ( + "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " + "development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging " + "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to " + "calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in " + "preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter " + "development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white " + "matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to " + "1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both " + "times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with " + "greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed " + "higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, " + "p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- " + "0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). " + "Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and " + "preterm infants at term showed marked differences in white matter fiber organization. The data indicate that " + "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " + "development in cerebral white matter in living infants" +) +EOS_TOKEN_ID = 151645 +PADDING_TOKEN_ID = 151643 + +@pytest.fixture(scope="session") +def train_tokenizer(): + """ + Use the Qwen 0.6B tokenizer. + """ + _add_tevatron_src_to_path() + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.padding_side = "right" # finetune_with_chunk.sh uses --padding_side right + return tok + + +@pytest.mark.unit +def test_encode_with_chunking(train_tokenizer, tmp_path): + """ + Test the full encode functionality with chunking enabled. + This tests the integration of: + - EncodeDataset loading JSONL data + - ChunkedEncodeCollator creating batches with eos_positions + - DenseModel.encode_passage with chunking + - Output shape and lookup_indices creation + """ + import json + import numpy as np + import torch + from torch.utils.data import DataLoader + from unittest.mock import Mock + + from tevatron.retriever.arguments import DataArguments, TevatronTrainingArguments as TrainingArguments + from tevatron.retriever.dataset import EncodeDataset + from tevatron.retriever.collator import ChunkedEncodeCollator + from tevatron.retriever.modeling.dense import DenseModel + + # Create temporary JSONL file with test passages + test_passages = [ + {"docid": "doc1", "text": REAL_TEXT}, # Long passage that will be chunked + {"docid": "doc2", "text": "Short passage."}, # Short passage + ] + + jsonl_file = tmp_path / "test_corpus.jsonl" + with open(jsonl_file, 'w') as f: + for passage in test_passages: + f.write(json.dumps(passage) + '\n') + + # Setup data arguments for chunked encoding + data_args = DataArguments( + dataset_name='json', + dataset_path=str(jsonl_file), + dataset_split='train', + passage_chunk_size=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + encode_is_query=False, + ) + + # Setup training arguments + training_args = TrainingArguments( + output_dir=str(tmp_path / "output"), + per_device_eval_batch_size=2, + dataloader_num_workers=0, + fp16=False, + bf16=False, + ) + + # Create dataset + encode_dataset = EncodeDataset(data_args=data_args) + assert len(encode_dataset) == 2 + + # Create chunked collator + encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Create data loader + encode_loader = DataLoader( + encode_dataset, + batch_size=training_args.per_device_eval_batch_size, + collate_fn=encode_collator, + shuffle=False, + drop_last=False, + num_workers=training_args.dataloader_num_workers, + ) + + # Create a mock encoder model + hidden_size = 64 + + # Create a proper mock that returns an object with last_hidden_state + class MockEncoderOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + # Mock the encoder forward pass to return hidden states + def mock_encoder_forward(**kwargs): + input_ids = kwargs['input_ids'] + batch_size, seq_len = input_ids.shape + # Create dummy hidden states with positional encoding for testing + hidden_states = torch.arange(batch_size * seq_len * hidden_size, dtype=torch.float32) + hidden_states = hidden_states.reshape(batch_size, seq_len, hidden_size) + # Add some variation based on input_ids for testing + hidden_states = hidden_states + input_ids.unsqueeze(-1).float() * 0.01 + return MockEncoderOutput(last_hidden_state=hidden_states) + + mock_encoder = Mock(side_effect=mock_encoder_forward) + mock_encoder.config = Mock() + mock_encoder.config.hidden_size = hidden_size + + # Create DenseModel with mock encoder + model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model.passage_chunk_size = data_args.passage_chunk_size + model.eval() + + # Simulate the encode loop + encoded = [] + lookup_indices = [] + + for batch in encode_loader: + doc_ids, batch_inputs, eos_positions = batch + + # Verify batch structure + assert isinstance(doc_ids, list) + # batch_inputs is a BatchEncoding (from tokenizer.pad), which behaves like a dict + assert hasattr(batch_inputs, '__getitem__') # Check if it's dict-like + assert 'input_ids' in batch_inputs + assert 'attention_mask' in batch_inputs + assert isinstance(eos_positions, list) + assert len(eos_positions) == len(doc_ids) + + # Verify eos_positions structure + for i, eos_pos_list in enumerate(eos_positions): + assert isinstance(eos_pos_list, list) + assert len(eos_pos_list) > 0 # Should have at least one chunk + # Verify eos_positions are within sequence length + seq_len = batch_inputs['input_ids'].shape[1] + for pos in eos_pos_list: + assert 0 <= pos < seq_len + + # Encode with chunking + with torch.no_grad(): + chunk_embs, chunk_mask = model.encode_passage(batch_inputs, eos_positions) + + # Verify output shapes + batch_size, max_chunks, hidden_size_out = chunk_embs.shape + assert batch_size == len(doc_ids) + assert hidden_size_out == hidden_size + assert chunk_mask.shape == (batch_size, max_chunks) + + # Verify chunk_mask values (should be 0 or 1) + assert torch.all((chunk_mask == 0) | (chunk_mask == 1)) + + # Process chunks and create lookup indices + for i, doc_id in enumerate(doc_ids): + for chunk_idx in range(max_chunks): + if chunk_mask[i, chunk_idx] > 0: # Valid chunk + encoded.append(chunk_embs[i, chunk_idx].cpu().detach().numpy()) + lookup_indices.append((doc_id, chunk_idx)) + + # Verify results + assert len(encoded) > 0 + assert len(lookup_indices) == len(encoded) + + # Stack encoded embeddings + encoded_array = np.stack(encoded) + assert encoded_array.shape[0] == len(encoded) + assert encoded_array.shape[1] == hidden_size + + # Verify lookup_indices structure + unique_docs = set(doc_id for doc_id, _ in lookup_indices) + assert len(unique_docs) == 2 # Should have both doc1 and doc2 + + # Verify doc1 has multiple chunks (it's a long passage) + doc1_chunks = [chunk_idx for doc_id, chunk_idx in lookup_indices if doc_id == "doc1"] + assert len(doc1_chunks) > 1 # Should have multiple chunks + + # Verify doc2 has at least one chunk + doc2_chunks = [chunk_idx for doc_id, chunk_idx in lookup_indices if doc_id == "doc2"] + assert len(doc2_chunks) >= 1 + + # Verify chunk indices are sequential starting from 0 + for doc_id in unique_docs: + doc_chunks = sorted([chunk_idx for d, chunk_idx in lookup_indices if d == doc_id]) + assert doc_chunks == list(range(len(doc_chunks))) # Should be 0, 1, 2, ... + + # Verify embeddings are not all zeros (they should have been computed) + assert not np.allclose(encoded_array, 0) + + # Verify embeddings have reasonable values (not NaN or Inf) + assert np.all(np.isfinite(encoded_array)) + + +@pytest.mark.unit +def test_pooling_chunked_eos_positions_alignment(): + """ + Test _pooling_chunked to verify that eos_positions correctly align with hidden states. + This test uses known hidden states and eos_positions to verify exact alignment. + """ + import torch + from unittest.mock import Mock + from tevatron.retriever.modeling.dense import DenseModel + + # Create a mock encoder + mock_encoder = Mock() + mock_encoder.config.hidden_size = 8 + + # Create DenseModel + model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model.passage_chunk_size = 32 + + # Test Case 1: Simple case with known positions + # Batch size=2, seq_len=10, hidden_size=8 + # Passage 0: eos at positions [2, 5, 8] (3 chunks) + # Passage 1: eos at positions [3, 7] (2 chunks) + batch_size = 2 + seq_len = 10 + hidden_size = 8 + + # Create hidden states with known values - each position has a unique pattern + # We'll use position index as part of the embedding to make verification easy + hidden_states = torch.zeros(batch_size, seq_len, hidden_size) + for i in range(batch_size): + for j in range(seq_len): + # Set embedding at position j to have value j*100 + i*10 in first dimension + # This makes it easy to verify we're extracting the right positions + hidden_states[i, j, 0] = j * 100 + i * 10 + # Fill other dimensions with position-dependent values + for k in range(1, hidden_size): + hidden_states[i, j, k] = j * 10 + k + + eos_positions = [[2, 5, 8], [3, 7]] + + # Call _pooling_chunked + chunk_reps, chunk_mask = model._pooling_chunked(hidden_states, eos_positions) + + # Verify output shapes + assert chunk_reps.shape == (batch_size, 3, hidden_size) # max_chunks = 3 + assert chunk_mask.shape == (batch_size, 3) + + # Verify Passage 0: should extract positions [2, 5, 8] + # Position 2: should have 2*100 + 0*10 = 200 in first dim + assert torch.allclose(chunk_reps[0, 0, 0], torch.tensor(200.0)) + assert torch.allclose(chunk_reps[0, 0, 1], torch.tensor(21.0)) # 2*10 + 1 + + # Position 5: should have 5*100 + 0*10 = 500 in first dim + assert torch.allclose(chunk_reps[0, 1, 0], torch.tensor(500.0)) + assert torch.allclose(chunk_reps[0, 1, 1], torch.tensor(51.0)) # 5*10 + 1 + + # Position 8: should have 8*100 + 0*10 = 800 in first dim + assert torch.allclose(chunk_reps[0, 2, 0], torch.tensor(800.0)) + assert torch.allclose(chunk_reps[0, 2, 1], torch.tensor(81.0)) # 8*10 + 1 + + # Verify Passage 1: should extract positions [3, 7] + # Position 3: should have 3*100 + 1*10 = 310 in first dim + assert torch.allclose(chunk_reps[1, 0, 0], torch.tensor(310.0)) + assert torch.allclose(chunk_reps[1, 0, 1], torch.tensor(31.0)) # 3*10 + 1 + + # Position 7: should have 7*100 + 1*10 = 710 in first dim + assert torch.allclose(chunk_reps[1, 1, 0], torch.tensor(710.0)) + assert torch.allclose(chunk_reps[1, 1, 1], torch.tensor(71.0)) # 7*10 + 1 + + # Verify chunk_mask + assert chunk_mask[0, 0] == 1.0 # Passage 0, chunk 0 (pos 2) + assert chunk_mask[0, 1] == 1.0 # Passage 0, chunk 1 (pos 5) + assert chunk_mask[0, 2] == 1.0 # Passage 0, chunk 2 (pos 8) + assert chunk_mask[1, 0] == 1.0 # Passage 1, chunk 0 (pos 3) + assert chunk_mask[1, 1] == 1.0 # Passage 1, chunk 1 (pos 7) + assert chunk_mask[1, 2] == 0.0 # Passage 1, chunk 2 (no chunk, should be 0) + + # Test Case 2: Verify exact tensor equality (not just close) + # Create hidden states where each position has a unique embedding + hidden_states_2 = torch.arange(batch_size * seq_len * hidden_size, dtype=torch.float32) + hidden_states_2 = hidden_states_2.reshape(batch_size, seq_len, hidden_size) + + # Extract embeddings manually for comparison + expected_chunk_0_0 = hidden_states_2[0, 2] # Passage 0, position 2 + expected_chunk_0_1 = hidden_states_2[0, 5] # Passage 0, position 5 + expected_chunk_0_2 = hidden_states_2[0, 8] # Passage 0, position 8 + expected_chunk_1_0 = hidden_states_2[1, 3] # Passage 1, position 3 + expected_chunk_1_1 = hidden_states_2[1, 7] # Passage 1, position 7 + + chunk_reps_2, chunk_mask_2 = model._pooling_chunked(hidden_states_2, eos_positions) + + # Verify exact equality + assert torch.equal(chunk_reps_2[0, 0], expected_chunk_0_0) + assert torch.equal(chunk_reps_2[0, 1], expected_chunk_0_1) + assert torch.equal(chunk_reps_2[0, 2], expected_chunk_0_2) + assert torch.equal(chunk_reps_2[1, 0], expected_chunk_1_0) + assert torch.equal(chunk_reps_2[1, 1], expected_chunk_1_1) + + # Test Case 3: Edge case - empty eos_positions + chunk_reps_empty, chunk_mask_empty = model._pooling_chunked(hidden_states, []) + assert chunk_reps_empty.shape == (batch_size, 0, hidden_size) + assert chunk_mask_empty.shape == (batch_size, 0) + + # Test Case 4: Edge case - out of bounds position (should be handled gracefully) + eos_positions_oob = [[2, 5, 15], [3, 7]] # 15 is out of bounds for seq_len=10 + chunk_reps_oob, chunk_mask_oob = model._pooling_chunked(hidden_states, eos_positions_oob) + + # Should still extract valid positions + assert chunk_reps_oob.shape == (batch_size, 3, hidden_size) + assert torch.allclose(chunk_reps_oob[0, 0], hidden_states[0, 2]) # Valid + assert torch.allclose(chunk_reps_oob[0, 1], hidden_states[0, 5]) # Valid + # Position 15 is out of bounds, so chunk_reps[0, 2] should be zeros + assert torch.allclose(chunk_reps_oob[0, 2], torch.zeros(hidden_size)) + assert chunk_mask_oob[0, 2] == 0.0 # Should be masked out + + # Test Case 5: Normalize=True + model.normalize = True + chunk_reps_norm, chunk_mask_norm = model._pooling_chunked(hidden_states_2, eos_positions) + + # Verify normalization (L2 norm should be 1 for non-zero chunks) + for i in range(batch_size): + for j in range(len(eos_positions[i])): + norm = torch.norm(chunk_reps_norm[i, j]) + assert torch.allclose(norm, torch.tensor(1.0), atol=1e-6) + + # Verify the normalized embeddings are proportional to original + model.normalize = False + chunk_reps_no_norm, _ = model._pooling_chunked(hidden_states_2, eos_positions) + for i in range(batch_size): + for j in range(len(eos_positions[i])): + # Normalized version should be original / norm + expected_norm = torch.norm(chunk_reps_no_norm[i, j]) + normalized_manual = chunk_reps_no_norm[i, j] / expected_norm + assert torch.allclose(chunk_reps_norm[i, j], normalized_manual, atol=1e-6) + + # Test Case 6: Single chunk per passage + eos_positions_single = [[4], [6]] + chunk_reps_single, chunk_mask_single = model._pooling_chunked(hidden_states_2, eos_positions_single) + + assert chunk_reps_single.shape == (batch_size, 1, hidden_size) + assert torch.equal(chunk_reps_single[0, 0], hidden_states_2[0, 4]) + assert torch.equal(chunk_reps_single[1, 0], hidden_states_2[1, 6]) + assert chunk_mask_single[0, 0] == 1.0 + assert chunk_mask_single[1, 0] == 1.0 + + # Test Case 7: Verify positions are extracted in correct order + # Use sequential positions to verify order + eos_positions_ordered = [[1, 3, 5], [2, 4]] + chunk_reps_ordered, _ = model._pooling_chunked(hidden_states_2, eos_positions_ordered) + + # Passage 0: should be in order [1, 3, 5] + assert torch.equal(chunk_reps_ordered[0, 0], hidden_states_2[0, 1]) + assert torch.equal(chunk_reps_ordered[0, 1], hidden_states_2[0, 3]) + assert torch.equal(chunk_reps_ordered[0, 2], hidden_states_2[0, 5]) + + # Passage 1: should be in order [2, 4] + assert torch.equal(chunk_reps_ordered[1, 0], hidden_states_2[1, 2]) + assert torch.equal(chunk_reps_ordered[1, 1], hidden_states_2[1, 4]) + + +@pytest.mark.unit +def test_pooling_chunked_real_tokenizer_alignment(train_tokenizer): + """ + Integration test: Verify that eos_positions from ChunkedEncodeCollator + correctly align with hidden states when using _pooling_chunked. + This uses real tokenizer to ensure end-to-end correctness. + """ + import torch + from unittest.mock import Mock + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + from tevatron.retriever.modeling.dense import DenseModel + + # Setup data arguments + data_args = DataArguments( + passage_chunk_size=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + passage_prefix="", + append_eos_token=False, + ) + + # Create collator + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Test passages + passages = [ + REAL_TEXT, # Long passage that will be chunked + "Short passage for testing.", # Short passage + ] + + # Get tokenized and chunked data + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages(passages) + + # Verify eos_positions are valid + input_ids = d_collated['input_ids'] + seq_len = input_ids.shape[1] + + for i, eos_pos_list in enumerate(eos_positions): + assert len(eos_pos_list) > 0, f"Passage {i} should have at least one chunk" + for pos in eos_pos_list: + assert 0 <= pos < seq_len, f"EOS position {pos} out of bounds for seq_len {seq_len}" + # Verify that the position actually contains EOS token + assert input_ids[i, pos] == train_tokenizer.eos_token_id, \ + f"Position {pos} should contain EOS token {train_tokenizer.eos_token_id}, got {input_ids[i, pos]}" + + # Create mock encoder that returns hidden states based on input_ids + # This allows us to verify exact alignment + hidden_size = 64 + + class MockEncoderOutput: + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + def mock_encoder_forward(**kwargs): + input_ids = kwargs['input_ids'] + batch_size, seq_len = input_ids.shape + + # Create hidden states where each position's embedding encodes its position + # This makes it easy to verify we're extracting the right positions + hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) + for i in range(batch_size): + for j in range(seq_len): + # Encode position j in the embedding + # Use input_ids[i, j] as seed to make it unique per token + hidden_states[i, j, 0] = float(j) # Position index + hidden_states[i, j, 1] = float(input_ids[i, j]) # Token ID + # Fill rest with position-dependent values + for k in range(2, hidden_size): + hidden_states[i, j, k] = float(j * hidden_size + k) + + return MockEncoderOutput(last_hidden_state=hidden_states) + + mock_encoder = Mock(side_effect=mock_encoder_forward) + mock_encoder.config = Mock() + mock_encoder.config.hidden_size = hidden_size + + # Create model + model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) + model.passage_chunk_size = data_args.passage_chunk_size + + # Convert BatchEncoding to dict for model + batch_inputs = { + 'input_ids': d_collated['input_ids'], + 'attention_mask': d_collated['attention_mask'], + } + + # Encode with chunking + chunk_reps, chunk_mask = model.encode_passage(batch_inputs, eos_positions) + + # Verify shapes + batch_size = len(passages) + max_chunks = max(len(pos_list) for pos_list in eos_positions) + assert chunk_reps.shape == (batch_size, max_chunks, hidden_size) + assert chunk_mask.shape == (batch_size, max_chunks) + + # Verify that extracted embeddings match the eos_positions + # We need to get the hidden states that were generated + # Since we can't easily access them, we'll verify by checking the mock was called correctly + # and that the extracted positions match what we expect + + # Re-create hidden states with the same logic to verify + hidden_states_expected = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) + for i in range(batch_size): + for j in range(seq_len): + hidden_states_expected[i, j, 0] = float(j) + hidden_states_expected[i, j, 1] = float(input_ids[i, j]) + for k in range(2, hidden_size): + hidden_states_expected[i, j, k] = float(j * hidden_size + k) + + # Verify each extracted chunk embedding matches the expected position + for i, eos_pos_list in enumerate(eos_positions): + for j, pos in enumerate(eos_pos_list): + # The extracted embedding should match the hidden state at position pos + expected_embedding = hidden_states_expected[i, pos] + extracted_embedding = chunk_reps[i, j] + + # Verify exact match (they should be identical) + assert torch.equal(extracted_embedding, expected_embedding), \ + f"Passage {i}, chunk {j} (eos_pos={pos}): extracted embedding doesn't match hidden state at position {pos}" + + # Verify chunk mask is set correctly + assert chunk_mask[i, j] == 1.0, f"Chunk mask should be 1.0 for valid chunk" + + # Verify that invalid chunks (beyond actual chunks) have mask=0 + for i in range(batch_size): + num_chunks = len(eos_positions[i]) + for j in range(num_chunks, max_chunks): + assert chunk_mask[i, j] == 0.0, f"Invalid chunk should have mask=0" + + # Verify that the first dimension of extracted embeddings contains position indices + for i, eos_pos_list in enumerate(eos_positions): + for j, pos in enumerate(eos_pos_list): + # First dimension should equal the position + assert torch.allclose(chunk_reps[i, j, 0], torch.tensor(float(pos))), \ + f"First dim should equal position {pos}, got {chunk_reps[i, j, 0]}" + + # Second dimension should equal the token ID at that position + expected_token_id = float(input_ids[i, pos]) + assert torch.allclose(chunk_reps[i, j, 1], torch.tensor(expected_token_id)), \ + f"Second dim should equal token ID {expected_token_id}, got {chunk_reps[i, j, 1]}" diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 00000000..155ecf89 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,211 @@ +import sys +from pathlib import Path +import pickle +import numpy as np +import pytest +from collections import defaultdict + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.mark.unit +def test_search_chunked_vs_non_chunked(): + """ + Test search behavior differences between chunked and non-chunked modes. + This verifies: + 1. Auto-detection of chunked format + 2. MaxSim aggregation logic + 3. Search depth handling + """ + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked, search_queries + from tevatron.retriever.searcher import FaissFlatSearcher + + # Create mock query and passage embeddings + num_queries = 3 + num_docs = 10 + hidden_size = 64 + + # Query embeddings + q_reps = np.random.randn(num_queries, hidden_size).astype(np.float32) + # Normalize for inner product search + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + # Test Case 1: Non-chunked format + # Each document has one embedding + p_reps_non_chunked = np.random.randn(num_docs, hidden_size).astype(np.float32) + # Normalize for inner product search + p_reps_non_chunked = p_reps_non_chunked / np.linalg.norm(p_reps_non_chunked, axis=1, keepdims=True) + p_lookup_non_chunked = [f"doc_{i}" for i in range(num_docs)] + + retriever_non_chunked = FaissFlatSearcher(p_reps_non_chunked) + # Need to add embeddings to index + retriever_non_chunked.add(p_reps_non_chunked) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Non-chunked search + scores_non_chunked, indices_non_chunked = search_queries( + retriever_non_chunked, q_reps, p_lookup_non_chunked, args + ) + + # Verify non-chunked results + assert len(scores_non_chunked) == num_queries + assert len(indices_non_chunked) == num_queries + for q_idx in range(num_queries): + assert len(scores_non_chunked[q_idx]) == args.depth + assert len(indices_non_chunked[q_idx]) == args.depth + # indices_non_chunked contains document IDs (strings), not indices + assert all(isinstance(doc_id, (str, np.str_)) for doc_id in indices_non_chunked[q_idx][:5]) + + # Test Case 2: Chunked format - single chunk per document + # This simulates chunk_size == max_passage_size scenario + # Each document has exactly one chunk + p_reps_chunked_single = np.random.randn(num_docs, hidden_size).astype(np.float32) + # Normalize for inner product search + p_reps_chunked_single = p_reps_chunked_single / np.linalg.norm(p_reps_chunked_single, axis=1, keepdims=True) + q_reps_normalized = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_lookup_chunked_single = [(f"doc_{i}", 0) for i in range(num_docs)] + + retriever_chunked_single = FaissFlatSearcher(p_reps_chunked_single) + # Need to add embeddings to index + retriever_chunked_single.add(p_reps_chunked_single) + + # Chunked search with single chunk per doc + results_chunked_single = search_queries_chunked( + retriever_chunked_single, q_reps_normalized, p_lookup_chunked_single, args + ) + + # Verify chunked results + assert len(results_chunked_single) == num_queries + for q_idx in range(num_queries): + # Results might be less than depth if fewer documents exist + assert len(results_chunked_single[q_idx]) <= args.depth + assert len(results_chunked_single[q_idx]) > 0, "Should have at least some results" + # Each result should be (doc_id, score) tuple + for doc_id, score in results_chunked_single[q_idx]: + assert isinstance(doc_id, str) + assert isinstance(score, (int, float, np.floating)) + + # Test Case 3: Chunked format - multiple chunks per document + # Some documents have multiple chunks + num_chunks_total = 20 # More chunks than documents + p_reps_chunked_multi = np.random.randn(num_chunks_total, hidden_size).astype(np.float32) + # Normalize for inner product search + p_reps_chunked_multi = p_reps_chunked_multi / np.linalg.norm(p_reps_chunked_multi, axis=1, keepdims=True) + p_lookup_chunked_multi = [] + # Document 0-4: 2 chunks each (10 chunks) + # Document 5-9: 2 chunks each (10 chunks) + for doc_idx in range(num_docs): + for chunk_idx in range(2): + p_lookup_chunked_multi.append((f"doc_{doc_idx}", chunk_idx)) + + retriever_chunked_multi = FaissFlatSearcher(p_reps_chunked_multi) + retriever_chunked_multi.add(p_reps_chunked_multi) + + # Chunked search with multiple chunks per doc + results_chunked_multi = search_queries_chunked( + retriever_chunked_multi, q_reps, p_lookup_chunked_multi, args + ) + + # Verify MaxSim aggregation + assert len(results_chunked_multi) == num_queries + for q_idx in range(num_queries): + assert len(results_chunked_multi[q_idx]) == args.depth + # Verify MaxSim: each document should appear at most once + doc_ids = [doc_id for doc_id, _ in results_chunked_multi[q_idx]] + assert len(doc_ids) == len(set(doc_ids)), "Each document should appear only once (MaxSim aggregation)" + + # Verify scores are in descending order + scores = [score for _, score in results_chunked_multi[q_idx]] + assert scores == sorted(scores, reverse=True), "Scores should be in descending order" + + # Test Case 4: Verify MaxSim logic - same document with multiple chunks + # Create a scenario where one document has the best chunks + q_rep_test = np.random.randn(1, hidden_size).astype(np.float32) + q_rep_test = q_rep_test / np.linalg.norm(q_rep_test, axis=1, keepdims=True) + + # Create embeddings where doc_0 chunks are most similar to query + p_reps_test = np.random.randn(5, hidden_size).astype(np.float32) + # Make doc_0 chunks (indices 0, 1) more similar to query + p_reps_test[0] = q_rep_test[0] * 0.9 + np.random.randn(hidden_size) * 0.1 + p_reps_test[1] = q_rep_test[0] * 0.8 + np.random.randn(hidden_size) * 0.2 + # Other chunks less similar + p_reps_test[2:] = q_rep_test[0] * 0.3 + np.random.randn(3, hidden_size) * 0.7 + # Normalize + p_reps_test = p_reps_test / np.linalg.norm(p_reps_test, axis=1, keepdims=True) + + p_lookup_test = [ + ("doc_0", 0), # Best chunk + ("doc_0", 1), # Second best chunk + ("doc_1", 0), # Less similar + ("doc_2", 0), # Less similar + ("doc_3", 0), # Less similar + ] + + retriever_test = FaissFlatSearcher(p_reps_test) + retriever_test.add(p_reps_test) + results_test = search_queries_chunked(retriever_test, q_rep_test, p_lookup_test, args) + + # Verify MaxSim: doc_0 should be ranked first (max of its two chunks) + assert len(results_test) == 1 + assert len(results_test[0]) > 0, "Should have results" + top_doc = results_test[0][0][0] + assert top_doc == "doc_0", "doc_0 should be ranked first due to MaxSim (max of its chunks)" + + # Test Case 5: Verify search depth multiplier + args_large = MockArgs() + args_large.depth = 5 + args_large.chunk_multiplier = 10 + args_large.batch_size = 0 + args_large.quiet = True + + # With chunk_multiplier=10, should search 5 * 10 = 50 chunks + # But we only have 20 chunks, so should get all chunks + results_depth_test = search_queries_chunked( + retriever_chunked_multi, q_reps_normalized, p_lookup_chunked_multi, args_large + ) + + # Should return up to depth documents (after MaxSim aggregation) + assert len(results_depth_test[0]) <= args_large.depth + assert len(results_depth_test[0]) > 0, "Should have some results" + + # Test Case 6: Verify auto-detection logic + # Test that tuple format is detected as chunked + assert isinstance(p_lookup_chunked_single[0], tuple), "Chunked lookup should be tuple" + assert not isinstance(p_lookup_non_chunked[0], tuple), "Non-chunked lookup should be string" + + # Test Case 7: Verify that single chunk per doc behaves correctly + # When chunk_size == max_passage_size, each doc has one chunk + # In this case, MaxSim should give same result as non-chunked (if embeddings are identical) + # But search depth multiplier means we search more candidates + p_reps_single_chunk = p_reps_chunked_single.copy() + q_reps_single = q_reps_normalized.copy() + + # Search with same embeddings but different formats + results_single_chunk = search_queries_chunked( + retriever_chunked_single, q_reps_single, p_lookup_chunked_single, args + ) + + # Verify results structure + assert len(results_single_chunk) == num_queries + for q_idx in range(num_queries): + assert len(results_single_chunk[q_idx]) > 0 + # Each result should be (doc_id, score) + for doc_id, score in results_single_chunk[q_idx]: + assert isinstance(doc_id, str) + assert isinstance(score, (int, float, np.floating)) From 3c1752d75dc8f595c8c10bf48d8cca4310506fb8 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 13:46:56 -0500 Subject: [PATCH 19/31] added collator helper functions --- src/tevatron/retriever/collator.py | 165 +++++--- tests/test_chunking.py | 465 +++------------------ tests/test_chunking_helper.py | 229 ++++++++++ tests/test_chunking_pooling_equivalence.py | 100 ++--- tests/test_forward.py | 69 +-- tests/test_padding_helper.py | 314 ++++++++++++++ tests/test_pooling.py | 237 ++--------- 7 files changed, 793 insertions(+), 786 deletions(-) create mode 100644 tests/test_chunking_helper.py create mode 100644 tests/test_padding_helper.py diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 7b27b282..65397a39 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -13,6 +13,106 @@ logger = logging.getLogger(__name__) +def _chunk_tokens( + tokens: List[int], + chunk_size: int, + eos_token_id: int, + max_length: int = None, +) -> Tuple[List[int], List[int]]: + """ + Chunk a list of tokens into chunks of specified size, adding EOS token after each chunk. + + :param tokens: List of token IDs to chunk + :param chunk_size: Maximum size of each chunk (before adding EOS). Must be >= 2. + :param eos_token_id: EOS token ID to append after each chunk + :param max_length: Optional maximum total length (including EOS tokens). If None, no limit. + :return: Tuple of (chunked_ids, eos_positions) where: + - chunked_ids: List of token IDs with EOS separators between chunks + - eos_positions: List of positions where EOS tokens were inserted + """ + if chunk_size < 2: + # chunk_size must be at least 2 to fit at least 1 token + 1 EOS + return [], [] + + chunk_len = chunk_size - 1 # Reserve 1 slot for EOS + ids = [] + eos_pos = [] + + i = 0 + while i < len(tokens): + if max_length and max_length > 0: + remaining = max_length - len(ids) + # Need at least 1 slot for EOS; otherwise stop (don't add empty chunks). + if remaining <= 1: + break + take = min(chunk_len, len(tokens) - i, remaining - 1) + if take <= 0: + break + else: + take = min(chunk_len, len(tokens) - i) + if take <= 0: + break + + chunk = tokens[i:i + take] # up to chunk_len tokens + ids.extend(chunk) + ids.append(eos_token_id) # EOS at end of this chunk + eos_pos.append(len(ids) - 1) # position of EOS (pooling position) + i += take + + return ids, eos_pos + + +def _pad_and_adjust_eos_positions( + all_input_ids: List[List[int]], + all_eos_positions: List[List[int]], + tokenizer: PreTrainedTokenizer, + padding_side: str, + pad_to_multiple_of: int, +) -> Tuple[dict, List[List[int]]]: + """ + Pad input IDs and adjust EOS positions based on padding side. + + :param all_input_ids: List of lists of token IDs (one per passage) + :param all_eos_positions: List of lists of EOS positions (one per passage) + :param tokenizer: Tokenizer to use for padding + :param padding_side: 'left' or 'right' - side to pad on + :param pad_to_multiple_of: Pad sequences to multiple of this value + :return: Tuple of (padded_dict, adjusted_eos_positions) where: + - padded_dict: dict with 'input_ids' and 'attention_mask' tensors + - adjusted_eos_positions: List of lists with EOS positions adjusted for padding + """ + d_collated = {'input_ids': all_input_ids} + + # Store original lengths before padding to adjust eos_positions for left padding + original_lengths = [len(ids) for ids in all_input_ids] + + # Set tokenizer padding_side before padding + tokenizer.padding_side = padding_side + + # Padding + d_collated = tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) + + # Adjust eos_positions for left padding + # When padding_side is 'left', padding tokens are added at the beginning, + # so EOS positions need to be shifted by the padding length + # Create a deep copy to avoid modifying the original + adjusted_eos_positions = [list(eos_pos_list) for eos_pos_list in all_eos_positions] + if padding_side == 'left': + padded_lengths = d_collated['input_ids'].shape[1] # All sequences have same length after padding + for i, eos_pos_list in enumerate(adjusted_eos_positions): + padding_length = padded_lengths - original_lengths[i] + # Shift each EOS position by the padding length + adjusted_eos_positions[i] = [pos + padding_length for pos in eos_pos_list] + + return d_collated, adjusted_eos_positions + + def _tokenize_and_pad_chunked_passages( passages: List[str], tokenizer: PreTrainedTokenizer, @@ -32,7 +132,6 @@ def _tokenize_and_pad_chunked_passages( - collated_dict: dict with 'input_ids' and 'attention_mask' tensors - eos_positions: list of lists, one per passage, containing EOS token positions """ - chunk_len = data_args.passage_chunk_size - 1 eos_id = tokenizer.eos_token_id if eos_id is None: raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") @@ -45,64 +144,24 @@ def _tokenize_and_pad_chunked_passages( if passage is None: passage = "" tokens = tokenizer.encode(passage, add_special_tokens=False) - ids = [] - eos_pos = [] - - # Build chunked ids, optionally capped by max_length (total tokens including EOS separators). - i = 0 - while i < len(tokens): - if max_length and max_length > 0: - remaining = max_length - len(ids) - # Need at least 1 slot for EOS; otherwise stop (don't add empty chunks). - if remaining <= 1: - break - take = min(chunk_len, len(tokens) - i, remaining - 1) - if take <= 0: - break - else: - take = min(chunk_len, len(tokens) - i) - - chunk = tokens[i:i + take] # up to chunk_len tokens - ids.extend(chunk) - ids.append(eos_id) # EOS at end of this chunk - eos_pos.append(len(ids) - 1) # position of EOS (pooling position) - i += take - + ids, eos_pos = _chunk_tokens( + tokens=tokens, + chunk_size=data_args.passage_chunk_size, + eos_token_id=eos_id, + max_length=max_length, + ) all_input_ids.append(ids) all_eos_positions.append(eos_pos) - d_collated = {'input_ids': all_input_ids} - - # Store original lengths before padding to adjust eos_positions for left padding - original_lengths = [len(ids) for ids in all_input_ids] - - # Set tokenizer padding_side before padding - original_padding_side = tokenizer.padding_side - tokenizer.padding_side = data_args.padding_side - - # Padding - d_collated = tokenizer.pad( - d_collated, - padding=True, + d_collated, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=tokenizer, + padding_side=data_args.padding_side, pad_to_multiple_of=data_args.pad_to_multiple_of, - return_attention_mask=True, - return_tensors='pt', ) - # Restore original padding_side - tokenizer.padding_side = original_padding_side - - # Adjust eos_positions for left padding - # When padding_side is 'left', padding tokens are added at the beginning, - # so EOS positions need to be shifted by the padding length - if data_args.padding_side == 'left': - padded_lengths = d_collated['input_ids'].shape[1] # All sequences have same length after padding - for i, eos_pos_list in enumerate(all_eos_positions): - padding_length = padded_lengths - original_lengths[i] - # Shift each EOS position by the padding length - all_eos_positions[i] = [pos + padding_length for pos in eos_pos_list] - - return d_collated, all_eos_positions + return d_collated, adjusted_eos_positions @dataclass diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 8c736858..10502cb0 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -56,13 +56,7 @@ def train_tokenizer(): @pytest.mark.unit def test_train_collator_chunked_passages(train_tokenizer): - """ - Restore finetune_with_chunk.sh passage chunking scene: - - passage_max_len=512 - - passage_chunk_size=256 - - pad_to_multiple_of=16 (DataArguments default) - - padding_side=right - """ + """Test chunking with passage_max_len=512, passage_chunk_size=256.""" from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator @@ -74,55 +68,23 @@ def test_train_collator_chunked_passages(train_tokenizer): append_eos_token=False, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - - # ======================================================================== - # NOTE: This test directly calls _tokenize_and_pad_chunked_passages() instead - # of collator.__call__() to test chunking in isolation. - # - # If we used collator.__call__(features) with passage_chunk_size > 0, it would return: - # (q_batch, p_batch, eos_positions) # 3-element tuple - # - # Where: - # - q_batch: dict with "input_ids" and "attention_mask" for queries - # - p_batch: dict with "input_ids" and "attention_mask" for chunked passages - # - eos_positions: list of lists, e.g., [[255, 430]] - EOS token positions per passage - # Used by the model to extract chunk embeddings via MaxSim pooling - # ======================================================================== d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) got_ids = d_collated["input_ids"][0].tolist() got_mask = d_collated["attention_mask"][0].tolist() - got_unpadded_len = sum(got_mask) - assert got_unpadded_len == 431 + assert sum(got_mask) == 431 assert eos_positions == [[255, 430]] - # EOS token at eos positions assert got_ids[255] == train_tokenizer.eos_token_id assert got_ids[430] == train_tokenizer.eos_token_id - print("length of got_ids: ", len(got_ids)) - - expected_ids = [ - 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, EOS_TOKEN_ID, 20, 51615, 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, 304, 5382, 41434, EOS_TOKEN_ID, PADDING_TOKEN_ID - ] - assert got_ids == expected_ids - - # Hardcoded attention_mask: 431 ones (unpadded tokens) + 1 zero (padding) - # Padded to multiple of 16: 431 -> 432 - expected_mask = [1] * 431 + [0] * 1 - assert len(got_mask) == 432 - assert got_mask == expected_mask - # Verify attention_mask is 1 at eos_positions (EOS tokens should be attended) + assert len(got_ids) == 432 # Padded to multiple of 16 assert got_mask[255] == 1 assert got_mask[430] == 1 @pytest.mark.unit -def test_chunk_size_zero_with_train_tokenizer_disables_chunking_and_truncates(train_tokenizer): - """ - With passage_chunk_size > 0, TrainCollator should take the chunking path. - - Tests chunked passages with passage_max_len=64 and passage_chunk_size=32. - """ +def test_chunked_collator_with_multiple_passages(train_tokenizer): + """Test TrainCollator with chunking enabled returns (q_batch, p_batch, eos_positions).""" from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator @@ -136,95 +98,37 @@ def test_chunk_size_zero_with_train_tokenizer_disables_chunking_and_truncates(tr passage_chunk_size=32, ) collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - - # ======================================================================== - # HOW features IS CONSTRUCTED: - # ======================================================================== - # features mimics what TrainDataset.__getitem__() returns. Each element is: - # (query_tuple, list_of_passage_tuples) - # - # Where: - # - query_tuple: (text, image, video, audio) - in this test, only text is used - # - list_of_passage_tuples: [(text, image, video, audio), ...] - one per passage - # - # Structure breakdown: - # - ("q1", None, None, None) = query with text="q1", no multimodal content - # - [(REAL_TEXT, ...), (REAL_TEXT, ...)] = 2 passages (train_group_size=2) - # Each passage tuple: (text=REAL_TEXT, image=None, video=None, audio=None) - # ======================================================================== features = [ (("q1", None, None, None), [(REAL_TEXT, None, None, None), (REAL_TEXT, None, None, None)]), ] - - # ======================================================================== - # WHAT collator(features) RETURNS: - # ======================================================================== - # Since passage_chunk_size > 0 (chunking enabled), TrainCollator.__call__() returns: - # (q_batch, p_batch, eos_positions) # 3-element tuple - # - # Where: - # q_batch: dict with PyTorch tensors for queries - # - "input_ids": tensor([[token_ids for "q1"]]) # shape: [num_queries, query_seq_len] - # - "attention_mask": tensor([[1, 1, ...]]) # shape: [num_queries, query_seq_len] - # - # p_batch: dict with PyTorch tensors for chunked passages (FLATTENED across all queries) - # - "input_ids": tensor([ - # [token_ids for passage 1 (chunked, padded to multiple of 16)], - # [token_ids for passage 2 (chunked, padded to multiple of 16)] - # ]) # shape: [total_passages, passage_seq_len] - # - "attention_mask": tensor([ - # [1, 1, ..., 0, 0, ...], # attention mask with padding - # [1, 1, ..., 0, 0, ...] - # ]) # shape: [total_passages, passage_seq_len] - # - # eos_positions: list of lists, e.g., [[31, 63], [31, 63]] - EOS token positions per passage - # Used by the model to extract chunk embeddings via MaxSim pooling - # - # Note: The collator flattens all passages from all queries into a single batch. - # With 1 query and train_group_size=2, we get 2 passages in p_batch. - # ======================================================================== - out = collator(features) - assert len(out) == 3 # Verify chunking path returns 3 elements - q_batch, p_batch, eos_positions = out # Unpack: q_batch (queries), p_batch (passages), eos_positions - - assert p_batch["input_ids"].shape[0] == 2 # train_group_size=2 - assert len(eos_positions) == 2 # One list of eos positions per passage - + + q_batch, p_batch, eos_positions = collator(features) + + assert p_batch["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + for i in range(p_batch["input_ids"].shape[0]): got_ids = p_batch["input_ids"][i].tolist() got_mask = p_batch["attention_mask"][i].tolist() - unpadded_len = sum(got_mask) - - # Verify chunking structure - assert len(eos_positions[i]) > 0 # Should have at least one chunk - assert _strictly_increasing(eos_positions[i]) # EOS positions should be strictly increasing - # Verify EOS tokens at eos positions + assert len(eos_positions[i]) > 0 + assert _strictly_increasing(eos_positions[i]) for eos_pos in eos_positions[i]: assert got_ids[eos_pos] == train_tokenizer.eos_token_id - assert got_mask[eos_pos] == 1 # EOS tokens should be attended - eos_positions[0][0] == 31 - eos_positions[0][1] == 63 - eos_positions[1][0] == 31 - eos_positions[1][1] == 63 - # Verify padding to multiple of 16 + assert got_mask[eos_pos] == 1 assert len(got_ids) == 64 - assert len(got_mask) == 64 - assert len(got_ids) == len(got_mask) @pytest.mark.unit -def test_chunking_chunk_size_equal_maxlen_is_capped_to_single_chunk(train_tokenizer): - """ - When chunk_size == max_len, chunking should be capped to exactly max_len total tokens - (incl. EOS), with exactly one EOS at the end. - """ +@pytest.mark.parametrize("chunk_size", [64, 128]) +def test_chunking_capped_to_maxlen(train_tokenizer, chunk_size): + """When chunk_size >= max_len, chunking is capped to max_len with one EOS.""" from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator long_text = (REAL_TEXT + " ") * 20 data_args = DataArguments( - passage_chunk_size=64, + passage_chunk_size=chunk_size, passage_max_len=64, pad_to_multiple_of=16, padding_side="right", @@ -235,66 +139,11 @@ def test_chunking_chunk_size_equal_maxlen_is_capped_to_single_chunk(train_tokeni ids = d_collated["input_ids"][0].tolist() mask = d_collated["attention_mask"][0].tolist() - # Hardcoded golden output - expected_ids = [ - 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, - 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, - 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, - 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, - 349, 2326, EOS_TOKEN_ID - ] - expected_eos_positions = [[63]] - expected_mask = [1] * 64 - assert sum(mask) == 64 assert len(ids) == 64 - assert eos_positions == expected_eos_positions - assert ids == expected_ids + assert eos_positions == [[63]] assert ids[63] == EOS_TOKEN_ID - assert EOS_TOKEN_ID not in ids[0:63] # EOS token should not be in the first 63 tokens - assert mask == expected_mask - assert _strictly_increasing(eos_positions[0]) - - -@pytest.mark.unit -def test_chunking_chunk_size_greater_than_maxlen_is_capped_to_single_chunk(train_tokenizer): - """ - When chunk_size > max_len, chunking should still be capped to exactly max_len total tokens - (incl. EOS), with exactly one EOS at the end. - """ - from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import TrainCollator - - long_text = (REAL_TEXT + " ") * 20 - data_args = DataArguments( - passage_chunk_size=128, - passage_max_len=64, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([long_text]) - ids = d_collated["input_ids"][0].tolist() - mask = d_collated["attention_mask"][0].tolist() - - # Hardcoded golden output (same as chunk_size == max_len case) - expected_ids = [ - 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, - 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, - 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, - 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, - 349, 2326, EOS_TOKEN_ID - ] - expected_eos_positions = [[63]] - expected_mask = [1] * 64 - - assert sum(mask) == 64 - assert len(ids) == 64 - assert eos_positions == expected_eos_positions - assert ids == expected_ids - assert ids[63] == EOS_TOKEN_ID - assert mask == expected_mask + assert EOS_TOKEN_ID not in ids[:63] assert _strictly_increasing(eos_positions[0]) @@ -484,78 +333,48 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): @pytest.mark.unit def test_non_chunked_padding_side_behavior(train_tokenizer): - """ - Test non-chunked passage encoding behavior with left vs right padding. - This verifies that padding_side affects how _pooling('last'/'eos') extracts embeddings. - """ + """Test that padding_side affects pooling position extraction.""" import torch from unittest.mock import Mock from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator from tevatron.retriever.modeling.dense import DenseModel - # Test passage - will be truncated to max_len - test_passage = REAL_TEXT # Long passage that will be truncated + test_passage = REAL_TEXT - # Test Case 1: Right padding + # Right padding data_args_right = DataArguments( passage_max_len=64, - passage_chunk_size=0, # No chunking + passage_chunk_size=0, pad_to_multiple_of=16, padding_side="right", - passage_prefix="", append_eos_token=False, ) - collator_right = TrainCollator(data_args=data_args_right, tokenizer=train_tokenizer) - q_batch_right, p_batch_right = collator_right([("query", [test_passage], [])]) - - # Verify right padding structure - input_ids_right = p_batch_right['input_ids'][0] + _, p_batch_right = collator_right([("query", [test_passage], [])]) attention_mask_right = p_batch_right['attention_mask'][0] - seq_len_right = len(attention_mask_right) - - # With right padding, content tokens are at the beginning, padding at the end - # Last position should be padding (since passage is truncated and padded) - # Note: first position might be special token (BOS) due to add_special_tokens=True - assert attention_mask_right[-1] == 0, "Right padding: last position should be padding" - - # Last valid token position last_valid_pos_right = attention_mask_right.sum().item() - 1 - # Test Case 2: Left padding + # Left padding data_args_left = DataArguments( passage_max_len=64, - passage_chunk_size=0, # No chunking + passage_chunk_size=0, pad_to_multiple_of=16, padding_side="left", - passage_prefix="", append_eos_token=False, ) - collator_left = TrainCollator(data_args=data_args_left, tokenizer=train_tokenizer) - q_batch_left, p_batch_left = collator_left([("query", [test_passage], [])]) - - # Verify left padding structure - input_ids_left = p_batch_left['input_ids'][0] + _, p_batch_left = collator_left([("query", [test_passage], [])]) attention_mask_left = p_batch_left['attention_mask'][0] - seq_len_left = len(attention_mask_left) - - # With left padding, padding tokens are at the beginning, content at the end - # Due to pad_to_multiple_of, the actual behavior depends on content length - # Key observation: The pooling logic checks if last position is valid to determine left padding num_valid_left = attention_mask_left.sum().item() + is_left_padding = (attention_mask_left[-1] == 1).item() - # The _pooling logic: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) - # If last position is 1 for all sequences, it treats it as left padding - is_detected_as_left_padding = (attention_mask_left[-1] == 1).item() + # Verify content tokens are identical + content_right = p_batch_right['input_ids'][0][attention_mask_right.bool()].tolist() + content_left = p_batch_left['input_ids'][0][attention_mask_left.bool()].tolist() + assert content_right == content_left - # Verify both versions tokenized the same content (ignoring padding) - content_tokens_right = input_ids_right[attention_mask_right.bool()].tolist() - content_tokens_left = input_ids_left[attention_mask_left.bool()].tolist() - assert content_tokens_right == content_tokens_left, "Content tokens should be identical" - - # Test Case 3: Verify pooling behavior with mock model + # Test pooling with mock model hidden_size = 64 class MockEncoderOutput: @@ -565,11 +384,9 @@ def __init__(self, last_hidden_state): def mock_encoder_forward(**kwargs): input_ids = kwargs['input_ids'] batch_size, seq_len = input_ids.shape - # Create hidden states where each position encodes its position index hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) for i in range(batch_size): for j in range(seq_len): - # Encode position j in the first dimension hidden_states[i, j, 0] = float(j) return MockEncoderOutput(last_hidden_state=hidden_states) @@ -578,130 +395,60 @@ def mock_encoder_forward(**kwargs): mock_encoder.config.hidden_size = hidden_size model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = 0 # No chunking + model.passage_chunk_size = 0 - # Test right padding pooling p_reps_right = model.encode_passage(p_batch_right) - - # Test left padding pooling p_reps_left = model.encode_passage(p_batch_left) - # Verify pooling extracts from correct positions - # Right padding: uses sequence_lengths calculation (attention_mask.sum() - 1) - expected_pos_right = last_valid_pos_right - assert torch.allclose(p_reps_right[0, 0], torch.tensor(float(expected_pos_right))) - - # Left padding: The _pooling logic checks: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) - # If last position is 1, it uses last_hidden_state[:, -1] - # Otherwise, it calculates sequence_lengths = attention_mask.sum(dim=1) - 1 - if is_detected_as_left_padding: - expected_pos_left = seq_len_left - 1 - else: - expected_pos_left = num_valid_left - 1 + assert torch.allclose(p_reps_right[0, 0], torch.tensor(float(last_valid_pos_right))) + expected_pos_left = len(attention_mask_left) - 1 if is_left_padding else num_valid_left - 1 assert torch.allclose(p_reps_left[0, 0], torch.tensor(float(expected_pos_left))) - - # Verify the key difference: right padding always uses sequence_lengths calculation - # Left padding uses last position if detected as left padding, otherwise sequence_lengths - # The actual positions depend on the padding structure - print(f"Right padding: extracted from position {expected_pos_right} (last_valid_pos)") - print(f"Left padding: extracted from position {expected_pos_left} (is_left_padding={is_detected_as_left_padding})") - print(f"Right padding mask: first={attention_mask_right[0].item()}, last={attention_mask_right[-1].item()}") - print(f"Left padding mask: first={attention_mask_left[0].item()}, last={attention_mask_left[-1].item()}") @pytest.mark.unit def test_chunked_passages_left_padding(train_tokenizer): - """ - Test chunked passage encoding with left padding. - This verifies that EOS positions are correctly adjusted when padding is on the left. - """ + """Test that EOS positions are correctly adjusted for left padding.""" import torch from unittest.mock import Mock from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator from tevatron.retriever.modeling.dense import DenseModel - # Test passage that will be chunked test_passage = REAL_TEXT - # Test Case 1: Right padding (baseline) + # Right padding (baseline) data_args_right = DataArguments( passage_max_len=128, passage_chunk_size=64, pad_to_multiple_of=16, padding_side="right", - passage_prefix="", append_eos_token=False, ) - collator_right = TrainCollator(data_args=data_args_right, tokenizer=train_tokenizer) - q_batch_right, p_batch_right, eos_positions_right = collator_right([("query", [test_passage], [])]) - - # Verify right padding structure - input_ids_right = p_batch_right['input_ids'][0] - attention_mask_right = p_batch_right['attention_mask'][0] - seq_len_right = len(attention_mask_right) - - # With right padding, content tokens are at the beginning, padding at the end - assert attention_mask_right[-1] == 0, "Right padding: last position should be padding" + _, p_batch_right, eos_positions_right = collator_right([("query", [test_passage], [])]) - # Verify EOS positions are correct (should be in the content area, before padding) - for eos_pos in eos_positions_right[0]: - assert eos_pos < attention_mask_right.sum().item(), f"EOS position {eos_pos} should be in valid token range" - assert input_ids_right[eos_pos] == train_tokenizer.eos_token_id, f"Position {eos_pos} should be EOS token" - - # Test Case 2: Left padding + # Left padding data_args_left = DataArguments( passage_max_len=128, passage_chunk_size=64, pad_to_multiple_of=16, padding_side="left", - passage_prefix="", append_eos_token=False, ) - collator_left = TrainCollator(data_args=data_args_left, tokenizer=train_tokenizer) - q_batch_left, p_batch_left, eos_positions_left = collator_left([("query", [test_passage], [])]) + _, p_batch_left, eos_positions_left = collator_left([("query", [test_passage], [])]) - # Verify left padding structure - input_ids_left = p_batch_left['input_ids'][0] attention_mask_left = p_batch_left['attention_mask'][0] - seq_len_left = len(attention_mask_left) - - # With left padding, padding tokens are at the beginning, content at the end - # Note: Due to pad_to_multiple_of, the actual padding structure may vary - # Check that there is padding at the beginning num_valid_tokens = attention_mask_left.sum().item() - padding_length = seq_len_left - num_valid_tokens - if padding_length > 0: - # If there's padding, first positions should be padding - assert attention_mask_left[0] == 0, "Left padding: first position should be padding when padding exists" - assert attention_mask_left[-1] == 1, "Left padding: last position should be content (valid token)" - - # Verify EOS positions are correctly adjusted for left padding - # EOS positions should be shifted by the padding length - - # Verify all EOS positions are in the valid token range (after padding) - for eos_pos in eos_positions_left[0]: - assert eos_pos >= padding_length, f"EOS position {eos_pos} should be after padding (padding_length={padding_length})" - assert eos_pos < seq_len_left, f"EOS position {eos_pos} should be within sequence length {seq_len_left}" - assert input_ids_left[eos_pos] == train_tokenizer.eos_token_id, f"Position {eos_pos} should be EOS token" - assert attention_mask_left[eos_pos] == 1, f"EOS position {eos_pos} should be in valid token range" - - # Verify that EOS positions are correctly shifted - # The relative positions within the content should be the same, but absolute positions differ - # Right padding: EOS at positions like [63, 127] (before padding) - # Left padding: EOS at positions like [padding_length + 63, padding_length + 127] (after padding) - assert len(eos_positions_right[0]) == len(eos_positions_left[0]), "Should have same number of chunks" - - # Verify the relative positions are preserved (EOS positions differ by padding_length) - for i, (eos_right, eos_left) in enumerate(zip(eos_positions_right[0], eos_positions_left[0])): - expected_left_pos = eos_right + padding_length - assert eos_left == expected_left_pos, \ - f"Chunk {i}: EOS position should be shifted by padding_length. " \ - f"Expected {expected_left_pos}, got {eos_left} (right={eos_right}, padding_length={padding_length})" + padding_length = len(attention_mask_left) - num_valid_tokens + + # Verify EOS positions are shifted by padding_length + assert len(eos_positions_right[0]) == len(eos_positions_left[0]) + for eos_right, eos_left in zip(eos_positions_right[0], eos_positions_left[0]): + assert eos_left == eos_right + padding_length + assert p_batch_left['input_ids'][0][eos_left] == train_tokenizer.eos_token_id - # Test Case 3: Verify pooling behavior with mock model + # Test pooling extracts from correct positions hidden_size = 64 class MockEncoderOutput: @@ -711,11 +458,9 @@ def __init__(self, last_hidden_state): def mock_encoder_forward(**kwargs): input_ids = kwargs['input_ids'] batch_size, seq_len = input_ids.shape - # Create hidden states where each position encodes its position index hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) for i in range(batch_size): for j in range(seq_len): - # Encode position j in the first dimension hidden_states[i, j, 0] = float(j) return MockEncoderOutput(last_hidden_state=hidden_states) @@ -726,106 +471,14 @@ def mock_encoder_forward(**kwargs): model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) model.passage_chunk_size = 64 - # Test right padding pooling - chunk_reps_right, chunk_mask_right = model.encode_passage(p_batch_right, eos_positions_right) - - # Test left padding pooling - chunk_reps_left, chunk_mask_left = model.encode_passage(p_batch_left, eos_positions_left) - - # Verify pooling extracts from correct EOS positions - # Right padding: extracts from eos_positions_right - # Left padding: extracts from eos_positions_left (which are adjusted) - assert chunk_reps_right.shape == chunk_reps_left.shape, "Should have same number of chunks" - assert chunk_mask_right.shape == chunk_mask_left.shape, "Should have same chunk mask shape" + chunk_reps_right, _ = model.encode_passage(p_batch_right, eos_positions_right) + chunk_reps_left, _ = model.encode_passage(p_batch_left, eos_positions_left) - # Verify that embeddings are extracted from the correct positions - # For right padding, EOS at position 63 should give embedding with value 63.0 - # For left padding, EOS at position (padding_length + 63) should give embedding with value (padding_length + 63.0) + # Verify embeddings differ by padding_length for i, (eos_right, eos_left) in enumerate(zip(eos_positions_right[0], eos_positions_left[0])): - # Right padding: embedding should encode position eos_right - assert torch.allclose(chunk_reps_right[0, i, 0], torch.tensor(float(eos_right))), \ - f"Right padding chunk {i}: embedding should encode EOS position {eos_right}" - - # Left padding: embedding should encode position eos_left - assert torch.allclose(chunk_reps_left[0, i, 0], torch.tensor(float(eos_left))), \ - f"Left padding chunk {i}: embedding should encode EOS position {eos_left}" - - # Verify masks are correct - assert chunk_mask_right[0, i] == 1.0, f"Right padding chunk {i} should be valid" - assert chunk_mask_left[0, i] == 1.0, f"Left padding chunk {i} should be valid" - - # Verify that the embeddings differ by the padding length (in the first dimension) - # This confirms that EOS positions are correctly adjusted - for i in range(len(eos_positions_right[0])): - expected_diff = float(padding_length) - actual_diff = chunk_reps_left[0, i, 0] - chunk_reps_right[0, i, 0] - assert torch.allclose(actual_diff, torch.tensor(expected_diff)), \ - f"Chunk {i}: embedding difference should equal padding_length. " \ - f"Expected {expected_diff}, got {actual_diff.item()}" - - print(f"Right padding EOS positions: {eos_positions_right[0]}") - print(f"Left padding EOS positions: {eos_positions_left[0]}") - print(f"Padding length: {padding_length}") - print(f"Sequence length: {seq_len_left}") - print(f"Valid tokens: {num_valid_tokens}") - - # Test Case 4: Verify with append_eos_token=True - data_args_right_eos = DataArguments( - passage_max_len=64, - passage_chunk_size=0, - pad_to_multiple_of=16, - padding_side="right", - passage_prefix="", - append_eos_token=True, - ) - - data_args_left_eos = DataArguments( - passage_max_len=64, - passage_chunk_size=0, - pad_to_multiple_of=16, - padding_side="left", - passage_prefix="", - append_eos_token=True, - ) - - collator_right_eos = TrainCollator(data_args=data_args_right_eos, tokenizer=train_tokenizer) - collator_left_eos = TrainCollator(data_args=data_args_left_eos, tokenizer=train_tokenizer) - - q_batch_eos_right, p_batch_eos_right = collator_right_eos([("query", [test_passage], [])]) - q_batch_eos_left, p_batch_eos_left = collator_left_eos([("query", [test_passage], [])]) - - # Verify EOS token is present in both - content_right_eos = p_batch_eos_right['input_ids'][0][p_batch_eos_right['attention_mask'][0].bool()].tolist() - content_left_eos = p_batch_eos_left['input_ids'][0][p_batch_eos_left['attention_mask'][0].bool()].tolist() - - assert content_right_eos[-1] == train_tokenizer.eos_token_id - assert content_left_eos[-1] == train_tokenizer.eos_token_id - - # Test pooling with EOS - p_reps_eos_right = model.encode_passage(p_batch_eos_right) - p_reps_eos_left = model.encode_passage(p_batch_eos_left) - - # Both should extract from EOS position - mask_eos_right = p_batch_eos_right['attention_mask'][0] - mask_eos_left = p_batch_eos_left['attention_mask'][0] - - # Right padding: uses sequence_lengths calculation - last_valid_eos_right = mask_eos_right.sum().item() - 1 - - # Left padding: checks if last position is valid - is_left_padding_eos = (mask_eos_left[-1] == 1).item() - if is_left_padding_eos: - last_valid_eos_left = mask_eos_left.shape[0] - 1 - else: - last_valid_eos_left = mask_eos_left.sum().item() - 1 - - assert torch.allclose(p_reps_eos_right[0, 0], torch.tensor(float(last_valid_eos_right))) - assert torch.allclose(p_reps_eos_left[0, 0], torch.tensor(float(last_valid_eos_left))) - - # With EOS, the extracted positions should be where EOS is located - assert p_batch_eos_right['input_ids'][0][last_valid_eos_right] == train_tokenizer.eos_token_id - assert p_batch_eos_left['input_ids'][0][last_valid_eos_left] == train_tokenizer.eos_token_id - - # Summary: This test verifies that padding_side affects pooling position calculation - # Right padding: always uses attention_mask.sum() - 1 - # Left padding: uses seq_len - 1 if last position is valid, otherwise attention_mask.sum() - 1 + assert torch.allclose(chunk_reps_right[0, i, 0], torch.tensor(float(eos_right))) + assert torch.allclose(chunk_reps_left[0, i, 0], torch.tensor(float(eos_left))) + assert torch.allclose( + chunk_reps_left[0, i, 0] - chunk_reps_right[0, i, 0], + torch.tensor(float(padding_length)) + ) diff --git a/tests/test_chunking_helper.py b/tests/test_chunking_helper.py new file mode 100644 index 00000000..32448ef0 --- /dev/null +++ b/tests/test_chunking_helper.py @@ -0,0 +1,229 @@ +""" +Unit tests for _chunk_tokens helper function. +""" +import sys +from pathlib import Path +import pytest + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.mark.unit +def test_chunk_tokens_basic(): + """Test basic chunking functionality.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=4 means chunk_len=3, so chunks are: + # [0,1,2,99], [3,4,5,99], [6,7,8,99], [9,99] + expected_ids = [0, 1, 2, 99, 3, 4, 5, 99, 6, 7, 8, 99, 9, 99] + expected_eos_pos = [3, 7, 11, 13] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_with_max_length(): + """Test chunking with max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 5 + max_length = 12 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # chunk_size=5 means chunk_len=4 + # First chunk: [0,1,2,3,99] = 5 tokens + # Second chunk: [4,5,6,7,99] = 5 tokens + # Total: 10 tokens, but max_length=12 allows one more EOS + # Third chunk would need at least 1 token + 1 EOS = 2 tokens, but we only have 2 left + # So we can fit: [8,99] = 2 tokens + # Total: 12 tokens + assert len(ids) == 12 + assert ids[-1] == eos_id # Last token should be EOS + assert len(eos_pos) == 3 + assert all(ids[pos] == eos_id for pos in eos_pos) + + +@pytest.mark.unit +def test_chunk_tokens_max_length_exact_fit(): + """Test chunking when max_length exactly fits chunks.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + max_length = 14 # Exactly fits 3 chunks: 3*4 + 2 = 14 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Should have 3 chunks: [0,1,2,99], [3,4,5,99], [6,7,8,99] = 12 tokens + # Plus [9,99] = 2 tokens, total 14 + assert len(ids) == 14 + assert len(eos_pos) == 4 + + +@pytest.mark.unit +def test_chunk_tokens_max_length_too_small(): + """Test chunking when max_length is too small for even one chunk.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + max_length = 1 # Too small for even one chunk (need at least 2: 1 token + EOS) + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Should return empty since we can't fit even one chunk + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_empty_input(): + """Test chunking with empty token list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_single_token(): + """Test chunking with single token.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [42] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + assert ids == [42, 99] + assert eos_pos == [1] + + +@pytest.mark.unit +def test_chunk_tokens_no_max_length(): + """Test chunking without max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(15)) + eos_id = 99 + chunk_size = 5 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length=None) + + # chunk_size=5 means chunk_len=4 + # Should have 4 chunks: [0-3,99], [4-7,99], [8-11,99], [12-14,99] + assert len(ids) == 19 # 15 tokens + 4 EOS tokens + assert len(eos_pos) == 4 + assert all(ids[pos] == eos_id for pos in eos_pos) + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_one(): + """Test chunking with chunk_size=1 (invalid, should return empty).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [1, 2, 3] + eos_id = 99 + chunk_size = 1 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=1 is invalid (need at least 2: 1 token + 1 EOS) + # Should return empty + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_two(): + """Test chunking with chunk_size=2.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [1, 2, 3, 4, 5] + eos_id = 99 + chunk_size = 2 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=2 means chunk_len=1 + # Chunks: [1,99], [2,99], [3,99], [4,99], [5,99] + expected_ids = [1, 99, 2, 99, 3, 99, 4, 99, 5, 99] + expected_eos_pos = [1, 3, 5, 7, 9] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_eos_positions_are_correct(): + """Test that EOS positions correctly point to EOS tokens.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # Verify all EOS positions contain EOS token + for pos in eos_pos: + assert ids[pos] == eos_id + + # Verify EOS positions are strictly increasing + assert all(eos_pos[i] < eos_pos[i + 1] for i in range(len(eos_pos) - 1)) + + +@pytest.mark.unit +def test_chunk_tokens_max_length_stops_at_boundary(): + """Test that max_length stops chunking at chunk boundary.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 5 + max_length = 10 # Exactly 2 chunks: 2*5 = 10 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + assert len(ids) == 10 + assert len(eos_pos) == 2 + # Should have exactly 2 chunks: [0,1,2,3,99], [4,5,6,7,99] + assert ids == [0, 1, 2, 3, 99, 4, 5, 6, 7, 99] + diff --git a/tests/test_chunking_pooling_equivalence.py b/tests/test_chunking_pooling_equivalence.py index ca2ec4a0..268d0cbf 100644 --- a/tests/test_chunking_pooling_equivalence.py +++ b/tests/test_chunking_pooling_equivalence.py @@ -1,6 +1,6 @@ """ Test to verify that when chunk_size == passage_max_len and there's only one chunk, -chunked and non-chunked modes should produce the same embeddings. +chunked and non-chunked modes extract embeddings from different positions. """ import sys from pathlib import Path @@ -26,72 +26,47 @@ def train_tokenizer(): @pytest.mark.unit def test_chunked_vs_non_chunked_when_chunk_size_equals_max_len(train_tokenizer): - """ - When chunk_size == passage_max_len and passage fits in one chunk, - chunked and non-chunked should produce identical embeddings. - """ + """When chunk_size == passage_max_len, chunked mode adds EOS and extracts from EOS position.""" _add_tevatron_src_to_path() from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import TrainCollator, ChunkedEncodeCollator + from tevatron.retriever.collator import TrainCollator from tevatron.retriever.modeling.dense import DenseModel from unittest.mock import Mock - # Test passage that fits in one chunk test_passage = "This is a test passage that will fit in one chunk." + passage_max_len = chunk_size = 64 - # Configuration: chunk_size == passage_max_len - passage_max_len = 64 - chunk_size = 64 # Same as max_len - - # Test Case 1: Non-chunked mode + # Non-chunked mode data_args_non_chunked = DataArguments( passage_max_len=passage_max_len, - passage_chunk_size=0, # No chunking + passage_chunk_size=0, pad_to_multiple_of=16, padding_side="right", - passage_prefix="", - append_eos_token=False, # Default: False + append_eos_token=False, ) - collator_non_chunked = TrainCollator(data_args=data_args_non_chunked, tokenizer=train_tokenizer) - q_batch_non_chunked, p_batch_non_chunked = collator_non_chunked([("query", [test_passage], [])]) + _, p_batch_non_chunked = collator_non_chunked([("query", [test_passage], [])]) - # Test Case 2: Chunked mode with chunk_size == max_len + # Chunked mode data_args_chunked = DataArguments( passage_max_len=passage_max_len, - passage_chunk_size=chunk_size, # Same as max_len + passage_chunk_size=chunk_size, pad_to_multiple_of=16, padding_side="right", - passage_prefix="", - append_eos_token=False, # Same as non-chunked + append_eos_token=False, ) - collator_chunked = TrainCollator(data_args=data_args_chunked, tokenizer=train_tokenizer) - q_batch_chunked, p_batch_chunked, eos_positions = collator_chunked([("query", [test_passage], [])]) - - # Verify tokenization differences - input_ids_non_chunked = p_batch_non_chunked['input_ids'][0] - input_ids_chunked = p_batch_chunked['input_ids'][0] - - # Chunked mode adds EOS after chunk, non-chunked doesn't (when append_eos_token=False) - # So chunked should have one more token (the EOS) - non_chunked_content = input_ids_non_chunked[p_batch_non_chunked['attention_mask'][0].bool()].tolist() - chunked_content = input_ids_chunked[p_batch_chunked['attention_mask'][0].bool()].tolist() - - print(f"Non-chunked content tokens: {len(non_chunked_content)}") - print(f"Chunked content tokens: {len(chunked_content)}") - print(f"EOS positions: {eos_positions}") + _, p_batch_chunked, eos_positions = collator_chunked([("query", [test_passage], [])]) - # Chunked should have EOS token at the end of the chunk - assert chunked_content[-1] == train_tokenizer.eos_token_id, "Chunked mode should have EOS at end" - # Non-chunked should NOT have EOS (when append_eos_token=False) - assert non_chunked_content[-1] != train_tokenizer.eos_token_id, "Non-chunked mode should NOT have EOS" + # Verify tokenization: chunked adds EOS, non-chunked doesn't + non_chunked_content = p_batch_non_chunked['input_ids'][0][p_batch_non_chunked['attention_mask'][0].bool()].tolist() + chunked_content = p_batch_chunked['input_ids'][0][p_batch_chunked['attention_mask'][0].bool()].tolist() - # The content tokens (excluding EOS) should be the same - chunked_content_without_eos = chunked_content[:-1] - assert non_chunked_content == chunked_content_without_eos, "Content tokens should be identical (excluding EOS)" + assert chunked_content[-1] == train_tokenizer.eos_token_id + assert non_chunked_content[-1] != train_tokenizer.eos_token_id + assert non_chunked_content == chunked_content[:-1] - # Now test pooling behavior + # Test pooling: different positions yield different embeddings hidden_size = 64 class MockEncoderOutput: @@ -101,11 +76,9 @@ def __init__(self, last_hidden_state): def mock_encoder_forward(**kwargs): input_ids = kwargs['input_ids'] batch_size, seq_len = input_ids.shape - # Create hidden states where each position encodes its position index hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) for i in range(batch_size): for j in range(seq_len): - # Encode position j in the first dimension hidden_states[i, j, 0] = float(j) return MockEncoderOutput(last_hidden_state=hidden_states) @@ -113,41 +86,16 @@ def mock_encoder_forward(**kwargs): mock_encoder.config = Mock() mock_encoder.config.hidden_size = hidden_size - # Non-chunked model model_non_chunked = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) model_non_chunked.passage_chunk_size = 0 - - # Chunked model model_chunked = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) model_chunked.passage_chunk_size = chunk_size - # Get embeddings p_reps_non_chunked = model_non_chunked.encode_passage(p_batch_non_chunked) - p_reps_chunked_tuple = model_chunked.encode_passage(p_batch_chunked, eos_positions) - p_reps_chunked, chunk_mask = p_reps_chunked_tuple - - # Non-chunked: extracts from last content token position - # Chunked: extracts from EOS position (which is one position after last content token) - - # Get the actual positions - mask_non_chunked = p_batch_non_chunked['attention_mask'][0] - last_valid_pos_non_chunked = mask_non_chunked.sum().item() - 1 - - # Chunked: EOS position - eos_pos = eos_positions[0][0] # First (and only) chunk's EOS position - - print(f"Non-chunked extracts from position: {last_valid_pos_non_chunked}") - print(f"Chunked extracts from position: {eos_pos}") - print(f"Non-chunked embedding value: {p_reps_non_chunked[0, 0].item()}") - print(f"Chunked embedding value: {p_reps_chunked[0, 0, 0].item()}") + p_reps_chunked, _ = model_chunked.encode_passage(p_batch_chunked, eos_positions) - # These should be DIFFERENT because they extract from different positions - # Non-chunked: last content token - # Chunked: EOS token (one position after) - assert eos_pos == last_valid_pos_non_chunked + 1, \ - f"EOS should be one position after last content token: {eos_pos} vs {last_valid_pos_non_chunked}" + last_valid_pos = p_batch_non_chunked['attention_mask'][0].sum().item() - 1 + eos_pos = eos_positions[0][0] - # The embeddings will be different because they're extracted from different positions - # This is the root cause of the inconsistency! - assert not torch.allclose(p_reps_non_chunked[0], p_reps_chunked[0, 0]), \ - "Embeddings should be different because they're extracted from different token positions" + assert eos_pos == last_valid_pos + 1 + assert not torch.allclose(p_reps_non_chunked[0], p_reps_chunked[0, 0]) diff --git a/tests/test_forward.py b/tests/test_forward.py index 50ec91d9..b57782fb 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -136,10 +136,7 @@ def encode_passage(self, psg): @pytest.mark.unit def test_forward_with_chunking(train_tokenizer): - """ - Test model forward function with chunked passages. - This tests the integration of encode_query, encode_passage, and compute_maxsim_similarity. - """ + """Test model forward with chunked passages: encode_query, encode_passage, compute_maxsim_similarity.""" _add_tevatron_src_to_path() from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator @@ -151,30 +148,19 @@ def test_forward_with_chunking(train_tokenizer): "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient." ) - # Setup data arguments data_args = DataArguments( passage_chunk_size=32, passage_max_len=128, pad_to_multiple_of=16, padding_side="right", - passage_prefix="", - query_prefix="", append_eos_token=False, ) - - # Create collator collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - # Create test data: 2 queries, 2 passages (1 positive each) queries = ["What is cerebral white matter?", "What is MRI?"] passages = [REAL_TEXT, "MRI stands for Magnetic Resonance Imaging."] + q_batch, p_batch, eos_positions = collator([(q, [p], []) for q, p in zip(queries, passages)]) - # Collate data - q_batch, p_batch, eos_positions = collator([ - (q, [p], []) for q, p in zip(queries, passages) - ]) - - # Create mock encoder hidden_size = 64 class MockEncoderOutput: @@ -184,82 +170,49 @@ def __init__(self, last_hidden_state): def mock_encoder_forward(**kwargs): input_ids = kwargs['input_ids'] batch_size, seq_len = input_ids.shape - hidden_states = torch.randn(batch_size, seq_len, hidden_size) - return MockEncoderOutput(last_hidden_state=hidden_states) + return MockEncoderOutput(last_hidden_state=torch.randn(batch_size, seq_len, hidden_size)) mock_encoder = Mock(side_effect=mock_encoder_forward) mock_encoder.config = Mock() mock_encoder.config.hidden_size = hidden_size - # Create model model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) model.passage_chunk_size = data_args.passage_chunk_size model.eos_positions = eos_positions model.training = True - # Forward pass output = model(query=q_batch, passage=p_batch) - # Verify output structure assert hasattr(output, 'q_reps') assert hasattr(output, 'p_reps') assert hasattr(output, 'scores') assert hasattr(output, 'loss') + assert output.q_reps.shape == (len(queries), hidden_size) - # Verify query representations - assert output.q_reps.shape[0] == len(queries) - assert output.q_reps.shape[1] == hidden_size - - # Verify passage representations (chunked) - assert isinstance(output.p_reps, tuple) # Should be (chunk_reps, chunk_mask) chunk_reps, chunk_mask = output.p_reps assert chunk_reps.shape[0] == len(passages) assert chunk_reps.shape[2] == hidden_size - assert chunk_mask.shape == (len(passages), chunk_reps.shape[1]) - - # Verify scores shape - # With 2 queries and 2 passages, scores should be [2, 2] assert output.scores.shape == (len(queries), len(passages)) + assert output.loss.item() >= 0 - # Verify loss is computed - assert output.loss is not None - assert output.loss.item() >= 0 # Loss should be non-negative - - # Test Case 2: Verify MaxSim is used (not regular similarity) - # Create a scenario where MaxSim gives different result than mean pooling + # Test MaxSim with known embeddings model.eval() with torch.no_grad(): - # Use known embeddings where max chunk is different from mean q_reps_test = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) - # Passage with 2 chunks: first chunk similar to query, second chunk dissimilar p_reps_test = torch.tensor([ - [[1.0, 0.0], [0.0, 1.0]], # Passage 0: chunk 0 matches query 0, chunk 1 matches query 1 - [[0.0, 1.0], [1.0, 0.0]], # Passage 1: chunk 0 matches query 1, chunk 1 matches query 0 + [[1.0, 0.0], [0.0, 1.0]], + [[0.0, 1.0], [1.0, 0.0]], ], dtype=torch.float32) chunk_mask_test = torch.ones(2, 2) scores_test = model.compute_maxsim_similarity(q_reps_test, p_reps_test, chunk_mask_test) - - # Query 0 with Passage 0: max similarity should be with chunk 0 (1.0) - assert torch.allclose(scores_test[0, 0], torch.tensor(1.0)) - # Query 0 with Passage 1: max similarity should be with chunk 1 (1.0) - assert torch.allclose(scores_test[0, 1], torch.tensor(1.0)) - # Query 1 with Passage 0: max similarity should be with chunk 1 (1.0) - assert torch.allclose(scores_test[1, 0], torch.tensor(1.0)) - # Query 1 with Passage 1: max similarity should be with chunk 0 (1.0) - assert torch.allclose(scores_test[1, 1], torch.tensor(1.0)) + assert torch.allclose(scores_test, torch.ones(2, 2)) - # Test Case 3: Verify padding chunks are ignored + # Test padding chunks are ignored p_reps_padded = torch.randn(2, 3, hidden_size) - chunk_mask_padded = torch.tensor([ - [1.0, 1.0, 0.0], # Passage 0: 2 valid chunks - [1.0, 0.0, 0.0], # Passage 1: 1 valid chunk - ]) - + chunk_mask_padded = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) scores_padded = model.compute_maxsim_similarity(output.q_reps, p_reps_padded, chunk_mask_padded) - assert scores_padded.shape == (len(queries), len(passages)) - # Verify that padding doesn't affect the max (should only consider valid chunks) for q_idx in range(len(queries)): for p_idx in range(len(passages)): chunk_scores = torch.einsum('h,ch->c', output.q_reps[q_idx], p_reps_padded[p_idx]) diff --git a/tests/test_padding_helper.py b/tests/test_padding_helper.py new file mode 100644 index 00000000..f496ea00 --- /dev/null +++ b/tests/test_padding_helper.py @@ -0,0 +1,314 @@ +""" +Unit tests for _pad_and_adjust_eos_positions helper function. +""" +import sys +from pathlib import Path +import pytest +import torch +from unittest.mock import Mock, MagicMock + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer for testing.""" + tokenizer = Mock() + tokenizer.pad_token_id = 0 + tokenizer.eos_token_id = 99 + + def pad_fn(encodings, padding=True, pad_to_multiple_of=None, return_attention_mask=True, return_tensors=None): + """Mock pad function that simulates tokenizer.pad behavior.""" + input_ids = encodings['input_ids'] + max_len = max(len(ids) for ids in input_ids) if input_ids else 0 + + # Round up to multiple of pad_to_multiple_of + if pad_to_multiple_of: + max_len = ((max_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + + padded_ids = [] + attention_masks = [] + + for ids in input_ids: + if tokenizer.padding_side == 'right': + pad_length = max_len - len(ids) + padded = ids + [tokenizer.pad_token_id] * pad_length + mask = [1] * len(ids) + [0] * pad_length + else: # left padding + pad_length = max_len - len(ids) + padded = [tokenizer.pad_token_id] * pad_length + ids + mask = [0] * pad_length + [1] * len(ids) + + padded_ids.append(padded) + attention_masks.append(mask) + + result = {'input_ids': padded_ids, 'attention_mask': attention_masks} + + if return_tensors == 'pt': + result['input_ids'] = torch.tensor(result['input_ids']) + result['attention_mask'] = torch.tensor(result['attention_mask']) + + return result + + tokenizer.pad = MagicMock(side_effect=pad_fn) + tokenizer.padding_side = 'right' + return tokenizer + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_right_padding(mock_tokenizer): + """Test padding with right padding (no EOS position adjustment needed).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [ + [1, 2, 3, 99], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, 99], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[3], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # With pad_to_multiple_of=4, max_len=4 -> padded to 4 + assert padded_dict['input_ids'].shape == (2, 4) + assert padded_dict['attention_mask'].shape == (2, 4) + + # EOS positions should not change for right padding + assert adjusted_eos_positions == [[3], [2]] + + # Verify EOS tokens are at correct positions + assert padded_dict['input_ids'][0][3].item() == 99 + assert padded_dict['input_ids'][1][2].item() == 99 + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_left_padding(mock_tokenizer): + """Test padding with left padding (EOS positions should be shifted).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [ + [1, 2, 3, 99], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, 99], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[3], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='left', + pad_to_multiple_of=4, + ) + + # With pad_to_multiple_of=4, max_len=4 -> padded to 4 + assert padded_dict['input_ids'].shape == (2, 4) + + # Passage 0: original length 4, padded length 4, padding_length=0, EOS stays at 3 + # Passage 1: original length 3, padded length 4, padding_length=1, EOS shifts from 2 to 3 + assert adjusted_eos_positions == [[3], [3]] + + # Verify EOS tokens are at correct positions after padding + assert padded_dict['input_ids'][0][3].item() == 99 + assert padded_dict['input_ids'][1][3].item() == 99 + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_multiple_eos(mock_tokenizer): + """Test padding with multiple EOS positions per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [ + [1, 2, 99, 3, 4, 99], # Passage 0: 6 tokens, EOS at positions 2, 5 + [5, 99], # Passage 1: 2 tokens, EOS at position 1 + ] + all_eos_positions = [[2, 5], [1]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + + # With pad_to_multiple_of=8, max_len=6 -> padded to 8 + assert padded_dict['input_ids'].shape == (2, 8) + + # Passage 0: original length 6, padded length 8, padding_length=2, EOS shift from [2,5] to [4,7] + # Passage 1: original length 2, padded length 8, padding_length=6, EOS shift from 1 to 7 + assert adjusted_eos_positions == [[4, 7], [7]] + + # Verify EOS tokens are at correct positions + assert padded_dict['input_ids'][0][4].item() == 99 + assert padded_dict['input_ids'][0][7].item() == 99 + assert padded_dict['input_ids'][1][7].item() == 99 + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_no_padding_needed(mock_tokenizer): + """Test when sequences are already the same length.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [ + [1, 2, 99], + [3, 4, 99], + ] + all_eos_positions = [[2], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # With pad_to_multiple_of=4, max_len=3 -> padded to 4 + assert padded_dict['input_ids'].shape == (2, 4) + + # EOS positions unchanged for right padding + assert adjusted_eos_positions == [[2], [2]] + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_empty_input(mock_tokenizer): + """Test with empty input.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [] + all_eos_positions = [] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + assert len(adjusted_eos_positions) == 0 + assert padded_dict['input_ids'].shape[0] == 0 + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_single_passage(mock_tokenizer): + """Test with single passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [[1, 2, 3, 99]] + all_eos_positions = [[3]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + assert padded_dict['input_ids'].shape == (1, 4) + assert adjusted_eos_positions == [[3]] + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_pad_to_multiple_of_one(mock_tokenizer): + """Test with pad_to_multiple_of=1 (no rounding).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [ + [1, 2, 99], + [3, 99], + ] + all_eos_positions = [[2], [1]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='right', + pad_to_multiple_of=1, + ) + + # Should pad to max_len=3 (no rounding needed) + assert padded_dict['input_ids'].shape == (2, 3) + assert adjusted_eos_positions == [[2], [1]] + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_left_padding_multiple_chunks(mock_tokenizer): + """Test left padding with multiple chunks per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [ + [1, 99, 2, 3, 99], # Passage 0: 5 tokens, EOS at positions 1, 4 + [4, 5, 99], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[1, 4], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + + # With pad_to_multiple_of=8, max_len=5 -> padded to 8 + assert padded_dict['input_ids'].shape == (2, 8) + + # Passage 0: original length 5, padded length 8, padding_length=3, EOS shift from [1,4] to [4,7] + # Passage 1: original length 3, padded length 8, padding_length=5, EOS shift from 2 to 7 + assert adjusted_eos_positions == [[4, 7], [7]] + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_tokenizer_padding_side_set(mock_tokenizer): + """Test that tokenizer.padding_side is set correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [[1, 2, 99]] + all_eos_positions = [[2]] + + # Test right padding + mock_tokenizer.padding_side = 'right' + _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + assert mock_tokenizer.padding_side == 'right' + + # Test left padding + mock_tokenizer.padding_side = 'left' + _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=mock_tokenizer, + padding_side='left', + pad_to_multiple_of=4, + ) + assert mock_tokenizer.padding_side == 'left' + diff --git a/tests/test_pooling.py b/tests/test_pooling.py index 0f0ae921..2c65eca3 100644 --- a/tests/test_pooling.py +++ b/tests/test_pooling.py @@ -230,213 +230,101 @@ def mock_encoder_forward(**kwargs): @pytest.mark.unit def test_pooling_chunked_eos_positions_alignment(): - """ - Test _pooling_chunked to verify that eos_positions correctly align with hidden states. - This test uses known hidden states and eos_positions to verify exact alignment. - """ + """Test _pooling_chunked extracts embeddings from correct EOS positions.""" import torch from unittest.mock import Mock from tevatron.retriever.modeling.dense import DenseModel - # Create a mock encoder mock_encoder = Mock() mock_encoder.config.hidden_size = 8 - - # Create DenseModel model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) model.passage_chunk_size = 32 - # Test Case 1: Simple case with known positions - # Batch size=2, seq_len=10, hidden_size=8 - # Passage 0: eos at positions [2, 5, 8] (3 chunks) - # Passage 1: eos at positions [3, 7] (2 chunks) - batch_size = 2 - seq_len = 10 - hidden_size = 8 - - # Create hidden states with known values - each position has a unique pattern - # We'll use position index as part of the embedding to make verification easy + batch_size, seq_len, hidden_size = 2, 10, 8 hidden_states = torch.zeros(batch_size, seq_len, hidden_size) for i in range(batch_size): for j in range(seq_len): - # Set embedding at position j to have value j*100 + i*10 in first dimension - # This makes it easy to verify we're extracting the right positions hidden_states[i, j, 0] = j * 100 + i * 10 - # Fill other dimensions with position-dependent values for k in range(1, hidden_size): hidden_states[i, j, k] = j * 10 + k eos_positions = [[2, 5, 8], [3, 7]] - - # Call _pooling_chunked chunk_reps, chunk_mask = model._pooling_chunked(hidden_states, eos_positions) - # Verify output shapes - assert chunk_reps.shape == (batch_size, 3, hidden_size) # max_chunks = 3 + assert chunk_reps.shape == (batch_size, 3, hidden_size) assert chunk_mask.shape == (batch_size, 3) - # Verify Passage 0: should extract positions [2, 5, 8] - # Position 2: should have 2*100 + 0*10 = 200 in first dim - assert torch.allclose(chunk_reps[0, 0, 0], torch.tensor(200.0)) - assert torch.allclose(chunk_reps[0, 0, 1], torch.tensor(21.0)) # 2*10 + 1 - - # Position 5: should have 5*100 + 0*10 = 500 in first dim - assert torch.allclose(chunk_reps[0, 1, 0], torch.tensor(500.0)) - assert torch.allclose(chunk_reps[0, 1, 1], torch.tensor(51.0)) # 5*10 + 1 - - # Position 8: should have 8*100 + 0*10 = 800 in first dim - assert torch.allclose(chunk_reps[0, 2, 0], torch.tensor(800.0)) - assert torch.allclose(chunk_reps[0, 2, 1], torch.tensor(81.0)) # 8*10 + 1 - - # Verify Passage 1: should extract positions [3, 7] - # Position 3: should have 3*100 + 1*10 = 310 in first dim - assert torch.allclose(chunk_reps[1, 0, 0], torch.tensor(310.0)) - assert torch.allclose(chunk_reps[1, 0, 1], torch.tensor(31.0)) # 3*10 + 1 - - # Position 7: should have 7*100 + 1*10 = 710 in first dim - assert torch.allclose(chunk_reps[1, 1, 0], torch.tensor(710.0)) - assert torch.allclose(chunk_reps[1, 1, 1], torch.tensor(71.0)) # 7*10 + 1 - - # Verify chunk_mask - assert chunk_mask[0, 0] == 1.0 # Passage 0, chunk 0 (pos 2) - assert chunk_mask[0, 1] == 1.0 # Passage 0, chunk 1 (pos 5) - assert chunk_mask[0, 2] == 1.0 # Passage 0, chunk 2 (pos 8) - assert chunk_mask[1, 0] == 1.0 # Passage 1, chunk 0 (pos 3) - assert chunk_mask[1, 1] == 1.0 # Passage 1, chunk 1 (pos 7) - assert chunk_mask[1, 2] == 0.0 # Passage 1, chunk 2 (no chunk, should be 0) - - # Test Case 2: Verify exact tensor equality (not just close) - # Create hidden states where each position has a unique embedding - hidden_states_2 = torch.arange(batch_size * seq_len * hidden_size, dtype=torch.float32) - hidden_states_2 = hidden_states_2.reshape(batch_size, seq_len, hidden_size) + # Verify correct positions extracted + assert torch.allclose(chunk_reps[0, 0, 0], torch.tensor(200.0)) # pos 2 + assert torch.allclose(chunk_reps[0, 1, 0], torch.tensor(500.0)) # pos 5 + assert torch.allclose(chunk_reps[0, 2, 0], torch.tensor(800.0)) # pos 8 + assert torch.allclose(chunk_reps[1, 0, 0], torch.tensor(310.0)) # pos 3 + assert torch.allclose(chunk_reps[1, 1, 0], torch.tensor(710.0)) # pos 7 - # Extract embeddings manually for comparison - expected_chunk_0_0 = hidden_states_2[0, 2] # Passage 0, position 2 - expected_chunk_0_1 = hidden_states_2[0, 5] # Passage 0, position 5 - expected_chunk_0_2 = hidden_states_2[0, 8] # Passage 0, position 8 - expected_chunk_1_0 = hidden_states_2[1, 3] # Passage 1, position 3 - expected_chunk_1_1 = hidden_states_2[1, 7] # Passage 1, position 7 + # Verify chunk mask + assert (chunk_mask[0, :3] == 1.0).all() + assert (chunk_mask[1, :2] == 1.0).all() + assert chunk_mask[1, 2] == 0.0 - chunk_reps_2, chunk_mask_2 = model._pooling_chunked(hidden_states_2, eos_positions) + # Test exact equality with sequential hidden states + hidden_states_2 = torch.arange(batch_size * seq_len * hidden_size, dtype=torch.float32) + hidden_states_2 = hidden_states_2.reshape(batch_size, seq_len, hidden_size) + chunk_reps_2, _ = model._pooling_chunked(hidden_states_2, eos_positions) - # Verify exact equality - assert torch.equal(chunk_reps_2[0, 0], expected_chunk_0_0) - assert torch.equal(chunk_reps_2[0, 1], expected_chunk_0_1) - assert torch.equal(chunk_reps_2[0, 2], expected_chunk_0_2) - assert torch.equal(chunk_reps_2[1, 0], expected_chunk_1_0) - assert torch.equal(chunk_reps_2[1, 1], expected_chunk_1_1) + assert torch.equal(chunk_reps_2[0, 0], hidden_states_2[0, 2]) + assert torch.equal(chunk_reps_2[0, 1], hidden_states_2[0, 5]) + assert torch.equal(chunk_reps_2[0, 2], hidden_states_2[0, 8]) + assert torch.equal(chunk_reps_2[1, 0], hidden_states_2[1, 3]) + assert torch.equal(chunk_reps_2[1, 1], hidden_states_2[1, 7]) - # Test Case 3: Edge case - empty eos_positions + # Test edge cases chunk_reps_empty, chunk_mask_empty = model._pooling_chunked(hidden_states, []) assert chunk_reps_empty.shape == (batch_size, 0, hidden_size) - assert chunk_mask_empty.shape == (batch_size, 0) - # Test Case 4: Edge case - out of bounds position (should be handled gracefully) - eos_positions_oob = [[2, 5, 15], [3, 7]] # 15 is out of bounds for seq_len=10 + eos_positions_oob = [[2, 5, 15], [3, 7]] chunk_reps_oob, chunk_mask_oob = model._pooling_chunked(hidden_states, eos_positions_oob) - - # Should still extract valid positions - assert chunk_reps_oob.shape == (batch_size, 3, hidden_size) - assert torch.allclose(chunk_reps_oob[0, 0], hidden_states[0, 2]) # Valid - assert torch.allclose(chunk_reps_oob[0, 1], hidden_states[0, 5]) # Valid - # Position 15 is out of bounds, so chunk_reps[0, 2] should be zeros assert torch.allclose(chunk_reps_oob[0, 2], torch.zeros(hidden_size)) - assert chunk_mask_oob[0, 2] == 0.0 # Should be masked out + assert chunk_mask_oob[0, 2] == 0.0 - # Test Case 5: Normalize=True + # Test normalization model.normalize = True - chunk_reps_norm, chunk_mask_norm = model._pooling_chunked(hidden_states_2, eos_positions) - - # Verify normalization (L2 norm should be 1 for non-zero chunks) - for i in range(batch_size): - for j in range(len(eos_positions[i])): - norm = torch.norm(chunk_reps_norm[i, j]) - assert torch.allclose(norm, torch.tensor(1.0), atol=1e-6) - - # Verify the normalized embeddings are proportional to original - model.normalize = False - chunk_reps_no_norm, _ = model._pooling_chunked(hidden_states_2, eos_positions) + chunk_reps_norm, _ = model._pooling_chunked(hidden_states_2, eos_positions) for i in range(batch_size): for j in range(len(eos_positions[i])): - # Normalized version should be original / norm - expected_norm = torch.norm(chunk_reps_no_norm[i, j]) - normalized_manual = chunk_reps_no_norm[i, j] / expected_norm - assert torch.allclose(chunk_reps_norm[i, j], normalized_manual, atol=1e-6) - - # Test Case 6: Single chunk per passage - eos_positions_single = [[4], [6]] - chunk_reps_single, chunk_mask_single = model._pooling_chunked(hidden_states_2, eos_positions_single) - - assert chunk_reps_single.shape == (batch_size, 1, hidden_size) - assert torch.equal(chunk_reps_single[0, 0], hidden_states_2[0, 4]) - assert torch.equal(chunk_reps_single[1, 0], hidden_states_2[1, 6]) - assert chunk_mask_single[0, 0] == 1.0 - assert chunk_mask_single[1, 0] == 1.0 - - # Test Case 7: Verify positions are extracted in correct order - # Use sequential positions to verify order - eos_positions_ordered = [[1, 3, 5], [2, 4]] - chunk_reps_ordered, _ = model._pooling_chunked(hidden_states_2, eos_positions_ordered) - - # Passage 0: should be in order [1, 3, 5] - assert torch.equal(chunk_reps_ordered[0, 0], hidden_states_2[0, 1]) - assert torch.equal(chunk_reps_ordered[0, 1], hidden_states_2[0, 3]) - assert torch.equal(chunk_reps_ordered[0, 2], hidden_states_2[0, 5]) - - # Passage 1: should be in order [2, 4] - assert torch.equal(chunk_reps_ordered[1, 0], hidden_states_2[1, 2]) - assert torch.equal(chunk_reps_ordered[1, 1], hidden_states_2[1, 4]) + assert torch.allclose(torch.norm(chunk_reps_norm[i, j]), torch.tensor(1.0), atol=1e-6) @pytest.mark.unit def test_pooling_chunked_real_tokenizer_alignment(train_tokenizer): - """ - Integration test: Verify that eos_positions from ChunkedEncodeCollator - correctly align with hidden states when using _pooling_chunked. - This uses real tokenizer to ensure end-to-end correctness. - """ + """Integration test: eos_positions from collator correctly align with hidden states.""" import torch from unittest.mock import Mock from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import ChunkedEncodeCollator from tevatron.retriever.modeling.dense import DenseModel - # Setup data arguments data_args = DataArguments( passage_chunk_size=32, passage_max_len=128, pad_to_multiple_of=16, padding_side="right", - passage_prefix="", append_eos_token=False, ) - - # Create collator collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) - - # Test passages - passages = [ - REAL_TEXT, # Long passage that will be chunked - "Short passage for testing.", # Short passage - ] - - # Get tokenized and chunked data + passages = [REAL_TEXT, "Short passage for testing."] d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages(passages) - # Verify eos_positions are valid input_ids = d_collated['input_ids'] seq_len = input_ids.shape[1] + # Verify eos_positions are valid for i, eos_pos_list in enumerate(eos_positions): - assert len(eos_pos_list) > 0, f"Passage {i} should have at least one chunk" + assert len(eos_pos_list) > 0 for pos in eos_pos_list: - assert 0 <= pos < seq_len, f"EOS position {pos} out of bounds for seq_len {seq_len}" - # Verify that the position actually contains EOS token - assert input_ids[i, pos] == train_tokenizer.eos_token_id, \ - f"Position {pos} should contain EOS token {train_tokenizer.eos_token_id}, got {input_ids[i, pos]}" + assert 0 <= pos < seq_len + assert input_ids[i, pos] == train_tokenizer.eos_token_id - # Create mock encoder that returns hidden states based on input_ids - # This allows us to verify exact alignment + # Create mock encoder hidden_size = 64 class MockEncoderOutput: @@ -446,51 +334,33 @@ def __init__(self, last_hidden_state): def mock_encoder_forward(**kwargs): input_ids = kwargs['input_ids'] batch_size, seq_len = input_ids.shape - - # Create hidden states where each position's embedding encodes its position - # This makes it easy to verify we're extracting the right positions hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) for i in range(batch_size): for j in range(seq_len): - # Encode position j in the embedding - # Use input_ids[i, j] as seed to make it unique per token - hidden_states[i, j, 0] = float(j) # Position index - hidden_states[i, j, 1] = float(input_ids[i, j]) # Token ID - # Fill rest with position-dependent values + hidden_states[i, j, 0] = float(j) + hidden_states[i, j, 1] = float(input_ids[i, j]) for k in range(2, hidden_size): hidden_states[i, j, k] = float(j * hidden_size + k) - return MockEncoderOutput(last_hidden_state=hidden_states) mock_encoder = Mock(side_effect=mock_encoder_forward) mock_encoder.config = Mock() mock_encoder.config.hidden_size = hidden_size - # Create model model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) model.passage_chunk_size = data_args.passage_chunk_size - # Convert BatchEncoding to dict for model batch_inputs = { 'input_ids': d_collated['input_ids'], 'attention_mask': d_collated['attention_mask'], } - - # Encode with chunking chunk_reps, chunk_mask = model.encode_passage(batch_inputs, eos_positions) - # Verify shapes batch_size = len(passages) max_chunks = max(len(pos_list) for pos_list in eos_positions) assert chunk_reps.shape == (batch_size, max_chunks, hidden_size) - assert chunk_mask.shape == (batch_size, max_chunks) - - # Verify that extracted embeddings match the eos_positions - # We need to get the hidden states that were generated - # Since we can't easily access them, we'll verify by checking the mock was called correctly - # and that the extracted positions match what we expect - # Re-create hidden states with the same logic to verify + # Re-create expected hidden states hidden_states_expected = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) for i in range(batch_size): for j in range(seq_len): @@ -499,34 +369,15 @@ def mock_encoder_forward(**kwargs): for k in range(2, hidden_size): hidden_states_expected[i, j, k] = float(j * hidden_size + k) - # Verify each extracted chunk embedding matches the expected position + # Verify extracted embeddings match expected positions for i, eos_pos_list in enumerate(eos_positions): for j, pos in enumerate(eos_pos_list): - # The extracted embedding should match the hidden state at position pos - expected_embedding = hidden_states_expected[i, pos] - extracted_embedding = chunk_reps[i, j] - - # Verify exact match (they should be identical) - assert torch.equal(extracted_embedding, expected_embedding), \ - f"Passage {i}, chunk {j} (eos_pos={pos}): extracted embedding doesn't match hidden state at position {pos}" - - # Verify chunk mask is set correctly - assert chunk_mask[i, j] == 1.0, f"Chunk mask should be 1.0 for valid chunk" + assert torch.equal(chunk_reps[i, j], hidden_states_expected[i, pos]) + assert chunk_mask[i, j] == 1.0 + assert torch.allclose(chunk_reps[i, j, 0], torch.tensor(float(pos))) - # Verify that invalid chunks (beyond actual chunks) have mask=0 + # Verify invalid chunks are masked for i in range(batch_size): num_chunks = len(eos_positions[i]) for j in range(num_chunks, max_chunks): - assert chunk_mask[i, j] == 0.0, f"Invalid chunk should have mask=0" - - # Verify that the first dimension of extracted embeddings contains position indices - for i, eos_pos_list in enumerate(eos_positions): - for j, pos in enumerate(eos_pos_list): - # First dimension should equal the position - assert torch.allclose(chunk_reps[i, j, 0], torch.tensor(float(pos))), \ - f"First dim should equal position {pos}, got {chunk_reps[i, j, 0]}" - - # Second dimension should equal the token ID at that position - expected_token_id = float(input_ids[i, pos]) - assert torch.allclose(chunk_reps[i, j, 1], torch.tensor(expected_token_id)), \ - f"Second dim should equal token ID {expected_token_id}, got {chunk_reps[i, j, 1]}" + assert chunk_mask[i, j] == 0.0 From 1fa3c1e7db76e4c2775609aef15e75fd2a8116f9 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 14:47:53 -0500 Subject: [PATCH 20/31] added padding helper --- tests/test_padding_helper.py | 297 +++++++++++++++++++++-------------- 1 file changed, 175 insertions(+), 122 deletions(-) diff --git a/tests/test_padding_helper.py b/tests/test_padding_helper.py index f496ea00..45c699e4 100644 --- a/tests/test_padding_helper.py +++ b/tests/test_padding_helper.py @@ -5,7 +5,6 @@ from pathlib import Path import pytest import torch -from unittest.mock import Mock, MagicMock def _tevatron_root() -> Path: @@ -17,177 +16,181 @@ def _add_tevatron_src_to_path(): sys.path.insert(0, str(src)) -@pytest.fixture -def mock_tokenizer(): - """Create a mock tokenizer for testing.""" - tokenizer = Mock() - tokenizer.pad_token_id = 0 - tokenizer.eos_token_id = 99 - - def pad_fn(encodings, padding=True, pad_to_multiple_of=None, return_attention_mask=True, return_tensors=None): - """Mock pad function that simulates tokenizer.pad behavior.""" - input_ids = encodings['input_ids'] - max_len = max(len(ids) for ids in input_ids) if input_ids else 0 - - # Round up to multiple of pad_to_multiple_of - if pad_to_multiple_of: - max_len = ((max_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of - - padded_ids = [] - attention_masks = [] - - for ids in input_ids: - if tokenizer.padding_side == 'right': - pad_length = max_len - len(ids) - padded = ids + [tokenizer.pad_token_id] * pad_length - mask = [1] * len(ids) + [0] * pad_length - else: # left padding - pad_length = max_len - len(ids) - padded = [tokenizer.pad_token_id] * pad_length + ids - mask = [0] * pad_length + [1] * len(ids) - - padded_ids.append(padded) - attention_masks.append(mask) - - result = {'input_ids': padded_ids, 'attention_mask': attention_masks} - - if return_tensors == 'pt': - result['input_ids'] = torch.tensor(result['input_ids']) - result['attention_mask'] = torch.tensor(result['attention_mask']) - - return result - - tokenizer.pad = MagicMock(side_effect=pad_fn) - tokenizer.padding_side = 'right' - return tokenizer +@pytest.fixture(scope="session") +def train_tokenizer(): + """Use the Qwen 0.6B tokenizer.""" + _add_tevatron_src_to_path() + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.padding_side = "right" + return tok @pytest.mark.unit -def test_pad_and_adjust_eos_positions_right_padding(mock_tokenizer): +def test_pad_and_adjust_eos_positions_right_padding(train_tokenizer): """Test padding with right padding (no EOS position adjustment needed).""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + all_input_ids = [ - [1, 2, 3, 99], # Passage 0: 4 tokens, EOS at position 3 - [4, 5, 99], # Passage 1: 3 tokens, EOS at position 2 + [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 ] all_eos_positions = [[3], [2]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='right', pad_to_multiple_of=4, ) - # With pad_to_multiple_of=4, max_len=4 -> padded to 4 - assert padded_dict['input_ids'].shape == (2, 4) - assert padded_dict['attention_mask'].shape == (2, 4) - - # EOS positions should not change for right padding - assert adjusted_eos_positions == [[3], [2]] - - # Verify EOS tokens are at correct positions - assert padded_dict['input_ids'][0][3].item() == 99 - assert padded_dict['input_ids'][1][2].item() == 99 + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) + [4, 5, eos_id, pad_id], # Passage 1: padded to 4 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # Passage 0: all tokens valid + [1, 1, 1, 0], # Passage 1: last token is padding + ]) + expected_eos_positions = [[3], [2]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_left_padding(mock_tokenizer): +def test_pad_and_adjust_eos_positions_left_padding(train_tokenizer): """Test padding with left padding (EOS positions should be shifted).""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + all_input_ids = [ - [1, 2, 3, 99], # Passage 0: 4 tokens, EOS at position 3 - [4, 5, 99], # Passage 1: 3 tokens, EOS at position 2 + [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 ] all_eos_positions = [[3], [2]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='left', pad_to_multiple_of=4, ) - # With pad_to_multiple_of=4, max_len=4 -> padded to 4 - assert padded_dict['input_ids'].shape == (2, 4) - + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) + [pad_id, 4, 5, eos_id], # Passage 1: padded to 4 (1 padding token on left) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # Passage 0: all tokens valid + [0, 1, 1, 1], # Passage 1: first token is padding + ]) # Passage 0: original length 4, padded length 4, padding_length=0, EOS stays at 3 # Passage 1: original length 3, padded length 4, padding_length=1, EOS shifts from 2 to 3 - assert adjusted_eos_positions == [[3], [3]] + expected_eos_positions = [[3], [3]] - # Verify EOS tokens are at correct positions after padding - assert padded_dict['input_ids'][0][3].item() == 99 - assert padded_dict['input_ids'][1][3].item() == 99 + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_multiple_eos(mock_tokenizer): +def test_pad_and_adjust_eos_positions_multiple_eos(train_tokenizer): """Test padding with multiple EOS positions per passage.""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + all_input_ids = [ - [1, 2, 99, 3, 4, 99], # Passage 0: 6 tokens, EOS at positions 2, 5 - [5, 99], # Passage 1: 2 tokens, EOS at position 1 + [1, 2, eos_id, 3, 4, eos_id], # Passage 0: 6 tokens, EOS at positions 2, 5 + [5, eos_id], # Passage 1: 2 tokens, EOS at position 1 ] all_eos_positions = [[2, 5], [1]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='left', pad_to_multiple_of=8, ) - # With pad_to_multiple_of=8, max_len=6 -> padded to 8 - assert padded_dict['input_ids'].shape == (2, 8) - + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [pad_id, pad_id, 1, 2, eos_id, 3, 4, eos_id], # Passage 0: padded to 8 (2 padding tokens on left) + [pad_id, pad_id, pad_id, pad_id, pad_id, pad_id, 5, eos_id], # Passage 1: padded to 8 (6 padding tokens on left) + ]) + expected_attention_mask = torch.tensor([ + [0, 0, 1, 1, 1, 1, 1, 1], # Passage 0: first 2 tokens are padding + [0, 0, 0, 0, 0, 0, 1, 1], # Passage 1: first 6 tokens are padding + ]) # Passage 0: original length 6, padded length 8, padding_length=2, EOS shift from [2,5] to [4,7] # Passage 1: original length 2, padded length 8, padding_length=6, EOS shift from 1 to 7 - assert adjusted_eos_positions == [[4, 7], [7]] + expected_eos_positions = [[4, 7], [7]] - # Verify EOS tokens are at correct positions - assert padded_dict['input_ids'][0][4].item() == 99 - assert padded_dict['input_ids'][0][7].item() == 99 - assert padded_dict['input_ids'][1][7].item() == 99 + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_no_padding_needed(mock_tokenizer): +def test_pad_and_adjust_eos_positions_no_padding_needed(train_tokenizer): """Test when sequences are already the same length.""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + all_input_ids = [ - [1, 2, 99], - [3, 4, 99], + [1, 2, eos_id], + [3, 4, eos_id], ] all_eos_positions = [[2], [2]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='right', pad_to_multiple_of=4, ) - # With pad_to_multiple_of=4, max_len=3 -> padded to 4 - assert padded_dict['input_ids'].shape == (2, 4) - - # EOS positions unchanged for right padding - assert adjusted_eos_positions == [[2], [2]] + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, eos_id, pad_id], # Padded to 4 (1 padding token) + [3, 4, eos_id, pad_id], # Padded to 4 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 0], # Last token is padding + [1, 1, 1, 0], # Last token is padding + ]) + expected_eos_positions = [[2], [2]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_empty_input(mock_tokenizer): +def test_pad_and_adjust_eos_positions_empty_input(train_tokenizer): """Test with empty input.""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions @@ -198,117 +201,167 @@ def test_pad_and_adjust_eos_positions_empty_input(mock_tokenizer): padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='right', pad_to_multiple_of=4, ) - assert len(adjusted_eos_positions) == 0 - assert padded_dict['input_ids'].shape[0] == 0 + # Hardcoded golden output for empty input + expected_eos_positions = [] + + assert adjusted_eos_positions == expected_eos_positions + # When input is empty, tokenizer.pad may return list or tensor depending on implementation + if isinstance(padded_dict['input_ids'], torch.Tensor): + assert padded_dict['input_ids'].shape[0] == 0 + assert padded_dict['attention_mask'].shape[0] == 0 + else: + assert len(padded_dict['input_ids']) == 0 + assert len(padded_dict['attention_mask']) == 0 @pytest.mark.unit -def test_pad_and_adjust_eos_positions_single_passage(mock_tokenizer): +def test_pad_and_adjust_eos_positions_single_passage(train_tokenizer): """Test with single passage.""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions - all_input_ids = [[1, 2, 3, 99]] + eos_id = train_tokenizer.eos_token_id + + all_input_ids = [[1, 2, 3, eos_id]] all_eos_positions = [[3]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='right', pad_to_multiple_of=4, ) - assert padded_dict['input_ids'].shape == (1, 4) - assert adjusted_eos_positions == [[3]] + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Already length 4, no padding needed + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # All tokens valid + ]) + expected_eos_positions = [[3]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_pad_to_multiple_of_one(mock_tokenizer): +def test_pad_and_adjust_eos_positions_pad_to_multiple_of_one(train_tokenizer): """Test with pad_to_multiple_of=1 (no rounding).""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + all_input_ids = [ - [1, 2, 99], - [3, 99], + [1, 2, eos_id], + [3, eos_id], ] all_eos_positions = [[2], [1]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='right', pad_to_multiple_of=1, ) - # Should pad to max_len=3 (no rounding needed) - assert padded_dict['input_ids'].shape == (2, 3) - assert adjusted_eos_positions == [[2], [1]] + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, eos_id], # Padded to max_len=3 (no rounding needed with pad_to_multiple_of=1) + [3, eos_id, pad_id], # Padded to max_len=3 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1], # All tokens valid + [1, 1, 0], # Last token is padding + ]) + expected_eos_positions = [[2], [1]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_left_padding_multiple_chunks(mock_tokenizer): +def test_pad_and_adjust_eos_positions_left_padding_multiple_chunks(train_tokenizer): """Test left padding with multiple chunks per passage.""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + all_input_ids = [ - [1, 99, 2, 3, 99], # Passage 0: 5 tokens, EOS at positions 1, 4 - [4, 5, 99], # Passage 1: 3 tokens, EOS at position 2 + [1, eos_id, 2, 3, eos_id], # Passage 0: 5 tokens, EOS at positions 1, 4 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 ] all_eos_positions = [[1, 4], [2]] padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='left', pad_to_multiple_of=8, ) - # With pad_to_multiple_of=8, max_len=5 -> padded to 8 - assert padded_dict['input_ids'].shape == (2, 8) - + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [pad_id, pad_id, pad_id, 1, eos_id, 2, 3, eos_id], # Passage 0: padded to 8 (3 padding tokens on left) + [pad_id, pad_id, pad_id, pad_id, pad_id, 4, 5, eos_id], # Passage 1: padded to 8 (5 padding tokens on left) + ]) + expected_attention_mask = torch.tensor([ + [0, 0, 0, 1, 1, 1, 1, 1], # Passage 0: first 3 tokens are padding + [0, 0, 0, 0, 0, 1, 1, 1], # Passage 1: first 5 tokens are padding + ]) # Passage 0: original length 5, padded length 8, padding_length=3, EOS shift from [1,4] to [4,7] # Passage 1: original length 3, padded length 8, padding_length=5, EOS shift from 2 to 7 - assert adjusted_eos_positions == [[4, 7], [7]] + expected_eos_positions = [[4, 7], [7]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions @pytest.mark.unit -def test_pad_and_adjust_eos_positions_tokenizer_padding_side_set(mock_tokenizer): +def test_pad_and_adjust_eos_positions_tokenizer_padding_side_set(train_tokenizer): """Test that tokenizer.padding_side is set correctly.""" _add_tevatron_src_to_path() from tevatron.retriever.collator import _pad_and_adjust_eos_positions - all_input_ids = [[1, 2, 99]] + eos_id = train_tokenizer.eos_token_id + + all_input_ids = [[1, 2, eos_id]] all_eos_positions = [[2]] # Test right padding - mock_tokenizer.padding_side = 'right' + train_tokenizer.padding_side = 'right' _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='right', pad_to_multiple_of=4, ) - assert mock_tokenizer.padding_side == 'right' + assert train_tokenizer.padding_side == 'right' # Test left padding - mock_tokenizer.padding_side = 'left' + train_tokenizer.padding_side = 'left' _pad_and_adjust_eos_positions( all_input_ids=all_input_ids, all_eos_positions=all_eos_positions, - tokenizer=mock_tokenizer, + tokenizer=train_tokenizer, padding_side='left', pad_to_multiple_of=4, ) - assert mock_tokenizer.padding_side == 'left' + assert train_tokenizer.padding_side == 'left' From 224e18599af9df6fd583420b7085e8e48bcb1832 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 16:49:59 -0500 Subject: [PATCH 21/31] Tested the collator --- src/tevatron/retriever/collator.py | 109 +- tests/test_chunking.py | 1066 +++++++++++++++++--- tests/test_chunking_helper.py | 229 ----- tests/test_chunking_pooling_equivalence.py | 101 -- tests/test_padding_helper.py | 367 ------- 5 files changed, 949 insertions(+), 923 deletions(-) delete mode 100644 tests/test_chunking_helper.py delete mode 100644 tests/test_chunking_pooling_equivalence.py delete mode 100644 tests/test_padding_helper.py diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 65397a39..a656e058 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -20,48 +20,50 @@ def _chunk_tokens( max_length: int = None, ) -> Tuple[List[int], List[int]]: """ - Chunk a list of tokens into chunks of specified size, adding EOS token after each chunk. + Chunk tokens into fixed-size chunks with EOS separators. - :param tokens: List of token IDs to chunk - :param chunk_size: Maximum size of each chunk (before adding EOS). Must be >= 2. + :param tokens: Token IDs to chunk + :param chunk_size: Max chunk size (before EOS). Must be >= 2. :param eos_token_id: EOS token ID to append after each chunk - :param max_length: Optional maximum total length (including EOS tokens). If None, no limit. - :return: Tuple of (chunked_ids, eos_positions) where: - - chunked_ids: List of token IDs with EOS separators between chunks - - eos_positions: List of positions where EOS tokens were inserted + :param max_length: Optional max total length (including EOS). If None, no limit. + :return: (chunked_ids, eos_positions) - token IDs with EOS separators and EOS positions """ if chunk_size < 2: - # chunk_size must be at least 2 to fit at least 1 token + 1 EOS return [], [] chunk_len = chunk_size - 1 # Reserve 1 slot for EOS + + # Truncate tokens to fit within max_length + # Each chunk: chunk_len tokens + 1 EOS = chunk_size total + if max_length and max_length > 0: + max_tokens_to_use = 0 + remaining_length = max_length + + while remaining_length > 1 and max_tokens_to_use < len(tokens): + if remaining_length >= chunk_size: + max_tokens_to_use += chunk_len + remaining_length -= chunk_size + else: + max_tokens_to_use += remaining_length - 1 + break + + tokens = tokens[:max_tokens_to_use] + + # Chunk tokens and add EOS after each chunk ids = [] eos_pos = [] i = 0 while i < len(tokens): - if max_length and max_length > 0: - remaining = max_length - len(ids) - # Need at least 1 slot for EOS; otherwise stop (don't add empty chunks). - if remaining <= 1: - break - take = min(chunk_len, len(tokens) - i, remaining - 1) - if take <= 0: - break - else: - take = min(chunk_len, len(tokens) - i) - if take <= 0: - break - - chunk = tokens[i:i + take] # up to chunk_len tokens + take = min(chunk_len, len(tokens) - i) + chunk = tokens[i:i + take] ids.extend(chunk) - ids.append(eos_token_id) # EOS at end of this chunk - eos_pos.append(len(ids) - 1) # position of EOS (pooling position) + ids.append(eos_token_id) + eos_pos.append(len(ids) - 1) # EOS position for pooling i += take return ids, eos_pos - def _pad_and_adjust_eos_positions( all_input_ids: List[List[int]], all_eos_positions: List[List[int]], @@ -70,26 +72,19 @@ def _pad_and_adjust_eos_positions( pad_to_multiple_of: int, ) -> Tuple[dict, List[List[int]]]: """ - Pad input IDs and adjust EOS positions based on padding side. + Pad input IDs and adjust EOS positions for left padding. - :param all_input_ids: List of lists of token IDs (one per passage) - :param all_eos_positions: List of lists of EOS positions (one per passage) - :param tokenizer: Tokenizer to use for padding - :param padding_side: 'left' or 'right' - side to pad on - :param pad_to_multiple_of: Pad sequences to multiple of this value - :return: Tuple of (padded_dict, adjusted_eos_positions) where: - - padded_dict: dict with 'input_ids' and 'attention_mask' tensors - - adjusted_eos_positions: List of lists with EOS positions adjusted for padding + :param all_input_ids: List of token ID lists (one per passage) + :param all_eos_positions: List of EOS position lists (one per passage) + :param tokenizer: Tokenizer for padding + :param padding_side: 'left' or 'right' + :param pad_to_multiple_of: Pad to multiple of this value + :return: (padded_dict, adjusted_eos_positions) - padded tensors and adjusted EOS positions """ d_collated = {'input_ids': all_input_ids} - - # Store original lengths before padding to adjust eos_positions for left padding original_lengths = [len(ids) for ids in all_input_ids] - - # Set tokenizer padding_side before padding tokenizer.padding_side = padding_side - # Padding d_collated = tokenizer.pad( d_collated, padding=True, @@ -98,16 +93,12 @@ def _pad_and_adjust_eos_positions( return_tensors='pt', ) - # Adjust eos_positions for left padding - # When padding_side is 'left', padding tokens are added at the beginning, - # so EOS positions need to be shifted by the padding length - # Create a deep copy to avoid modifying the original + # Shift EOS positions for left padding adjusted_eos_positions = [list(eos_pos_list) for eos_pos_list in all_eos_positions] if padding_side == 'left': - padded_lengths = d_collated['input_ids'].shape[1] # All sequences have same length after padding + padded_lengths = d_collated['input_ids'].shape[1] for i, eos_pos_list in enumerate(adjusted_eos_positions): padding_length = padded_lengths - original_lengths[i] - # Shift each EOS position by the padding length adjusted_eos_positions[i] = [pos + padding_length for pos in eos_pos_list] return d_collated, adjusted_eos_positions @@ -119,18 +110,12 @@ def _tokenize_and_pad_chunked_passages( data_args: DataArguments, ) -> Tuple[dict, List[List[int]]]: """ - Tokenize passages with EOS separators between chunks. - Each chunk ends with EOS, enabling extraction of chunk embeddings from EOS positions. - - Uses the same token that tokenizer.add_special_tokens adds (e.g., <|endoftext|>) - so that query and passage use the same pooling token automatically. + Tokenize and chunk passages with EOS separators. Each chunk ends with EOS for embedding extraction. - :param passages: List of passage texts to tokenize and chunk - :param tokenizer: Tokenizer to use for encoding - :param data_args: DataArguments containing chunk_size, max_len, pad_to_multiple_of - :return: Tuple of (collated_dict, eos_positions) where: - - collated_dict: dict with 'input_ids' and 'attention_mask' tensors - - eos_positions: list of lists, one per passage, containing EOS token positions + :param passages: Passage texts to tokenize and chunk + :param tokenizer: Tokenizer for encoding + :param data_args: DataArguments with chunk_size, max_len, pad_to_multiple_of + :return: (collated_dict, eos_positions) - padded tensors and EOS positions per passage """ eos_id = tokenizer.eos_token_id if eos_id is None: @@ -185,7 +170,6 @@ def __call__(self, features: List[Tuple[str, List[str]]]): all_queries = [q[0] for q in all_queries] all_passages = [p[0] for p in all_passages] - # Query tokenization q_collated = self.tokenizer( all_queries, padding=False, @@ -205,7 +189,6 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return_tensors='pt', ) - # Passage tokenization if self.data_args.passage_chunk_size > 0: d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) return q_collated, d_collated, eos_positions @@ -387,18 +370,14 @@ def __call__(self, features): @dataclass class ChunkedEncodeCollator: - """ - Collator for chunked passage encoding (inference/search). - Splits passages into chunks with EOS separators, similar to training. - Uses the same chunking logic as TrainCollator._tokenize_and_pad_chunked_passages. - """ + """Collator for chunked passage encoding (inference/search). Uses same chunking logic as training.""" data_args: DataArguments tokenizer: PreTrainedTokenizer def __call__(self, features): """ - Collate function for chunked passage encoding. - :param features: list of (doc_id, text, image, video, audio) tuples + Collate chunked passage encoding features. + :param features: List of (doc_id, text, image, video, audio) tuples :return: (doc_ids, collated_inputs, eos_positions) """ doc_ids = [x[0] for x in features] diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 10502cb0..4f939dad 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest +import torch def _tevatron_root() -> Path: @@ -36,7 +37,7 @@ def _strictly_increasing(xs): "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " "development in cerebral white matter in living infants" ) -EOS_TOKEN_ID = 151645 +EOS_TOKEN_ID = 151643 PADDING_TOKEN_ID = 151643 @pytest.fixture(scope="session") @@ -50,10 +51,731 @@ def train_tokenizer(): tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id + tok.eos_token_id = tok.pad_token_id tok.padding_side = "right" # finetune_with_chunk.sh uses --padding_side right return tok +# ============================================================================ +# Unit tests for _chunk_tokens helper function +# ============================================================================ + +@pytest.mark.unit +def test_chunk_tokens_basic(): + """Test basic chunking functionality.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=4 means chunk_len=3, so chunks are: + # [0,1,2,99], [3,4,5,99], [6,7,8,99], [9,99] + expected_ids = [0, 1, 2, 99, 3, 4, 5, 99, 6, 7, 8, 99, 9, 99] + expected_eos_pos = [3, 7, 11, 13] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_with_max_length(): + """Test chunking with max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 5 + max_length = 12 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Hardcoded golden output: chunk_size=5 means chunk_len=4 + # First chunk: [0,1,2,3,99] = 5 tokens + # Second chunk: [4,5,6,7,99] = 5 tokens + # Third chunk: [8,99] = 2 tokens (partial, fits in remaining 2 tokens) + # Total: 12 tokens + expected_ids = [0, 1, 2, 3, 99, 4, 5, 6, 7, 99, 8, 99] + expected_eos_pos = [4, 9, 11] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_max_length_exact_fit(): + """Test chunking when max_length exactly fits chunks.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + max_length = 14 # Exactly fits 3 chunks: 3*4 + 2 = 14 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + expected_ids = [0, 1, 2, 99, 3, 4, 5, 99, 6, 7, 8, 99, 9, 99] + expected_eos_pos = [3, 7, 11, 13] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_max_length_too_small(): + """Test chunking when max_length is too small for even one chunk.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + max_length = 1 # Too small for even one chunk (need at least 2: 1 token + EOS) + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Should return empty since we can't fit even one chunk + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_empty_input(): + """Test chunking with empty token list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + assert ids == [] + assert eos_pos == [] + +@pytest.mark.unit +def test_chunk_tokens_same_length_as_chunk_size(): + """Test chunking when tokens are the same length as chunk_size.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 16 + max_length = 16 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + expected_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99] + expected_eos_pos = [15] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_single_token(): + """Test chunking with single token.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [42] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + assert ids == [42, 99] + assert eos_pos == [1] + + +@pytest.mark.unit +def test_chunk_tokens_no_max_length(): + """Test chunking without max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(15)) + eos_id = 99 + chunk_size = 5 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length=None) + + # Hardcoded golden output: chunk_size=5 means chunk_len=4 + # Chunks: [0-3,99], [4-7,99], [8-11,99], [12-14,99] + expected_ids = [0, 1, 2, 3, 99, 4, 5, 6, 7, 99, 8, 9, 10, 11, 99, 12, 13, 14, 99] + expected_eos_pos = [4, 9, 14, 18] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_one(): + """Test chunking with chunk_size=1 (invalid, should return empty).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [1, 2, 3] + eos_id = 99 + chunk_size = 1 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=1 is invalid (need at least 2: 1 token + 1 EOS) + # Should return empty + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_two(): + """Test chunking with chunk_size=2.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [1, 2, 3, 4, 5] + eos_id = 99 + chunk_size = 2 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=2 means chunk_len=1 + # Chunks: [1,99], [2,99], [3,99], [4,99], [5,99] + expected_ids = [1, 99, 2, 99, 3, 99, 4, 99, 5, 99] + expected_eos_pos = [1, 3, 5, 7, 9] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_max_length_stops_at_boundary(): + """Test that max_length stops chunking at chunk boundary.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 5 + max_length = 10 # Exactly 2 chunks: 2*5 = 10 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + expected_ids = [0, 1, 2, 3, 99, 4, 5, 6, 7, 99] + expected_eos_pos = [4, 9] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_greater_than_max_length(): + """Test chunking when chunk_size > max_length (only one partial chunk fits).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 10 # chunk_size > max_length + max_length = 5 # max_length < chunk_size + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Hardcoded golden output: chunk_size=10 means chunk_len=9, but max_length=5 + # Can only fit: 4 tokens + 1 EOS = 5 tokens (exactly max_length) + expected_ids = [0, 1, 2, 3, 99] + expected_eos_pos = [4] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_truncation_takes_from_front(): + """Test that truncation when tokens exceed max_length takes from the front (beginning) of the list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + # Create tokens with distinct values at front and back to verify truncation direction + tokens = list(range(20)) # [0, 1, 2, ..., 19] + eos_id = 99 + chunk_size = 5 # chunk_len = 4 + max_length = 8 # Can fit: 1 full chunk (4 tokens + 1 EOS = 5) + 1 partial (2 tokens + 1 EOS = 3) = 8 total + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Hardcoded golden output: truncation takes from front, so we get [0,1,2,3,99,4,5,99] + # If it took from back, we'd get [16,17,18,19,99,...] or similar + expected_ids = [0, 1, 2, 3, 99, 4, 5, 99] + expected_eos_pos = [4, 7] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + # Verify it's taking from the front: first token should be 0 (beginning of original list) + assert ids[0] == 0 + # Verify it's NOT taking from the back: last content token should be 5, not 19 + assert ids[-2] == 5 # Last content token before final EOS + assert ids[-2] != 19 # Confirms we're not taking from the end + + +@pytest.mark.unit +def test_chunk_tokens_truncation_then_padding_complex_case(train_tokenizer): + """Test complex case: tokens exceed max_length (truncation from front), then padding is applied.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens, _pad_and_adjust_eos_positions + + # Create a long token sequence that will be truncated + # Use distinct values to clearly see truncation direction + tokens = list(range(100, 200)) # [100, 101, 102, ..., 199] - 100 tokens + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + chunk_size = 10 # chunk_len = 9 + max_length = 20 # Can fit: 1 full chunk (9 tokens + 1 EOS = 10) + 1 partial (9 tokens + 1 EOS = 10) = 20 total + + # Step 1: Chunk with truncation (takes from front) + chunked_ids, eos_positions = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Verify truncation takes from front: should start with 100, not 199 + assert chunked_ids[0] == 100 # First token from original list + assert chunked_ids[-2] == 117 # Last content token (not 199) - second chunk ends at 117 + assert len(chunked_ids) == 20 # Exactly max_length + + # Hardcoded golden output: truncated from front + # Original: 100 tokens [100-199] + # After truncation (front): 18 tokens [100-117] + 2 EOS = 20 tokens + expected_chunked_ids = [ + 100, 101, 102, 103, 104, 105, 106, 107, 108, eos_id, # First chunk: 9 tokens + EOS + 109, 110, 111, 112, 113, 114, 115, 116, 117, eos_id # Second chunk: 9 tokens + EOS + ] + expected_eos_positions = [9, 19] # EOS positions before padding (list, not list of lists) + + assert chunked_ids == expected_chunked_ids + assert eos_positions == expected_eos_positions + + # Step 2: Test left padding with truncation + all_input_ids = [chunked_ids] + all_eos_positions = [eos_positions] + + # Apply our padding function + padded_dict_left, adjusted_eos_positions_left = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + expected_padded_ids_left = [ + pad_id, pad_id, pad_id, pad_id, # 4 padding tokens + 100, 101, 102, 103, 104, 105, 106, 107, 108, eos_id, # First chunk: 9 tokens + EOS + 109, 110, 111, 112, 113, 114, 115, 116, 117, eos_id # Second chunk: 9 tokens + EOS + ] + expected_attention_mask_left = [0, 0, 0, 0] + [1] * 20 # 4 padding + 20 content + expected_adjusted_eos_positions_left = [[13, 23]] + + assert padded_dict_left['input_ids'][0].tolist() == expected_padded_ids_left + assert padded_dict_left['attention_mask'][0].tolist() == expected_attention_mask_left + assert adjusted_eos_positions_left == expected_adjusted_eos_positions_left + +# ============================================================================ +# Unit tests for _pad_and_adjust_eos_positions helper function +# ============================================================================ + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_right_padding(train_tokenizer): + """Test padding with right padding (no EOS position adjustment needed).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[3], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) + [4, 5, eos_id, pad_id], # Passage 1: padded to 4 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # Passage 0: all tokens valid + [1, 1, 1, 0], # Passage 1: last token is padding + ]) + expected_eos_positions = [[3], [2]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=4, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_left_padding(train_tokenizer): + """Test padding with left padding (EOS positions should be shifted).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[3], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) + [pad_id, 4, 5, eos_id], # Passage 1: padded to 4 (1 padding token on left) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # Passage 0: all tokens valid + [0, 1, 1, 1], # Passage 1: first token is padding + ]) + # Passage 0: original length 4, padded length 4, padding_length=0, EOS stays at 3 + # Passage 1: original length 3, padded length 4, padding_length=1, EOS shifts from 2 to 3 + expected_eos_positions = [[3], [3]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'left' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=4, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_multiple_eos(train_tokenizer): + """Test padding with multiple EOS positions per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, eos_id, 3, 4, eos_id], # Passage 0: 6 tokens, EOS at positions 2, 5 + [5, eos_id], # Passage 1: 2 tokens, EOS at position 1 + ] + all_eos_positions = [[2, 5], [1]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [pad_id, pad_id, 1, 2, eos_id, 3, 4, eos_id], # Passage 0: padded to 8 (2 padding tokens on left) + [pad_id, pad_id, pad_id, pad_id, pad_id, pad_id, 5, eos_id], # Passage 1: padded to 8 (6 padding tokens on left) + ]) + expected_attention_mask = torch.tensor([ + [0, 0, 1, 1, 1, 1, 1, 1], # Passage 0: first 2 tokens are padding + [0, 0, 0, 0, 0, 0, 1, 1], # Passage 1: first 6 tokens are padding + ]) + # Passage 0: original length 6, padded length 8, padding_length=2, EOS shift from [2,5] to [4,7] + # Passage 1: original length 2, padded length 8, padding_length=6, EOS shift from 1 to 7 + expected_eos_positions = [[4, 7], [7]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'left' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=8, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_no_padding_needed(train_tokenizer): + """Test when sequences are already the same length.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, eos_id], + [3, 4, eos_id], + ] + all_eos_positions = [[2], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=3, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, eos_id], + [3, 4, eos_id], + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1], + [1, 1, 1], + ]) + expected_eos_positions = [[2], [2]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=3, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_empty_input(train_tokenizer): + """Test with empty input.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [] + all_eos_positions = [] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output for empty input + expected_eos_positions = [] + + assert adjusted_eos_positions == expected_eos_positions + # When input is empty, tokenizer.pad may return list or tensor depending on implementation + if isinstance(padded_dict['input_ids'], torch.Tensor): + assert padded_dict['input_ids'].shape[0] == 0 + assert padded_dict['attention_mask'].shape[0] == 0 + else: + assert len(padded_dict['input_ids']) == 0 + assert len(padded_dict['attention_mask']) == 0 + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_single_passage(train_tokenizer): + """Test with single passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + + all_input_ids = [[1, 2, 3, eos_id]] + all_eos_positions = [[3]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Already length 4, no padding needed + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # All tokens valid + ]) + expected_eos_positions = [[3]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=4, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_pad_to_multiple_of_one(train_tokenizer): + """Test with pad_to_multiple_of=1 (no rounding).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, eos_id], + [3, eos_id], + ] + all_eos_positions = [[2], [1]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=1, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, eos_id], # Padded to max_len=3 (no rounding needed with pad_to_multiple_of=1) + [3, eos_id, pad_id], # Padded to max_len=3 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1], # All tokens valid + [1, 1, 0], # Last token is padding + ]) + expected_eos_positions = [[2], [1]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=1, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_left_padding_multiple_chunks(train_tokenizer): + """Test left padding with multiple chunks per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, eos_id, 2, 3, eos_id], # Passage 0: 5 tokens, EOS at positions 1, 4 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[1, 4], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [pad_id, pad_id, pad_id, 1, eos_id, 2, 3, eos_id], # Passage 0: padded to 8 (3 padding tokens on left) + [pad_id, pad_id, pad_id, pad_id, pad_id, 4, 5, eos_id], # Passage 1: padded to 8 (5 padding tokens on left) + ]) + expected_attention_mask = torch.tensor([ + [0, 0, 0, 1, 1, 1, 1, 1], # Passage 0: first 3 tokens are padding + [0, 0, 0, 0, 0, 1, 1, 1], # Passage 1: first 5 tokens are padding + ]) + # Passage 0: original length 5, padded length 8, padding_length=3, EOS shift from [1,4] to [4,7] + # Passage 1: original length 3, padded length 8, padding_length=5, EOS shift from 2 to 7 + expected_eos_positions = [[4, 7], [7]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'left' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=8, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + + + @pytest.mark.unit def test_train_collator_chunked_passages(train_tokenizer): """Test chunking with passage_max_len=512, passage_chunk_size=256.""" @@ -73,15 +795,105 @@ def test_train_collator_chunked_passages(train_tokenizer): got_ids = d_collated["input_ids"][0].tolist() got_mask = d_collated["attention_mask"][0].tolist() + # Hardcoded golden output: 2 chunks (255 tokens + EOS, 174 tokens + EOS) = 431 tokens, padded to 432 + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, + 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, + 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, + 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, + 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, + 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, + 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, + 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, + 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, + 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, + 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, + 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, EOS_TOKEN_ID, 20, 51615, + 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, + 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, + 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, + 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, + 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, + 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, + 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, + 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, + 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, + 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, + 304, 5382, 41434, EOS_TOKEN_ID, PADDING_TOKEN_ID + ] + expected_mask = [1] * 431 + [0] # 431 ones + 1 zero + expected_eos_positions = [[255, 430]] + assert sum(got_mask) == 431 - assert eos_positions == [[255, 430]] + assert len(got_ids) == 432 # Padded to multiple of 16 + assert eos_positions == expected_eos_positions + assert got_ids == expected_ids + assert got_mask == expected_mask assert got_ids[255] == train_tokenizer.eos_token_id assert got_ids[430] == train_tokenizer.eos_token_id - assert len(got_ids) == 432 # Padded to multiple of 16 assert got_mask[255] == 1 assert got_mask[430] == 1 +@pytest.mark.unit +def test_train_collator_chunked_passages_left_padding(train_tokenizer): + """Test chunking with passage_max_len=512, passage_chunk_size=256, left padding.""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + passage_max_len=512, + passage_chunk_size=256, + pad_to_multiple_of=16, + padding_side="left", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + expected_ids = [ PADDING_TOKEN_ID, + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, + 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, + 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, + 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, + 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, + 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, + 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, + 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, + 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, + 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, + 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, + 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, EOS_TOKEN_ID, 20, 51615, + 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, + 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, + 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, + 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, + 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, + 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, + 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, + 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, + 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, + 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, + 304, 5382, 41434, EOS_TOKEN_ID + ] + expected_mask = [0] + [1] * 431 # 1 padding + 431 content + expected_eos_positions = [[256, 431]] + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + @pytest.mark.unit def test_chunked_collator_with_multiple_passages(train_tokenizer): """Test TrainCollator with chunking enabled returns (q_batch, p_batch, eos_positions).""" @@ -104,9 +916,42 @@ def test_chunked_collator_with_multiple_passages(train_tokenizer): q_batch, p_batch, eos_positions = collator(features) + # Hardcoded golden output: both passages have 2 chunks (31 tokens + EOS, 31 tokens + EOS) = 64 tokens each + expected_ids_0 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + EOS_TOKEN_ID, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, EOS_TOKEN_ID + ] + expected_mask_0 = [1] * 64 + expected_eos_0 = [31, 63] + + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + EOS_TOKEN_ID, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, EOS_TOKEN_ID + ] + expected_mask_1 = [1] * 64 + expected_eos_1 = [31, 63] + assert p_batch["input_ids"].shape[0] == 2 assert len(eos_positions) == 2 + got_ids_0 = p_batch["input_ids"][0].tolist() + got_mask_0 = p_batch["attention_mask"][0].tolist() + got_ids_1 = p_batch["input_ids"][1].tolist() + got_mask_1 = p_batch["attention_mask"][1].tolist() + + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_0 + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_1 + for i in range(p_batch["input_ids"].shape[0]): got_ids = p_batch["input_ids"][i].tolist() got_mask = p_batch["attention_mask"][i].tolist() @@ -120,15 +965,14 @@ def test_chunked_collator_with_multiple_passages(train_tokenizer): @pytest.mark.unit -@pytest.mark.parametrize("chunk_size", [64, 128]) -def test_chunking_capped_to_maxlen(train_tokenizer, chunk_size): - """When chunk_size >= max_len, chunking is capped to max_len with one EOS.""" +def test_chunking_capped_to_maxlen_chunk_size_64(train_tokenizer): + """When chunk_size >= max_len, chunking is capped to max_len with one EOS (chunk_size=64).""" from tevatron.retriever.arguments import DataArguments from tevatron.retriever.collator import TrainCollator long_text = (REAL_TEXT + " ") * 20 data_args = DataArguments( - passage_chunk_size=chunk_size, + passage_chunk_size=64, passage_max_len=64, pad_to_multiple_of=16, padding_side="right", @@ -139,9 +983,62 @@ def test_chunking_capped_to_maxlen(train_tokenizer, chunk_size): ids = d_collated["input_ids"][0].tolist() mask = d_collated["attention_mask"][0].tolist() + # Hardcoded golden output: 63 tokens + 1 EOS = 64 tokens + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, EOS_TOKEN_ID + ] + expected_mask = [1] * 64 + expected_eos_positions = [[63]] + assert sum(mask) == 64 assert len(ids) == 64 - assert eos_positions == [[63]] + assert eos_positions == expected_eos_positions + assert ids == expected_ids + assert mask == expected_mask + assert ids[63] == EOS_TOKEN_ID + assert EOS_TOKEN_ID not in ids[:63] + assert _strictly_increasing(eos_positions[0]) + + +@pytest.mark.unit +def test_chunking_capped_to_maxlen_chunk_size_128(train_tokenizer): + """When chunk_size >= max_len, chunking is capped to max_len with one EOS (chunk_size=128).""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + long_text = (REAL_TEXT + " ") * 20 + data_args = DataArguments( + passage_chunk_size=128, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: 63 tokens + 1 EOS = 64 tokens + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, EOS_TOKEN_ID + ] + expected_mask = [1] * 64 + expected_eos_positions = [[63]] + + assert sum(mask) == 64 + assert len(ids) == 64 + assert eos_positions == expected_eos_positions + assert ids == expected_ids + assert mask == expected_mask assert ids[63] == EOS_TOKEN_ID assert EOS_TOKEN_ID not in ids[:63] assert _strictly_increasing(eos_positions[0]) @@ -329,156 +1226,3 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): assert eos_positions[3] == expected_eos_3 assert ids_3 == expected_ids_3 assert mask_3 == expected_mask_3 - - -@pytest.mark.unit -def test_non_chunked_padding_side_behavior(train_tokenizer): - """Test that padding_side affects pooling position extraction.""" - import torch - from unittest.mock import Mock - from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import TrainCollator - from tevatron.retriever.modeling.dense import DenseModel - - test_passage = REAL_TEXT - - # Right padding - data_args_right = DataArguments( - passage_max_len=64, - passage_chunk_size=0, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator_right = TrainCollator(data_args=data_args_right, tokenizer=train_tokenizer) - _, p_batch_right = collator_right([("query", [test_passage], [])]) - attention_mask_right = p_batch_right['attention_mask'][0] - last_valid_pos_right = attention_mask_right.sum().item() - 1 - - # Left padding - data_args_left = DataArguments( - passage_max_len=64, - passage_chunk_size=0, - pad_to_multiple_of=16, - padding_side="left", - append_eos_token=False, - ) - collator_left = TrainCollator(data_args=data_args_left, tokenizer=train_tokenizer) - _, p_batch_left = collator_left([("query", [test_passage], [])]) - attention_mask_left = p_batch_left['attention_mask'][0] - num_valid_left = attention_mask_left.sum().item() - is_left_padding = (attention_mask_left[-1] == 1).item() - - # Verify content tokens are identical - content_right = p_batch_right['input_ids'][0][attention_mask_right.bool()].tolist() - content_left = p_batch_left['input_ids'][0][attention_mask_left.bool()].tolist() - assert content_right == content_left - - # Test pooling with mock model - hidden_size = 64 - - class MockEncoderOutput: - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - def mock_encoder_forward(**kwargs): - input_ids = kwargs['input_ids'] - batch_size, seq_len = input_ids.shape - hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) - for i in range(batch_size): - for j in range(seq_len): - hidden_states[i, j, 0] = float(j) - return MockEncoderOutput(last_hidden_state=hidden_states) - - mock_encoder = Mock(side_effect=mock_encoder_forward) - mock_encoder.config = Mock() - mock_encoder.config.hidden_size = hidden_size - - model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = 0 - - p_reps_right = model.encode_passage(p_batch_right) - p_reps_left = model.encode_passage(p_batch_left) - - assert torch.allclose(p_reps_right[0, 0], torch.tensor(float(last_valid_pos_right))) - expected_pos_left = len(attention_mask_left) - 1 if is_left_padding else num_valid_left - 1 - assert torch.allclose(p_reps_left[0, 0], torch.tensor(float(expected_pos_left))) - - -@pytest.mark.unit -def test_chunked_passages_left_padding(train_tokenizer): - """Test that EOS positions are correctly adjusted for left padding.""" - import torch - from unittest.mock import Mock - from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import TrainCollator - from tevatron.retriever.modeling.dense import DenseModel - - test_passage = REAL_TEXT - - # Right padding (baseline) - data_args_right = DataArguments( - passage_max_len=128, - passage_chunk_size=64, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator_right = TrainCollator(data_args=data_args_right, tokenizer=train_tokenizer) - _, p_batch_right, eos_positions_right = collator_right([("query", [test_passage], [])]) - - # Left padding - data_args_left = DataArguments( - passage_max_len=128, - passage_chunk_size=64, - pad_to_multiple_of=16, - padding_side="left", - append_eos_token=False, - ) - collator_left = TrainCollator(data_args=data_args_left, tokenizer=train_tokenizer) - _, p_batch_left, eos_positions_left = collator_left([("query", [test_passage], [])]) - - attention_mask_left = p_batch_left['attention_mask'][0] - num_valid_tokens = attention_mask_left.sum().item() - padding_length = len(attention_mask_left) - num_valid_tokens - - # Verify EOS positions are shifted by padding_length - assert len(eos_positions_right[0]) == len(eos_positions_left[0]) - for eos_right, eos_left in zip(eos_positions_right[0], eos_positions_left[0]): - assert eos_left == eos_right + padding_length - assert p_batch_left['input_ids'][0][eos_left] == train_tokenizer.eos_token_id - - # Test pooling extracts from correct positions - hidden_size = 64 - - class MockEncoderOutput: - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - def mock_encoder_forward(**kwargs): - input_ids = kwargs['input_ids'] - batch_size, seq_len = input_ids.shape - hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) - for i in range(batch_size): - for j in range(seq_len): - hidden_states[i, j, 0] = float(j) - return MockEncoderOutput(last_hidden_state=hidden_states) - - mock_encoder = Mock(side_effect=mock_encoder_forward) - mock_encoder.config = Mock() - mock_encoder.config.hidden_size = hidden_size - - model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = 64 - - chunk_reps_right, _ = model.encode_passage(p_batch_right, eos_positions_right) - chunk_reps_left, _ = model.encode_passage(p_batch_left, eos_positions_left) - - # Verify embeddings differ by padding_length - for i, (eos_right, eos_left) in enumerate(zip(eos_positions_right[0], eos_positions_left[0])): - assert torch.allclose(chunk_reps_right[0, i, 0], torch.tensor(float(eos_right))) - assert torch.allclose(chunk_reps_left[0, i, 0], torch.tensor(float(eos_left))) - assert torch.allclose( - chunk_reps_left[0, i, 0] - chunk_reps_right[0, i, 0], - torch.tensor(float(padding_length)) - ) diff --git a/tests/test_chunking_helper.py b/tests/test_chunking_helper.py deleted file mode 100644 index 32448ef0..00000000 --- a/tests/test_chunking_helper.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Unit tests for _chunk_tokens helper function. -""" -import sys -from pathlib import Path -import pytest - - -def _tevatron_root() -> Path: - return Path(__file__).resolve().parents[1] - - -def _add_tevatron_src_to_path(): - src = _tevatron_root() / "src" - sys.path.insert(0, str(src)) - - -@pytest.mark.unit -def test_chunk_tokens_basic(): - """Test basic chunking functionality.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - eos_id = 99 - chunk_size = 4 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) - - # chunk_size=4 means chunk_len=3, so chunks are: - # [0,1,2,99], [3,4,5,99], [6,7,8,99], [9,99] - expected_ids = [0, 1, 2, 99, 3, 4, 5, 99, 6, 7, 8, 99, 9, 99] - expected_eos_pos = [3, 7, 11, 13] - - assert ids == expected_ids - assert eos_pos == expected_eos_pos - - -@pytest.mark.unit -def test_chunk_tokens_with_max_length(): - """Test chunking with max_length constraint.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(20)) - eos_id = 99 - chunk_size = 5 - max_length = 12 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) - - # chunk_size=5 means chunk_len=4 - # First chunk: [0,1,2,3,99] = 5 tokens - # Second chunk: [4,5,6,7,99] = 5 tokens - # Total: 10 tokens, but max_length=12 allows one more EOS - # Third chunk would need at least 1 token + 1 EOS = 2 tokens, but we only have 2 left - # So we can fit: [8,99] = 2 tokens - # Total: 12 tokens - assert len(ids) == 12 - assert ids[-1] == eos_id # Last token should be EOS - assert len(eos_pos) == 3 - assert all(ids[pos] == eos_id for pos in eos_pos) - - -@pytest.mark.unit -def test_chunk_tokens_max_length_exact_fit(): - """Test chunking when max_length exactly fits chunks.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(10)) - eos_id = 99 - chunk_size = 4 - max_length = 14 # Exactly fits 3 chunks: 3*4 + 2 = 14 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) - - # Should have 3 chunks: [0,1,2,99], [3,4,5,99], [6,7,8,99] = 12 tokens - # Plus [9,99] = 2 tokens, total 14 - assert len(ids) == 14 - assert len(eos_pos) == 4 - - -@pytest.mark.unit -def test_chunk_tokens_max_length_too_small(): - """Test chunking when max_length is too small for even one chunk.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(10)) - eos_id = 99 - chunk_size = 4 - max_length = 1 # Too small for even one chunk (need at least 2: 1 token + EOS) - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) - - # Should return empty since we can't fit even one chunk - assert ids == [] - assert eos_pos == [] - - -@pytest.mark.unit -def test_chunk_tokens_empty_input(): - """Test chunking with empty token list.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = [] - eos_id = 99 - chunk_size = 4 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) - - assert ids == [] - assert eos_pos == [] - - -@pytest.mark.unit -def test_chunk_tokens_single_token(): - """Test chunking with single token.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = [42] - eos_id = 99 - chunk_size = 4 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) - - assert ids == [42, 99] - assert eos_pos == [1] - - -@pytest.mark.unit -def test_chunk_tokens_no_max_length(): - """Test chunking without max_length constraint.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(15)) - eos_id = 99 - chunk_size = 5 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length=None) - - # chunk_size=5 means chunk_len=4 - # Should have 4 chunks: [0-3,99], [4-7,99], [8-11,99], [12-14,99] - assert len(ids) == 19 # 15 tokens + 4 EOS tokens - assert len(eos_pos) == 4 - assert all(ids[pos] == eos_id for pos in eos_pos) - - -@pytest.mark.unit -def test_chunk_tokens_chunk_size_one(): - """Test chunking with chunk_size=1 (invalid, should return empty).""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = [1, 2, 3] - eos_id = 99 - chunk_size = 1 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) - - # chunk_size=1 is invalid (need at least 2: 1 token + 1 EOS) - # Should return empty - assert ids == [] - assert eos_pos == [] - - -@pytest.mark.unit -def test_chunk_tokens_chunk_size_two(): - """Test chunking with chunk_size=2.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = [1, 2, 3, 4, 5] - eos_id = 99 - chunk_size = 2 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) - - # chunk_size=2 means chunk_len=1 - # Chunks: [1,99], [2,99], [3,99], [4,99], [5,99] - expected_ids = [1, 99, 2, 99, 3, 99, 4, 99, 5, 99] - expected_eos_pos = [1, 3, 5, 7, 9] - - assert ids == expected_ids - assert eos_pos == expected_eos_pos - - -@pytest.mark.unit -def test_chunk_tokens_eos_positions_are_correct(): - """Test that EOS positions correctly point to EOS tokens.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(10)) - eos_id = 99 - chunk_size = 4 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) - - # Verify all EOS positions contain EOS token - for pos in eos_pos: - assert ids[pos] == eos_id - - # Verify EOS positions are strictly increasing - assert all(eos_pos[i] < eos_pos[i + 1] for i in range(len(eos_pos) - 1)) - - -@pytest.mark.unit -def test_chunk_tokens_max_length_stops_at_boundary(): - """Test that max_length stops chunking at chunk boundary.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _chunk_tokens - - tokens = list(range(20)) - eos_id = 99 - chunk_size = 5 - max_length = 10 # Exactly 2 chunks: 2*5 = 10 - - ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) - - assert len(ids) == 10 - assert len(eos_pos) == 2 - # Should have exactly 2 chunks: [0,1,2,3,99], [4,5,6,7,99] - assert ids == [0, 1, 2, 3, 99, 4, 5, 6, 7, 99] - diff --git a/tests/test_chunking_pooling_equivalence.py b/tests/test_chunking_pooling_equivalence.py deleted file mode 100644 index 268d0cbf..00000000 --- a/tests/test_chunking_pooling_equivalence.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Test to verify that when chunk_size == passage_max_len and there's only one chunk, -chunked and non-chunked modes extract embeddings from different positions. -""" -import sys -from pathlib import Path -import pytest -import torch -from transformers import AutoTokenizer - - -def _tevatron_root() -> Path: - return Path(__file__).resolve().parents[1] - - -def _add_tevatron_src_to_path(): - src = _tevatron_root() / "src" - sys.path.insert(0, str(src)) - - -@pytest.fixture -def train_tokenizer(): - from transformers import AutoTokenizer - return AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") - - -@pytest.mark.unit -def test_chunked_vs_non_chunked_when_chunk_size_equals_max_len(train_tokenizer): - """When chunk_size == passage_max_len, chunked mode adds EOS and extracts from EOS position.""" - _add_tevatron_src_to_path() - from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import TrainCollator - from tevatron.retriever.modeling.dense import DenseModel - from unittest.mock import Mock - - test_passage = "This is a test passage that will fit in one chunk." - passage_max_len = chunk_size = 64 - - # Non-chunked mode - data_args_non_chunked = DataArguments( - passage_max_len=passage_max_len, - passage_chunk_size=0, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator_non_chunked = TrainCollator(data_args=data_args_non_chunked, tokenizer=train_tokenizer) - _, p_batch_non_chunked = collator_non_chunked([("query", [test_passage], [])]) - - # Chunked mode - data_args_chunked = DataArguments( - passage_max_len=passage_max_len, - passage_chunk_size=chunk_size, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator_chunked = TrainCollator(data_args=data_args_chunked, tokenizer=train_tokenizer) - _, p_batch_chunked, eos_positions = collator_chunked([("query", [test_passage], [])]) - - # Verify tokenization: chunked adds EOS, non-chunked doesn't - non_chunked_content = p_batch_non_chunked['input_ids'][0][p_batch_non_chunked['attention_mask'][0].bool()].tolist() - chunked_content = p_batch_chunked['input_ids'][0][p_batch_chunked['attention_mask'][0].bool()].tolist() - - assert chunked_content[-1] == train_tokenizer.eos_token_id - assert non_chunked_content[-1] != train_tokenizer.eos_token_id - assert non_chunked_content == chunked_content[:-1] - - # Test pooling: different positions yield different embeddings - hidden_size = 64 - - class MockEncoderOutput: - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - def mock_encoder_forward(**kwargs): - input_ids = kwargs['input_ids'] - batch_size, seq_len = input_ids.shape - hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) - for i in range(batch_size): - for j in range(seq_len): - hidden_states[i, j, 0] = float(j) - return MockEncoderOutput(last_hidden_state=hidden_states) - - mock_encoder = Mock(side_effect=mock_encoder_forward) - mock_encoder.config = Mock() - mock_encoder.config.hidden_size = hidden_size - - model_non_chunked = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model_non_chunked.passage_chunk_size = 0 - model_chunked = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model_chunked.passage_chunk_size = chunk_size - - p_reps_non_chunked = model_non_chunked.encode_passage(p_batch_non_chunked) - p_reps_chunked, _ = model_chunked.encode_passage(p_batch_chunked, eos_positions) - - last_valid_pos = p_batch_non_chunked['attention_mask'][0].sum().item() - 1 - eos_pos = eos_positions[0][0] - - assert eos_pos == last_valid_pos + 1 - assert not torch.allclose(p_reps_non_chunked[0], p_reps_chunked[0, 0]) diff --git a/tests/test_padding_helper.py b/tests/test_padding_helper.py deleted file mode 100644 index 45c699e4..00000000 --- a/tests/test_padding_helper.py +++ /dev/null @@ -1,367 +0,0 @@ -""" -Unit tests for _pad_and_adjust_eos_positions helper function. -""" -import sys -from pathlib import Path -import pytest -import torch - - -def _tevatron_root() -> Path: - return Path(__file__).resolve().parents[1] - - -def _add_tevatron_src_to_path(): - src = _tevatron_root() / "src" - sys.path.insert(0, str(src)) - - -@pytest.fixture(scope="session") -def train_tokenizer(): - """Use the Qwen 0.6B tokenizer.""" - _add_tevatron_src_to_path() - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") - if tok.pad_token_id is None: - tok.pad_token_id = tok.eos_token_id - tok.padding_side = "right" - return tok - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_right_padding(train_tokenizer): - """Test padding with right padding (no EOS position adjustment needed).""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - pad_id = train_tokenizer.pad_token_id - - all_input_ids = [ - [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 - [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 - ] - all_eos_positions = [[3], [2]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='right', - pad_to_multiple_of=4, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) - [4, 5, eos_id, pad_id], # Passage 1: padded to 4 (1 padding token) - ]) - expected_attention_mask = torch.tensor([ - [1, 1, 1, 1], # Passage 0: all tokens valid - [1, 1, 1, 0], # Passage 1: last token is padding - ]) - expected_eos_positions = [[3], [2]] # EOS positions unchanged for right padding - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_left_padding(train_tokenizer): - """Test padding with left padding (EOS positions should be shifted).""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - pad_id = train_tokenizer.pad_token_id - - all_input_ids = [ - [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 - [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 - ] - all_eos_positions = [[3], [2]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='left', - pad_to_multiple_of=4, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) - [pad_id, 4, 5, eos_id], # Passage 1: padded to 4 (1 padding token on left) - ]) - expected_attention_mask = torch.tensor([ - [1, 1, 1, 1], # Passage 0: all tokens valid - [0, 1, 1, 1], # Passage 1: first token is padding - ]) - # Passage 0: original length 4, padded length 4, padding_length=0, EOS stays at 3 - # Passage 1: original length 3, padded length 4, padding_length=1, EOS shifts from 2 to 3 - expected_eos_positions = [[3], [3]] - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_multiple_eos(train_tokenizer): - """Test padding with multiple EOS positions per passage.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - pad_id = train_tokenizer.pad_token_id - - all_input_ids = [ - [1, 2, eos_id, 3, 4, eos_id], # Passage 0: 6 tokens, EOS at positions 2, 5 - [5, eos_id], # Passage 1: 2 tokens, EOS at position 1 - ] - all_eos_positions = [[2, 5], [1]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='left', - pad_to_multiple_of=8, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [pad_id, pad_id, 1, 2, eos_id, 3, 4, eos_id], # Passage 0: padded to 8 (2 padding tokens on left) - [pad_id, pad_id, pad_id, pad_id, pad_id, pad_id, 5, eos_id], # Passage 1: padded to 8 (6 padding tokens on left) - ]) - expected_attention_mask = torch.tensor([ - [0, 0, 1, 1, 1, 1, 1, 1], # Passage 0: first 2 tokens are padding - [0, 0, 0, 0, 0, 0, 1, 1], # Passage 1: first 6 tokens are padding - ]) - # Passage 0: original length 6, padded length 8, padding_length=2, EOS shift from [2,5] to [4,7] - # Passage 1: original length 2, padded length 8, padding_length=6, EOS shift from 1 to 7 - expected_eos_positions = [[4, 7], [7]] - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_no_padding_needed(train_tokenizer): - """Test when sequences are already the same length.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - pad_id = train_tokenizer.pad_token_id - - all_input_ids = [ - [1, 2, eos_id], - [3, 4, eos_id], - ] - all_eos_positions = [[2], [2]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='right', - pad_to_multiple_of=4, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [1, 2, eos_id, pad_id], # Padded to 4 (1 padding token) - [3, 4, eos_id, pad_id], # Padded to 4 (1 padding token) - ]) - expected_attention_mask = torch.tensor([ - [1, 1, 1, 0], # Last token is padding - [1, 1, 1, 0], # Last token is padding - ]) - expected_eos_positions = [[2], [2]] # EOS positions unchanged for right padding - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_empty_input(train_tokenizer): - """Test with empty input.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - all_input_ids = [] - all_eos_positions = [] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='right', - pad_to_multiple_of=4, - ) - - # Hardcoded golden output for empty input - expected_eos_positions = [] - - assert adjusted_eos_positions == expected_eos_positions - # When input is empty, tokenizer.pad may return list or tensor depending on implementation - if isinstance(padded_dict['input_ids'], torch.Tensor): - assert padded_dict['input_ids'].shape[0] == 0 - assert padded_dict['attention_mask'].shape[0] == 0 - else: - assert len(padded_dict['input_ids']) == 0 - assert len(padded_dict['attention_mask']) == 0 - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_single_passage(train_tokenizer): - """Test with single passage.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - - all_input_ids = [[1, 2, 3, eos_id]] - all_eos_positions = [[3]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='right', - pad_to_multiple_of=4, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [1, 2, 3, eos_id], # Already length 4, no padding needed - ]) - expected_attention_mask = torch.tensor([ - [1, 1, 1, 1], # All tokens valid - ]) - expected_eos_positions = [[3]] - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_pad_to_multiple_of_one(train_tokenizer): - """Test with pad_to_multiple_of=1 (no rounding).""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - pad_id = train_tokenizer.pad_token_id - - all_input_ids = [ - [1, 2, eos_id], - [3, eos_id], - ] - all_eos_positions = [[2], [1]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='right', - pad_to_multiple_of=1, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [1, 2, eos_id], # Padded to max_len=3 (no rounding needed with pad_to_multiple_of=1) - [3, eos_id, pad_id], # Padded to max_len=3 (1 padding token) - ]) - expected_attention_mask = torch.tensor([ - [1, 1, 1], # All tokens valid - [1, 1, 0], # Last token is padding - ]) - expected_eos_positions = [[2], [1]] # EOS positions unchanged for right padding - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_left_padding_multiple_chunks(train_tokenizer): - """Test left padding with multiple chunks per passage.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - pad_id = train_tokenizer.pad_token_id - - all_input_ids = [ - [1, eos_id, 2, 3, eos_id], # Passage 0: 5 tokens, EOS at positions 1, 4 - [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 - ] - all_eos_positions = [[1, 4], [2]] - - padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='left', - pad_to_multiple_of=8, - ) - - # Hardcoded golden output - expected_input_ids = torch.tensor([ - [pad_id, pad_id, pad_id, 1, eos_id, 2, 3, eos_id], # Passage 0: padded to 8 (3 padding tokens on left) - [pad_id, pad_id, pad_id, pad_id, pad_id, 4, 5, eos_id], # Passage 1: padded to 8 (5 padding tokens on left) - ]) - expected_attention_mask = torch.tensor([ - [0, 0, 0, 1, 1, 1, 1, 1], # Passage 0: first 3 tokens are padding - [0, 0, 0, 0, 0, 1, 1, 1], # Passage 1: first 5 tokens are padding - ]) - # Passage 0: original length 5, padded length 8, padding_length=3, EOS shift from [1,4] to [4,7] - # Passage 1: original length 3, padded length 8, padding_length=5, EOS shift from 2 to 7 - expected_eos_positions = [[4, 7], [7]] - - assert torch.equal(padded_dict['input_ids'], expected_input_ids) - assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) - assert adjusted_eos_positions == expected_eos_positions - - -@pytest.mark.unit -def test_pad_and_adjust_eos_positions_tokenizer_padding_side_set(train_tokenizer): - """Test that tokenizer.padding_side is set correctly.""" - _add_tevatron_src_to_path() - from tevatron.retriever.collator import _pad_and_adjust_eos_positions - - eos_id = train_tokenizer.eos_token_id - - all_input_ids = [[1, 2, eos_id]] - all_eos_positions = [[2]] - - # Test right padding - train_tokenizer.padding_side = 'right' - _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='right', - pad_to_multiple_of=4, - ) - assert train_tokenizer.padding_side == 'right' - - # Test left padding - train_tokenizer.padding_side = 'left' - _pad_and_adjust_eos_positions( - all_input_ids=all_input_ids, - all_eos_positions=all_eos_positions, - tokenizer=train_tokenizer, - padding_side='left', - pad_to_multiple_of=4, - ) - assert train_tokenizer.padding_side == 'left' - From a60b20ba3528a40cb7911f5cd7a919e5abe3ba41 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 18:37:28 -0500 Subject: [PATCH 22/31] Reviewed forward and maxsim --- src/tevatron/retriever/modeling/dense.py | 2 + src/tevatron/retriever/modeling/encoder.py | 45 ++++++++++++---------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index 3904854b..dd0d077a 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -54,7 +54,9 @@ def _pooling_chunked(self, last_hidden_state, eos_positions): for i, positions in enumerate(eos_positions): for j, pos in enumerate(positions): if 0 <= pos < seq_len: + # i is the batch index, j is the chunk index, pos is the eos position chunk_reps[i, j] = last_hidden_state[i, pos] + # chunk_mask is 1.0 for valid chunks, 0.0 for padding chunks chunk_mask[i, j] = 1.0 else: logger.warning(f"Position {pos} out of bounds for sequence length {seq_len} in batch {i}, chunk {j}") diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index 8a25c556..443a004b 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -89,12 +89,16 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = else: # print(f"start compute similarity==========================") scores = self.compute_similarity(q_reps, p_reps) + # view the scores as [Q, P] where Q is the number of queries and P is the number of passages scores = scores.view(q_reps.size(0), -1) num_psg_per_query = scores.size(1) // q_reps.size(0) target = torch.arange(q_reps.size(0), device=scores.device, dtype=torch.long) target = target * num_psg_per_query - + # target contains the indices of the positive passages in this batch target.shape = [Q] + # so the target is [0, 4, 8, 12] for batch_size = 2, group_size = 4, chunk_size = 64 + print(f"target: {target}") + print(f"target.shape: {target.shape}") loss = self.compute_loss(scores / self.temperature, target) if self.is_ddp: loss = loss * self.world_size # counter average weight reduction @@ -137,39 +141,40 @@ def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): chunk_scores = chunk_scores.masked_fill(padding_mask, float('-inf')) max_vals, max_idx = chunk_scores.max(dim=-1) # [Q, P], [Q, P] - # Print argmax chunk index + (optional) original token position from eos_positions + # Log maxsim info: read chunk indices directly from max_idx if True: # only log from rank-0 if DDP if (not getattr(self, "is_ddp", False)) or getattr(self, "process_rank", 0) == 0: eos_positions = getattr(self, "eos_positions", None) - # If DDP gathered passages, eos_positions may not align; only use when sizes match. eos_ok = ( isinstance(eos_positions, (list, tuple)) and len(eos_positions) == p_reps.size(0) ) - qn, pn = max_idx.size(0), max_idx.size(1) - for qi in range(qn): - for pi in range(pn): - ci = int(max_idx[qi, pi].item()) - # last valid chunk index for this passage (by mask) - if chunk_mask is not None: - valid = int(chunk_mask[pi].sum().item()) - last_ci = max(valid - 1, 0) - else: - last_ci = p_reps.size(1) - 1 - - if eos_ok and eos_positions[pi]: - pos_list = eos_positions[pi] - best_pos = pos_list[ci] if 0 <= ci < len(pos_list) else None - last_pos = pos_list[-1] + + # Compute last valid chunk indices for all passages + if chunk_mask is not None: + last_ci_per_passage = (chunk_mask.sum(dim=1) - 1).clamp(min=0) # [P] + else: + last_ci_per_passage = torch.full((p_reps.size(0),), p_reps.size(1) - 1, dtype=torch.long) + + # Log for each query-passage pair + for qi in range(max_idx.size(0)): + for pi in range(max_idx.size(1)): + ci = int(max_idx[qi, pi].item()) # best chunk index from max_idx + last_ci = int(last_ci_per_passage[pi].item()) + score = float(max_vals[qi, pi].item()) + + if eos_ok and eos_positions[pi] and ci < len(eos_positions[pi]): + best_pos = eos_positions[pi][ci] + last_pos = eos_positions[pi][-1] logger.info( f"[maxsim] q={qi} p={pi} best_chunk={ci} best_pos={best_pos} " - f"last_chunk={last_ci} last_pos={last_pos} best_score={float(max_vals[qi, pi].item()):.6f}" + f"last_chunk={last_ci} last_pos={last_pos} best_score={score:.6f}" ) else: logger.info( f"[maxsim] q={qi} p={pi} best_chunk={ci} last_chunk={last_ci} " - f"best_score={float(max_vals[qi, pi].item()):.6f}" + f"best_score={score:.6f}" ) return max_vals From 32a19768a2fae37bd81092798423e02e0d6ed69b Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 21:16:55 -0500 Subject: [PATCH 23/31] dataset uses random negative for cases have less negatives --- src/tevatron/retriever/dataset.py | 4 +- src/tevatron/retriever/driver/encode.py | 2 +- tests/test_pooling.py | 383 ------------------------ 3 files changed, 4 insertions(+), 385 deletions(-) delete mode 100644 tests/test_pooling.py diff --git a/src/tevatron/retriever/dataset.py b/src/tevatron/retriever/dataset.py index ae3fdb57..b02fa827 100644 --- a/src/tevatron/retriever/dataset.py +++ b/src/tevatron/retriever/dataset.py @@ -130,6 +130,7 @@ def __getitem__(self, item): # Select negative documents negative_size = self.data_args.train_group_size - 1 if len(group['negative_passages']) < negative_size: + print(f"selected_negatives: Randomly selected!!!!!!!!!!!!!!!!!!!!!!!!!!") selected_negatives = random.choices(group['negative_passages'], k=negative_size) elif self.data_args.train_group_size == 1: selected_negatives = [] @@ -394,7 +395,8 @@ def __getitem__(self, item): negative_size = self.data_args.train_group_size - 1 if len(negative_document_ids) < negative_size: - selected_negative_document_ids = random.choices(negative_document_ids, k=negative_size) + rng = random.Random(_hashed_seed) + selected_negative_document_ids = rng.choices(negative_document_ids, k=negative_size) elif self.data_args.train_group_size == 1: selected_negative_document_ids = [] else: diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 59a92e97..3d15e7a9 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -148,7 +148,7 @@ def main(): logger.info(f"Encoded {len(set(d for d, c in lookup_indices))} docs into {len(lookup_indices)} chunks") print(f"encoded.shape: {encoded.shape}") print(f"length of encoded: {len(encoded)}") - input("Press Enter to continue...") + # input("Press Enter to continue...") else: encoded = np.concatenate(encoded) diff --git a/tests/test_pooling.py b/tests/test_pooling.py deleted file mode 100644 index 2c65eca3..00000000 --- a/tests/test_pooling.py +++ /dev/null @@ -1,383 +0,0 @@ -import sys -from pathlib import Path - -import pytest - - -def _tevatron_root() -> Path: - return Path(__file__).resolve().parents[1] - - -def _add_tevatron_src_to_path(): - # tevatron/tests/test_pooling.py -> tevatron/ -> tevatron/src - src = _tevatron_root() / "src" - sys.path.insert(0, str(src)) - - -REAL_TEXT = ( - "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " - "development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging " - "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to " - "calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in " - "preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter " - "development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white " - "matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to " - "1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both " - "times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with " - "greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed " - "higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, " - "p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- " - "0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). " - "Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and " - "preterm infants at term showed marked differences in white matter fiber organization. The data indicate that " - "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " - "development in cerebral white matter in living infants" -) -EOS_TOKEN_ID = 151645 -PADDING_TOKEN_ID = 151643 - -@pytest.fixture(scope="session") -def train_tokenizer(): - """ - Use the Qwen 0.6B tokenizer. - """ - _add_tevatron_src_to_path() - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") - if tok.pad_token_id is None: - tok.pad_token_id = tok.eos_token_id - tok.padding_side = "right" # finetune_with_chunk.sh uses --padding_side right - return tok - - -@pytest.mark.unit -def test_encode_with_chunking(train_tokenizer, tmp_path): - """ - Test the full encode functionality with chunking enabled. - This tests the integration of: - - EncodeDataset loading JSONL data - - ChunkedEncodeCollator creating batches with eos_positions - - DenseModel.encode_passage with chunking - - Output shape and lookup_indices creation - """ - import json - import numpy as np - import torch - from torch.utils.data import DataLoader - from unittest.mock import Mock - - from tevatron.retriever.arguments import DataArguments, TevatronTrainingArguments as TrainingArguments - from tevatron.retriever.dataset import EncodeDataset - from tevatron.retriever.collator import ChunkedEncodeCollator - from tevatron.retriever.modeling.dense import DenseModel - - # Create temporary JSONL file with test passages - test_passages = [ - {"docid": "doc1", "text": REAL_TEXT}, # Long passage that will be chunked - {"docid": "doc2", "text": "Short passage."}, # Short passage - ] - - jsonl_file = tmp_path / "test_corpus.jsonl" - with open(jsonl_file, 'w') as f: - for passage in test_passages: - f.write(json.dumps(passage) + '\n') - - # Setup data arguments for chunked encoding - data_args = DataArguments( - dataset_name='json', - dataset_path=str(jsonl_file), - dataset_split='train', - passage_chunk_size=32, - passage_max_len=128, - pad_to_multiple_of=16, - padding_side="right", - passage_prefix="", - encode_is_query=False, - ) - - # Setup training arguments - training_args = TrainingArguments( - output_dir=str(tmp_path / "output"), - per_device_eval_batch_size=2, - dataloader_num_workers=0, - fp16=False, - bf16=False, - ) - - # Create dataset - encode_dataset = EncodeDataset(data_args=data_args) - assert len(encode_dataset) == 2 - - # Create chunked collator - encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) - - # Create data loader - encode_loader = DataLoader( - encode_dataset, - batch_size=training_args.per_device_eval_batch_size, - collate_fn=encode_collator, - shuffle=False, - drop_last=False, - num_workers=training_args.dataloader_num_workers, - ) - - # Create a mock encoder model - hidden_size = 64 - - # Create a proper mock that returns an object with last_hidden_state - class MockEncoderOutput: - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - # Mock the encoder forward pass to return hidden states - def mock_encoder_forward(**kwargs): - input_ids = kwargs['input_ids'] - batch_size, seq_len = input_ids.shape - # Create dummy hidden states with positional encoding for testing - hidden_states = torch.arange(batch_size * seq_len * hidden_size, dtype=torch.float32) - hidden_states = hidden_states.reshape(batch_size, seq_len, hidden_size) - # Add some variation based on input_ids for testing - hidden_states = hidden_states + input_ids.unsqueeze(-1).float() * 0.01 - return MockEncoderOutput(last_hidden_state=hidden_states) - - mock_encoder = Mock(side_effect=mock_encoder_forward) - mock_encoder.config = Mock() - mock_encoder.config.hidden_size = hidden_size - - # Create DenseModel with mock encoder - model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = data_args.passage_chunk_size - model.eval() - - # Simulate the encode loop - encoded = [] - lookup_indices = [] - - for batch in encode_loader: - doc_ids, batch_inputs, eos_positions = batch - - # Verify batch structure - assert isinstance(doc_ids, list) - # batch_inputs is a BatchEncoding (from tokenizer.pad), which behaves like a dict - assert hasattr(batch_inputs, '__getitem__') # Check if it's dict-like - assert 'input_ids' in batch_inputs - assert 'attention_mask' in batch_inputs - assert isinstance(eos_positions, list) - assert len(eos_positions) == len(doc_ids) - - # Verify eos_positions structure - for i, eos_pos_list in enumerate(eos_positions): - assert isinstance(eos_pos_list, list) - assert len(eos_pos_list) > 0 # Should have at least one chunk - # Verify eos_positions are within sequence length - seq_len = batch_inputs['input_ids'].shape[1] - for pos in eos_pos_list: - assert 0 <= pos < seq_len - - # Encode with chunking - with torch.no_grad(): - chunk_embs, chunk_mask = model.encode_passage(batch_inputs, eos_positions) - - # Verify output shapes - batch_size, max_chunks, hidden_size_out = chunk_embs.shape - assert batch_size == len(doc_ids) - assert hidden_size_out == hidden_size - assert chunk_mask.shape == (batch_size, max_chunks) - - # Verify chunk_mask values (should be 0 or 1) - assert torch.all((chunk_mask == 0) | (chunk_mask == 1)) - - # Process chunks and create lookup indices - for i, doc_id in enumerate(doc_ids): - for chunk_idx in range(max_chunks): - if chunk_mask[i, chunk_idx] > 0: # Valid chunk - encoded.append(chunk_embs[i, chunk_idx].cpu().detach().numpy()) - lookup_indices.append((doc_id, chunk_idx)) - - # Verify results - assert len(encoded) > 0 - assert len(lookup_indices) == len(encoded) - - # Stack encoded embeddings - encoded_array = np.stack(encoded) - assert encoded_array.shape[0] == len(encoded) - assert encoded_array.shape[1] == hidden_size - - # Verify lookup_indices structure - unique_docs = set(doc_id for doc_id, _ in lookup_indices) - assert len(unique_docs) == 2 # Should have both doc1 and doc2 - - # Verify doc1 has multiple chunks (it's a long passage) - doc1_chunks = [chunk_idx for doc_id, chunk_idx in lookup_indices if doc_id == "doc1"] - assert len(doc1_chunks) > 1 # Should have multiple chunks - - # Verify doc2 has at least one chunk - doc2_chunks = [chunk_idx for doc_id, chunk_idx in lookup_indices if doc_id == "doc2"] - assert len(doc2_chunks) >= 1 - - # Verify chunk indices are sequential starting from 0 - for doc_id in unique_docs: - doc_chunks = sorted([chunk_idx for d, chunk_idx in lookup_indices if d == doc_id]) - assert doc_chunks == list(range(len(doc_chunks))) # Should be 0, 1, 2, ... - - # Verify embeddings are not all zeros (they should have been computed) - assert not np.allclose(encoded_array, 0) - - # Verify embeddings have reasonable values (not NaN or Inf) - assert np.all(np.isfinite(encoded_array)) - - -@pytest.mark.unit -def test_pooling_chunked_eos_positions_alignment(): - """Test _pooling_chunked extracts embeddings from correct EOS positions.""" - import torch - from unittest.mock import Mock - from tevatron.retriever.modeling.dense import DenseModel - - mock_encoder = Mock() - mock_encoder.config.hidden_size = 8 - model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = 32 - - batch_size, seq_len, hidden_size = 2, 10, 8 - hidden_states = torch.zeros(batch_size, seq_len, hidden_size) - for i in range(batch_size): - for j in range(seq_len): - hidden_states[i, j, 0] = j * 100 + i * 10 - for k in range(1, hidden_size): - hidden_states[i, j, k] = j * 10 + k - - eos_positions = [[2, 5, 8], [3, 7]] - chunk_reps, chunk_mask = model._pooling_chunked(hidden_states, eos_positions) - - assert chunk_reps.shape == (batch_size, 3, hidden_size) - assert chunk_mask.shape == (batch_size, 3) - - # Verify correct positions extracted - assert torch.allclose(chunk_reps[0, 0, 0], torch.tensor(200.0)) # pos 2 - assert torch.allclose(chunk_reps[0, 1, 0], torch.tensor(500.0)) # pos 5 - assert torch.allclose(chunk_reps[0, 2, 0], torch.tensor(800.0)) # pos 8 - assert torch.allclose(chunk_reps[1, 0, 0], torch.tensor(310.0)) # pos 3 - assert torch.allclose(chunk_reps[1, 1, 0], torch.tensor(710.0)) # pos 7 - - # Verify chunk mask - assert (chunk_mask[0, :3] == 1.0).all() - assert (chunk_mask[1, :2] == 1.0).all() - assert chunk_mask[1, 2] == 0.0 - - # Test exact equality with sequential hidden states - hidden_states_2 = torch.arange(batch_size * seq_len * hidden_size, dtype=torch.float32) - hidden_states_2 = hidden_states_2.reshape(batch_size, seq_len, hidden_size) - chunk_reps_2, _ = model._pooling_chunked(hidden_states_2, eos_positions) - - assert torch.equal(chunk_reps_2[0, 0], hidden_states_2[0, 2]) - assert torch.equal(chunk_reps_2[0, 1], hidden_states_2[0, 5]) - assert torch.equal(chunk_reps_2[0, 2], hidden_states_2[0, 8]) - assert torch.equal(chunk_reps_2[1, 0], hidden_states_2[1, 3]) - assert torch.equal(chunk_reps_2[1, 1], hidden_states_2[1, 7]) - - # Test edge cases - chunk_reps_empty, chunk_mask_empty = model._pooling_chunked(hidden_states, []) - assert chunk_reps_empty.shape == (batch_size, 0, hidden_size) - - eos_positions_oob = [[2, 5, 15], [3, 7]] - chunk_reps_oob, chunk_mask_oob = model._pooling_chunked(hidden_states, eos_positions_oob) - assert torch.allclose(chunk_reps_oob[0, 2], torch.zeros(hidden_size)) - assert chunk_mask_oob[0, 2] == 0.0 - - # Test normalization - model.normalize = True - chunk_reps_norm, _ = model._pooling_chunked(hidden_states_2, eos_positions) - for i in range(batch_size): - for j in range(len(eos_positions[i])): - assert torch.allclose(torch.norm(chunk_reps_norm[i, j]), torch.tensor(1.0), atol=1e-6) - - -@pytest.mark.unit -def test_pooling_chunked_real_tokenizer_alignment(train_tokenizer): - """Integration test: eos_positions from collator correctly align with hidden states.""" - import torch - from unittest.mock import Mock - from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import ChunkedEncodeCollator - from tevatron.retriever.modeling.dense import DenseModel - - data_args = DataArguments( - passage_chunk_size=32, - passage_max_len=128, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) - passages = [REAL_TEXT, "Short passage for testing."] - d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages(passages) - - input_ids = d_collated['input_ids'] - seq_len = input_ids.shape[1] - - # Verify eos_positions are valid - for i, eos_pos_list in enumerate(eos_positions): - assert len(eos_pos_list) > 0 - for pos in eos_pos_list: - assert 0 <= pos < seq_len - assert input_ids[i, pos] == train_tokenizer.eos_token_id - - # Create mock encoder - hidden_size = 64 - - class MockEncoderOutput: - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - def mock_encoder_forward(**kwargs): - input_ids = kwargs['input_ids'] - batch_size, seq_len = input_ids.shape - hidden_states = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) - for i in range(batch_size): - for j in range(seq_len): - hidden_states[i, j, 0] = float(j) - hidden_states[i, j, 1] = float(input_ids[i, j]) - for k in range(2, hidden_size): - hidden_states[i, j, k] = float(j * hidden_size + k) - return MockEncoderOutput(last_hidden_state=hidden_states) - - mock_encoder = Mock(side_effect=mock_encoder_forward) - mock_encoder.config = Mock() - mock_encoder.config.hidden_size = hidden_size - - model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = data_args.passage_chunk_size - - batch_inputs = { - 'input_ids': d_collated['input_ids'], - 'attention_mask': d_collated['attention_mask'], - } - chunk_reps, chunk_mask = model.encode_passage(batch_inputs, eos_positions) - - batch_size = len(passages) - max_chunks = max(len(pos_list) for pos_list in eos_positions) - assert chunk_reps.shape == (batch_size, max_chunks, hidden_size) - - # Re-create expected hidden states - hidden_states_expected = torch.zeros(batch_size, seq_len, hidden_size, dtype=torch.float32) - for i in range(batch_size): - for j in range(seq_len): - hidden_states_expected[i, j, 0] = float(j) - hidden_states_expected[i, j, 1] = float(input_ids[i, j]) - for k in range(2, hidden_size): - hidden_states_expected[i, j, k] = float(j * hidden_size + k) - - # Verify extracted embeddings match expected positions - for i, eos_pos_list in enumerate(eos_positions): - for j, pos in enumerate(eos_pos_list): - assert torch.equal(chunk_reps[i, j], hidden_states_expected[i, pos]) - assert chunk_mask[i, j] == 1.0 - assert torch.allclose(chunk_reps[i, j, 0], torch.tensor(float(pos))) - - # Verify invalid chunks are masked - for i in range(batch_size): - num_chunks = len(eos_positions[i]) - for j in range(num_chunks, max_chunks): - assert chunk_mask[i, j] == 0.0 From 1b4163ee8d24aeb58375ab8cdce407d237b01a7f Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Tue, 23 Dec 2025 23:18:51 -0500 Subject: [PATCH 24/31] added some prints --- src/tevatron/retriever/dataset.py | 3 +- src/tevatron/retriever/driver/search.py | 30 +- tests/test_search.py | 646 ++++++++++++++++++++++++ 3 files changed, 670 insertions(+), 9 deletions(-) diff --git a/src/tevatron/retriever/dataset.py b/src/tevatron/retriever/dataset.py index b02fa827..8b2bf25b 100644 --- a/src/tevatron/retriever/dataset.py +++ b/src/tevatron/retriever/dataset.py @@ -395,8 +395,7 @@ def __getitem__(self, item): negative_size = self.data_args.train_group_size - 1 if len(negative_document_ids) < negative_size: - rng = random.Random(_hashed_seed) - selected_negative_document_ids = rng.choices(negative_document_ids, k=negative_size) + selected_negative_document_ids = random.choices(negative_document_ids, k=negative_size) elif self.data_args.train_group_size == 1: selected_negative_document_ids = [] else: diff --git a/src/tevatron/retriever/driver/search.py b/src/tevatron/retriever/driver/search.py index 5a7f84e8..9b601c09 100644 --- a/src/tevatron/retriever/driver/search.py +++ b/src/tevatron/retriever/driver/search.py @@ -35,33 +35,50 @@ def search_queries_chunked(retriever, q_reps, p_lookup, args): Search with chunked passages and aggregate by document using MaxSim. """ # Search more chunks to ensure good recall after aggregation - search_depth = args.depth * args.chunk_multiplier + search_depth = args.depth + print(f"q_reps: {len(q_reps)}") + print(f"p_lookup: {p_lookup}") + print(f"len(p_lookup): {len(p_lookup)}") + print(f"args.batch_size: {args.batch_size}") if args.batch_size > 0: + # all_scores.shape = [Q, search_depth] all_scores, all_indices = retriever.batch_search(q_reps, search_depth, args.batch_size, args.quiet) else: + # all_scores.shape = [search_depth] all_scores, all_indices = retriever.search(q_reps, search_depth) + print(f"all_scores: {all_scores}") + print(f"all_indices: {all_indices}") + print(f"all_scores.shape: {all_scores.shape}") # [Q, search_depth] + print(f"all_indices.shape: {all_indices.shape}") # [Q, search_depth] + # input("Press Enter to continue...") # Aggregate by document ID using MaxSim aggregated_results = [] for q_idx in range(len(q_reps)): scores = all_scores[q_idx] indices = all_indices[q_idx] - doc_max_scores = defaultdict(lambda: float('-inf')) - for score, idx in zip(scores, indices): if idx < 0: # FAISS returns -1 for insufficient results continue + if idx >= len(p_lookup): # Boundary check: prevent IndexError + logger.warning(f"Index {idx} out of bounds for p_lookup (length {len(p_lookup)}), skipping") + continue + + try: + doc_id, chunk_idx = p_lookup[idx] + except (ValueError, TypeError) as e: + logger.error(f"p_lookup[{idx}] is not a tuple (doc_id, chunk_idx): {p_lookup[idx]}, error: {e}") + continue - doc_id, chunk_idx = p_lookup[idx] # MaxSim: keep the maximum score for each document doc_max_scores[doc_id] = max(doc_max_scores[doc_id], score) - # Sort by score and take top-depth sorted_docs = sorted(doc_max_scores.items(), key=lambda x: x[1], reverse=True)[:args.depth] aggregated_results.append(sorted_docs) - + print(f"aggregated_results: {aggregated_results[0]}") + input("Press Enter to continue...") return aggregated_results @@ -129,7 +146,6 @@ def main(): # Auto-detect chunked format: lookup entries are tuples (doc_id, chunk_idx) is_chunked = args.chunked or (len(look_up) > 0 and isinstance(look_up[0], tuple)) - if is_chunked: unique_docs = len(set(doc_id for doc_id, _ in look_up)) logger.info(f"Chunked mode: {len(look_up)} chunks from {unique_docs} documents") diff --git a/tests/test_search.py b/tests/test_search.py index 155ecf89..75c7b28d 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -209,3 +209,649 @@ class MockArgs: for doc_id, score in results_single_chunk[q_idx]: assert isinstance(doc_id, str) assert isinstance(score, (int, float, np.floating)) + + +@pytest.mark.unit +def test_write_ranking(): + """Test write_ranking function for non-chunked search results.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import write_ranking + import tempfile + import os + + # Create mock data + q_lookup = ["q1", "q2", "q3"] + corpus_scores = [ + [0.9, 0.8, 0.7, 0.6, 0.5], + [0.95, 0.85, 0.75, 0.65, 0.55], + [0.88, 0.78, 0.68, 0.58, 0.48] + ] + corpus_indices = [ + ["doc_1", "doc_2", "doc_3", "doc_4", "doc_5"], + ["doc_10", "doc_20", "doc_30", "doc_40", "doc_50"], + ["doc_100", "doc_200", "doc_300", "doc_400", "doc_500"] + ] + + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + write_ranking(corpus_indices, corpus_scores, q_lookup, temp_path) + + # Verify file contents + with open(temp_path, 'r') as f: + lines = f.readlines() + + assert len(lines) == 15 # 3 queries * 5 results + + # Check first query results (should be sorted by score descending) + first_query_lines = lines[:5] + scores = [float(line.strip().split('\t')[2]) for line in first_query_lines] + assert scores == sorted(scores, reverse=True), "Scores should be in descending order" + + # Verify format: qid\tidx\tscore + for line in lines: + parts = line.strip().split('\t') + assert len(parts) == 3, "Each line should have 3 parts: qid, idx, score" + assert parts[0] in q_lookup, "Query ID should be in q_lookup" + assert float(parts[2]) >= 0, "Score should be a number" + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.mark.unit +def test_write_ranking_chunked(): + """Test write_ranking_chunked function for chunked search results.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import write_ranking_chunked + import tempfile + import os + + # Create mock chunked results + q_lookup = ["q1", "q2"] + results = [ + [("doc_1", 0.95), ("doc_2", 0.85), ("doc_3", 0.75)], + [("doc_10", 0.92), ("doc_20", 0.82), ("doc_30", 0.72)] + ] + + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + write_ranking_chunked(results, q_lookup, temp_path) + + # Verify file contents + with open(temp_path, 'r') as f: + lines = f.readlines() + + assert len(lines) == 6 # 2 queries * 3 results + + # Verify format: qid\tdoc_id\tscore + for i, line in enumerate(lines): + parts = line.strip().split('\t') + assert len(parts) == 3, "Each line should have 3 parts: qid, doc_id, score" + + # Check query ID + if i < 3: + assert parts[0] == "q1" + else: + assert parts[0] == "q2" + + # Check score is a number + assert float(parts[2]) >= 0, "Score should be a number" + + # Verify scores are in descending order for each query + q1_scores = [float(lines[i].strip().split('\t')[2]) for i in range(3)] + q2_scores = [float(lines[i].strip().split('\t')[2]) for i in range(3, 6)] + assert q1_scores == sorted(q1_scores, reverse=True) + assert q2_scores == sorted(q2_scores, reverse=True) + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.mark.unit +def test_pickle_load_save(): + """Test pickle_load and pickle_save functions.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import pickle_load, pickle_save + import tempfile + import os + + # Create test data + test_reps = np.random.randn(10, 64).astype(np.float32) + test_lookup = [f"doc_{i}" for i in range(10)] + test_data = (test_reps, test_lookup) + + with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as f: + temp_path = f.name + + try: + # Save + pickle_save(test_data, temp_path) + assert os.path.exists(temp_path), "Pickle file should be created" + + # Load + loaded_reps, loaded_lookup = pickle_load(temp_path) + + # Verify data integrity + assert np.array_equal(loaded_reps, test_reps), "Embeddings should match" + assert loaded_lookup == test_lookup, "Lookup should match" + assert isinstance(loaded_reps, np.ndarray), "Loaded reps should be numpy array" + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.mark.unit +def test_search_batch_size(): + """Test that batch_size parameter works correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries + from tevatron.retriever.searcher import FaissFlatSearcher + + num_queries = 10 + num_docs = 20 + hidden_size = 64 + + q_reps = np.random.randn(num_queries, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_docs, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [f"doc_{i}" for i in range(num_docs)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + quiet = True + chunk_multiplier = 10 + + # Test with batch_size = 0 (no batching) + args_no_batch = MockArgs() + args_no_batch.batch_size = 0 + scores_no_batch, indices_no_batch = search_queries(retriever, q_reps, p_lookup, args_no_batch) + + # Test with batch_size > 0 (batching) + args_batch = MockArgs() + args_batch.batch_size = 3 + scores_batch, indices_batch = search_queries(retriever, q_reps, p_lookup, args_batch) + + # Results should be the same regardless of batching + assert len(scores_no_batch) == len(scores_batch) == num_queries + assert len(indices_no_batch) == len(indices_batch) == num_queries + + # Scores should match (allowing for small numerical differences) + for q_idx in range(num_queries): + assert len(scores_no_batch[q_idx]) == len(scores_batch[q_idx]) == args_no_batch.depth + # Scores should be very similar (allowing for floating point precision) + np.testing.assert_allclose(scores_no_batch[q_idx], scores_batch[q_idx], rtol=1e-5) + + +@pytest.mark.unit +def test_search_chunked_with_negative_indices(): + """Test chunked search handles FAISS -1 indices correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + from unittest.mock import Mock, patch + + hidden_size = 64 + num_docs = 3 + num_chunks = 5 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [ + ("doc_0", 0), + ("doc_0", 1), + ("doc_1", 0), + ("doc_2", 0), + ("doc_2", 1), + ] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 10 # Request more than available + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + # Mock search to return -1 for insufficient results + original_search = retriever.search + + def mock_search(q_reps, k): + scores, indices = original_search(q_reps, k) + # Simulate FAISS returning -1 for insufficient results + if k > num_chunks: + # Pad with -1 indices + padded_indices = np.full((scores.shape[0], k), -1, dtype=indices.dtype) + padded_scores = np.full((scores.shape[0], k), -np.inf, dtype=scores.dtype) + padded_indices[:, :indices.shape[1]] = indices + padded_scores[:, :scores.shape[1]] = scores + return padded_scores, padded_indices + return scores, indices + + retriever.search = mock_search + + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + # Should handle -1 indices gracefully + assert len(results) == 1 + assert len(results[0]) <= num_docs # Should aggregate to unique documents + # All results should be valid (doc_id, score) tuples + for doc_id, score in results[0]: + assert isinstance(doc_id, str) + assert isinstance(score, (int, float, np.floating)) + assert not np.isinf(score), "Scores should not be infinite" + + +@pytest.mark.unit +def test_search_single_query(): + """Test search with a single query.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries, search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_docs = 10 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_docs, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [f"doc_{i}" for i in range(num_docs)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Non-chunked search + scores, indices = search_queries(retriever, q_reps, p_lookup, args) + assert len(scores) == 1 + assert len(indices) == 1 + assert len(scores[0]) == args.depth + assert len(indices[0]) == args.depth + + # Chunked search + p_lookup_chunked = [(f"doc_{i}", 0) for i in range(num_docs)] + results = search_queries_chunked(retriever, q_reps, p_lookup_chunked, args) + assert len(results) == 1 + assert len(results[0]) <= args.depth + assert len(results[0]) > 0 + + +@pytest.mark.unit +def test_search_empty_results(): + """Test search behavior with edge cases.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + + # Single query, no passages + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + # Empty passage index + p_reps = np.random.randn(0, hidden_size).astype(np.float32) + p_lookup = [] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Should handle empty index gracefully + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + assert len(results) == 1 + assert len(results[0]) == 0, "Should return empty results for empty index" + + +@pytest.mark.unit +def test_search_depth_larger_than_documents(): + """Test search when depth is larger than available documents.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries, search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_docs = 5 + + q_reps = np.random.randn(2, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_docs, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [f"doc_{i}" for i in range(num_docs)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 20 # Larger than num_docs + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Non-chunked: should return depth results (with padding if needed) + scores, indices = search_queries(retriever, q_reps, p_lookup, args) + assert len(scores) == 2 + assert len(scores[0]) == args.depth # FAISS will pad with -1 indices + + # Chunked: should return at most num_docs results + p_lookup_chunked = [(f"doc_{i}", 0) for i in range(num_docs)] + results = search_queries_chunked(retriever, q_reps, p_lookup_chunked, args) + assert len(results) == 2 + for q_result in results: + assert len(q_result) <= num_docs, "Should not return more documents than available" + + +@pytest.mark.unit +def test_search_chunked_multiplier_effect(): + """Test that chunk_multiplier affects search depth correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_docs = 10 + chunks_per_doc = 3 + num_chunks = num_docs * chunks_per_doc + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [(f"doc_{i}", j) for i in range(num_docs) for j in range(chunks_per_doc)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + + # Test with different multipliers + for multiplier in [1, 5, 10]: + args = MockArgs() + args.chunk_multiplier = multiplier + + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + # Should search depth * multiplier chunks + # After MaxSim aggregation, should return at most depth documents + assert len(results) == 1 + assert len(results[0]) <= args.depth, f"With multiplier {multiplier}, should return at most {args.depth} docs" + assert len(results[0]) > 0, "Should have some results" + + +@pytest.mark.unit +def test_index_boundary_check(): + """Verify index boundary check - ensure no out-of-bounds access to p_lookup""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_chunks = 10 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [(f"doc_{i}", 0) for i in range(num_chunks)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 # Will search 5 * 10 = 50 chunks, but only 10 available + + args = MockArgs() + + # Should not raise IndexError, FAISS will return -1 or valid indices + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + assert len(results) == 1 + # Should handle gracefully without out-of-bounds + assert len(results[0]) <= num_chunks + + +@pytest.mark.unit +def test_p_lookup_format_validation(): + """Verify p_lookup format - must be (doc_id, chunk_idx) tuples""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_chunks = 5 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + + # Correct format: tuples + p_lookup_correct = [(f"doc_{i}", i % 2) for i in range(num_chunks)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + # Should work correctly + results = search_queries_chunked(retriever, q_reps, p_lookup_correct, args) + assert len(results) == 1 + + # Wrong format: strings (non-chunked format) + p_lookup_wrong = [f"doc_{i}" for i in range(num_chunks)] + + # Function will catch errors and continue, won't raise exception + # but will log error messages + results = search_queries_chunked(retriever, q_reps, p_lookup_wrong, args) + # Due to format error, should return empty or partial results + assert len(results) == 1 + + +@pytest.mark.unit +def test_maxsim_aggregation_correctness(): + """Verify MaxSim aggregation correctness""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + + # Create a query + q_rep = np.random.randn(1, hidden_size).astype(np.float32) + q_rep = q_rep / np.linalg.norm(q_rep, axis=1, keepdims=True) + + # Create documents: doc_0 has 3 chunks, doc_1 has 2 chunks + # Make doc_0's chunk 0 most similar, chunk 1 second most, chunk 2 less similar + # Make doc_1's chunks less similar + p_reps = np.random.randn(5, hidden_size).astype(np.float32) + + # doc_0's chunk 0: most similar + p_reps[0] = q_rep[0] * 0.95 + np.random.randn(hidden_size) * 0.05 + # doc_0's chunk 1: second most similar + p_reps[1] = q_rep[0] * 0.85 + np.random.randn(hidden_size) * 0.15 + # doc_0's chunk 2: less similar + p_reps[2] = q_rep[0] * 0.50 + np.random.randn(hidden_size) * 0.50 + # doc_1's chunks: less similar + p_reps[3] = q_rep[0] * 0.40 + np.random.randn(hidden_size) * 0.60 + p_reps[4] = q_rep[0] * 0.35 + np.random.randn(hidden_size) * 0.65 + + # Normalize + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + + p_lookup = [ + ("doc_0", 0), # Most similar + ("doc_0", 1), # Second most similar + ("doc_0", 2), # Less similar + ("doc_1", 0), # Less similar + ("doc_1", 1), # Less similar + ] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 10 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + results = search_queries_chunked(retriever, q_rep, p_lookup, args) + + assert len(results) == 1 + assert len(results[0]) >= 1 + + # doc_0 should be ranked first (because its max score is chunk 0's score, highest) + top_doc = results[0][0][0] + assert top_doc == "doc_0", f"doc_0 should be top (has best chunk), but got {top_doc}" + + # Verify each document appears only once (MaxSim aggregation) + doc_ids = [doc_id for doc_id, _ in results[0]] + assert len(doc_ids) == len(set(doc_ids)), "Each document should appear only once" + + # Verify scores are in descending order + scores = [score for _, score in results[0]] + assert scores == sorted(scores, reverse=True), "Scores should be in descending order" + + +@pytest.mark.unit +def test_empty_doc_max_scores(): + """Test edge case when all results are -1""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(1, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [("doc_0", 0)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + # Mock search to return all -1 + original_search = retriever.search + + def mock_search_all_negative(q_reps, k): + scores = np.array([[-np.inf] * k]) + indices = np.array([[-1] * k]) + return scores, indices + + retriever.search = mock_search_all_negative + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + # Should return empty results, not crash + assert len(results) == 1 + assert len(results[0]) == 0, "Should return empty list when all indices are -1" + + +@pytest.mark.unit +def test_index_out_of_bounds_protection(): + """Test index out-of-bounds protection - if FAISS returns out-of-range indices""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_chunks = 5 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [(f"doc_{i}", 0) for i in range(num_chunks)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + # Mock search to return out-of-bounds indices + original_search = retriever.search + + def mock_search_out_of_bounds(q_reps, k): + # Return some valid indices and some out-of-bounds indices + scores = np.array([[0.9, 0.8, 0.7, 0.6, 0.5]]) + indices = np.array([[0, 1, 2, 10, 20]]) # 10 and 20 are out of bounds + return scores, indices + + retriever.search = mock_search_out_of_bounds + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + # Function will catch out-of-bounds indices and log warnings, won't raise exception + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + # Should handle gracefully, only using valid indices + assert len(results) == 1 + # Since we have 3 valid indices (0, 1, 2), should have some results + assert len(results[0]) <= 3 # At most 3 documents (corresponding to indices 0, 1, 2) From fc16311bdbe901bfa7d284650c5cd9c32592bcfa Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Wed, 24 Dec 2025 11:53:44 -0500 Subject: [PATCH 25/31] removed one breakpoint --- src/tevatron/retriever/driver/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tevatron/retriever/driver/search.py b/src/tevatron/retriever/driver/search.py index 9b601c09..dac952e5 100644 --- a/src/tevatron/retriever/driver/search.py +++ b/src/tevatron/retriever/driver/search.py @@ -78,7 +78,7 @@ def search_queries_chunked(retriever, q_reps, p_lookup, args): sorted_docs = sorted(doc_max_scores.items(), key=lambda x: x[1], reverse=True)[:args.depth] aggregated_results.append(sorted_docs) print(f"aggregated_results: {aggregated_results[0]}") - input("Press Enter to continue...") + # input("Press Enter to continue...") return aggregated_results From 22868d2ea1eb67729dd96defacea60a212ad1128 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 28 Dec 2025 20:24:50 -0500 Subject: [PATCH 26/31] Added random chunks --- src/tevatron/retriever/arguments.py | 5 ++ src/tevatron/retriever/collator.py | 49 ++++++++++++-- src/tevatron/retriever/driver/search.py | 15 +---- src/tevatron/retriever/driver/train.py | 11 +++- tests/test_forward.py | 88 +------------------------ 5 files changed, 60 insertions(+), 108 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index cce3285f..4775f0e1 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -208,6 +208,11 @@ class DataArguments: metadata={"help": "Chunk size for chunked passage encoding with MaxSim. 0=disabled, >0=chunk size in tokens"} ) + passage_chunk_size_range: Optional[str] = field( + default=None, + metadata={"help": "Chunk size range for random chunking during training (e.g., '64,128'). Randomly selects chunk size in [min, max] range per passage. Only for training."} + ) + @dataclass class TevatronTrainingArguments(TrainingArguments): diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index a656e058..1b7a8a35 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -1,6 +1,7 @@ import logging +import random import torch -from typing import List, Tuple +from typing import List, Tuple, Optional from dataclasses import dataclass from transformers import PreTrainedTokenizer, ProcessorMixin from qwen_omni_utils import process_mm_info @@ -108,6 +109,7 @@ def _tokenize_and_pad_chunked_passages( passages: List[str], tokenizer: PreTrainedTokenizer, data_args: DataArguments, + chunk_sizes: Optional[List[int]] = None, ) -> Tuple[dict, List[List[int]]]: """ Tokenize and chunk passages with EOS separators. Each chunk ends with EOS for embedding extraction. @@ -115,23 +117,28 @@ def _tokenize_and_pad_chunked_passages( :param passages: Passage texts to tokenize and chunk :param tokenizer: Tokenizer for encoding :param data_args: DataArguments with chunk_size, max_len, pad_to_multiple_of + :param chunk_sizes: Optional list of chunk sizes (one per passage). If None, uses data_args.passage_chunk_size :return: (collated_dict, eos_positions) - padded tensors and EOS positions per passage """ eos_id = tokenizer.eos_token_id if eos_id is None: raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") + if chunk_sizes is not None and len(chunk_sizes) != len(passages): + raise ValueError(f"chunk_sizes length ({len(chunk_sizes)}) must match passages length ({len(passages)})") max_length = data_args.passage_max_len # cap total length (incl. EOS per chunk) all_input_ids = [] all_eos_positions = [] - for passage in passages: + for idx, passage in enumerate(passages): if passage is None: passage = "" tokens = tokenizer.encode(passage, add_special_tokens=False) + # Use per-passage chunk size if provided, otherwise use fixed chunk size + chunk_size = chunk_sizes[idx] if chunk_sizes is not None else data_args.passage_chunk_size ids, eos_pos = _chunk_tokens( tokens=tokens, - chunk_size=data_args.passage_chunk_size, + chunk_size=chunk_size, eos_token_id=eos_id, max_length=max_length, ) @@ -189,7 +196,34 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return_tensors='pt', ) - if self.data_args.passage_chunk_size > 0: + # Check if we should use chunking (fixed or random) + use_fixed_chunking = self.data_args.passage_chunk_size > 0 + + if self.data_args.passage_chunk_size_range is not None: + # Parse range string (e.g., "64, 128" or "64,128") + try: + parts = [p.strip() for p in self.data_args.passage_chunk_size_range.split(',')] + if len(parts) != 2: + raise ValueError(f"passage_chunk_size_range must contain exactly 2 values separated by comma, got: {self.data_args.passage_chunk_size_range}") + chunk_size_min = int(parts[0]) + chunk_size_max = int(parts[1]) + except ValueError as e: + raise ValueError(f"Invalid passage_chunk_size_range format '{self.data_args.passage_chunk_size_range}'. Expected format: 'min,max' (e.g., '64,128')") from e + + # Validate range + if chunk_size_min < 2: + raise ValueError(f"Minimum chunk size must be >= 2, got {chunk_size_min}") + if chunk_size_max < chunk_size_min: + raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") + + # Generate random chunk sizes for each passage + chunk_sizes = [ + random.randint(chunk_size_min, chunk_size_max) + for _ in all_passages + ] + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_sizes=chunk_sizes) + return q_collated, d_collated, eos_positions + elif use_fixed_chunking: d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) return q_collated, d_collated, eos_positions else: @@ -213,8 +247,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): ) return q_collated, d_collated - def _tokenize_and_pad_chunked_passages(self, passages: List[str]): - return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args) + def _tokenize_and_pad_chunked_passages(self, passages: List[str], chunk_sizes: Optional[List[int]] = None): + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args, chunk_sizes=chunk_sizes) @dataclass @@ -370,7 +404,7 @@ def __call__(self, features): @dataclass class ChunkedEncodeCollator: - """Collator for chunked passage encoding (inference/search). Uses same chunking logic as training.""" + """Collator for chunked passage encoding (inference/search). Uses fixed chunk size (passage_chunk_size), not random chunking.""" data_args: DataArguments tokenizer: PreTrainedTokenizer @@ -383,6 +417,7 @@ def __call__(self, features): doc_ids = [x[0] for x in features] texts = [x[1] for x in features] + # Always use fixed chunking for inference (no random chunk sizes) d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts) return doc_ids, d_collated, all_eos_positions diff --git a/src/tevatron/retriever/driver/search.py b/src/tevatron/retriever/driver/search.py index dac952e5..28fc8c8a 100644 --- a/src/tevatron/retriever/driver/search.py +++ b/src/tevatron/retriever/driver/search.py @@ -35,24 +35,15 @@ def search_queries_chunked(retriever, q_reps, p_lookup, args): Search with chunked passages and aggregate by document using MaxSim. """ # Search more chunks to ensure good recall after aggregation - search_depth = args.depth + chunk_multiplier = getattr(args, 'chunk_multiplier', 10) + search_depth = args.depth * chunk_multiplier - print(f"q_reps: {len(q_reps)}") - print(f"p_lookup: {p_lookup}") - print(f"len(p_lookup): {len(p_lookup)}") - print(f"args.batch_size: {args.batch_size}") if args.batch_size > 0: # all_scores.shape = [Q, search_depth] all_scores, all_indices = retriever.batch_search(q_reps, search_depth, args.batch_size, args.quiet) else: # all_scores.shape = [search_depth] all_scores, all_indices = retriever.search(q_reps, search_depth) - - print(f"all_scores: {all_scores}") - print(f"all_indices: {all_indices}") - print(f"all_scores.shape: {all_scores.shape}") # [Q, search_depth] - print(f"all_indices.shape: {all_indices.shape}") # [Q, search_depth] - # input("Press Enter to continue...") # Aggregate by document ID using MaxSim aggregated_results = [] for q_idx in range(len(q_reps)): @@ -77,8 +68,6 @@ def search_queries_chunked(retriever, q_reps, p_lookup, args): # Sort by score and take top-depth sorted_docs = sorted(doc_max_scores.items(), key=lambda x: x[1], reverse=True)[:args.depth] aggregated_results.append(sorted_docs) - print(f"aggregated_results: {aggregated_results[0]}") - # input("Press Enter to continue...") return aggregated_results diff --git a/src/tevatron/retriever/driver/train.py b/src/tevatron/retriever/driver/train.py index aaa6c163..a570231d 100644 --- a/src/tevatron/retriever/driver/train.py +++ b/src/tevatron/retriever/driver/train.py @@ -88,7 +88,16 @@ def main(): torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation, ) - model.passage_chunk_size = data_args.passage_chunk_size + # Set passage_chunk_size: use fixed chunk size if set, otherwise set to 1 if using random chunking + # (model uses passage_chunk_size > 0 as signal to use chunked encoding) + if data_args.passage_chunk_size > 0: + model.passage_chunk_size = data_args.passage_chunk_size + elif data_args.passage_chunk_size_range is not None: + # For random chunking, set to a positive value to enable chunked encoding + # The actual chunk sizes will be determined per-passage by the collator + model.passage_chunk_size = 1 + else: + model.passage_chunk_size = 0 train_dataset = TrainDataset(data_args) collator = TrainCollator(data_args, tokenizer) diff --git a/tests/test_forward.py b/tests/test_forward.py index b57782fb..e8b54e3b 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -27,6 +27,7 @@ def train_tokenizer(): tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id + tok.eos_token_id = tok.pad_token_id # Match training setup tok.padding_side = "right" return tok @@ -134,90 +135,3 @@ def encode_passage(self, psg): assert torch.allclose(scores_uneven[q_idx, p_idx], torch.tensor(expected_score)) -@pytest.mark.unit -def test_forward_with_chunking(train_tokenizer): - """Test model forward with chunked passages: encode_query, encode_passage, compute_maxsim_similarity.""" - _add_tevatron_src_to_path() - from tevatron.retriever.arguments import DataArguments - from tevatron.retriever.collator import TrainCollator - from tevatron.retriever.modeling.dense import DenseModel - - REAL_TEXT = ( - "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " - "development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging " - "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient." - ) - - data_args = DataArguments( - passage_chunk_size=32, - passage_max_len=128, - pad_to_multiple_of=16, - padding_side="right", - append_eos_token=False, - ) - collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) - - queries = ["What is cerebral white matter?", "What is MRI?"] - passages = [REAL_TEXT, "MRI stands for Magnetic Resonance Imaging."] - q_batch, p_batch, eos_positions = collator([(q, [p], []) for q, p in zip(queries, passages)]) - - hidden_size = 64 - - class MockEncoderOutput: - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - def mock_encoder_forward(**kwargs): - input_ids = kwargs['input_ids'] - batch_size, seq_len = input_ids.shape - return MockEncoderOutput(last_hidden_state=torch.randn(batch_size, seq_len, hidden_size)) - - mock_encoder = Mock(side_effect=mock_encoder_forward) - mock_encoder.config = Mock() - mock_encoder.config.hidden_size = hidden_size - - model = DenseModel(encoder=mock_encoder, pooling='last', normalize=False) - model.passage_chunk_size = data_args.passage_chunk_size - model.eos_positions = eos_positions - model.training = True - - output = model(query=q_batch, passage=p_batch) - - assert hasattr(output, 'q_reps') - assert hasattr(output, 'p_reps') - assert hasattr(output, 'scores') - assert hasattr(output, 'loss') - assert output.q_reps.shape == (len(queries), hidden_size) - - chunk_reps, chunk_mask = output.p_reps - assert chunk_reps.shape[0] == len(passages) - assert chunk_reps.shape[2] == hidden_size - assert output.scores.shape == (len(queries), len(passages)) - assert output.loss.item() >= 0 - - # Test MaxSim with known embeddings - model.eval() - with torch.no_grad(): - q_reps_test = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) - p_reps_test = torch.tensor([ - [[1.0, 0.0], [0.0, 1.0]], - [[0.0, 1.0], [1.0, 0.0]], - ], dtype=torch.float32) - chunk_mask_test = torch.ones(2, 2) - - scores_test = model.compute_maxsim_similarity(q_reps_test, p_reps_test, chunk_mask_test) - assert torch.allclose(scores_test, torch.ones(2, 2)) - - # Test padding chunks are ignored - p_reps_padded = torch.randn(2, 3, hidden_size) - chunk_mask_padded = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - scores_padded = model.compute_maxsim_similarity(output.q_reps, p_reps_padded, chunk_mask_padded) - - for q_idx in range(len(queries)): - for p_idx in range(len(passages)): - chunk_scores = torch.einsum('h,ch->c', output.q_reps[q_idx], p_reps_padded[p_idx]) - valid_mask = chunk_mask_padded[p_idx].bool() - chunk_scores_masked = chunk_scores.clone() - chunk_scores_masked[~valid_mask] = float('-inf') - expected_score = chunk_scores_masked.max().item() - assert torch.allclose(scores_padded[q_idx, p_idx], torch.tensor(expected_score)) From d389bc6eec3f93fc0b0237355394b99bf471e8b2 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sun, 28 Dec 2025 20:46:02 -0500 Subject: [PATCH 27/31] Added full randomization --- src/tevatron/retriever/arguments.py | 5 ++ src/tevatron/retriever/collator.py | 93 +++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index 4775f0e1..d27bb79a 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -213,6 +213,11 @@ class DataArguments: metadata={"help": "Chunk size range for random chunking during training (e.g., '64,128'). Randomly selects chunk size in [min, max] range per passage. Only for training."} ) + passage_chunk_size_variable: bool = field( + default=False, + metadata={"help": "If True and passage_chunk_size_range is set, each chunk within a passage gets a random size from the range. If False, all chunks in a passage use the same random size. Only for training."} + ) + @dataclass class TevatronTrainingArguments(TrainingArguments): diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 1b7a8a35..26391f28 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -19,31 +19,52 @@ def _chunk_tokens( chunk_size: int, eos_token_id: int, max_length: int = None, + chunk_size_range: Optional[Tuple[int, int]] = None, ) -> Tuple[List[int], List[int]]: """ - Chunk tokens into fixed-size chunks with EOS separators. + Chunk tokens into chunks with EOS separators. :param tokens: Token IDs to chunk - :param chunk_size: Max chunk size (before EOS). Must be >= 2. + :param chunk_size: Max chunk size (before EOS). Must be >= 2. Used when chunk_size_range is None. :param eos_token_id: EOS token ID to append after each chunk :param max_length: Optional max total length (including EOS). If None, no limit. + :param chunk_size_range: Optional (min, max) tuple for variable chunk sizes. If set, each chunk uses a random size in [min, max]. :return: (chunked_ids, eos_positions) - token IDs with EOS separators and EOS positions """ - if chunk_size < 2: - return [], [] + # Determine chunk size parameters + if chunk_size_range is not None: + chunk_size_min, chunk_size_max = chunk_size_range + if chunk_size_min < 2: + raise ValueError(f"Minimum chunk size must be >= 2, got {chunk_size_min}") + if chunk_size_max < chunk_size_min: + raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") + use_variable_sizes = True + # For max_length calculation, use average chunk size + avg_chunk_size = (chunk_size_min + chunk_size_max) // 2 + effective_chunk_size = avg_chunk_size + else: + if chunk_size < 2: + return [], [] + use_variable_sizes = False + effective_chunk_size = chunk_size - chunk_len = chunk_size - 1 # Reserve 1 slot for EOS + chunk_len = effective_chunk_size - 1 # Reserve 1 slot for EOS (for estimation) - # Truncate tokens to fit within max_length - # Each chunk: chunk_len tokens + 1 EOS = chunk_size total + # Truncate tokens to fit within max_length (conservative estimate) if max_length and max_length > 0: max_tokens_to_use = 0 remaining_length = max_length + if use_variable_sizes: + # For variable sizes, use min chunk size for conservative estimation + est_chunk_size = chunk_size_min + else: + est_chunk_size = chunk_size + while remaining_length > 1 and max_tokens_to_use < len(tokens): - if remaining_length >= chunk_size: - max_tokens_to_use += chunk_len - remaining_length -= chunk_size + if remaining_length >= est_chunk_size: + max_tokens_to_use += (est_chunk_size - 1) + remaining_length -= est_chunk_size else: max_tokens_to_use += remaining_length - 1 break @@ -55,12 +76,33 @@ def _chunk_tokens( eos_pos = [] i = 0 + total_length = 0 while i < len(tokens): - take = min(chunk_len, len(tokens) - i) + # Pick chunk size for this chunk + if use_variable_sizes: + current_chunk_size = random.randint(chunk_size_min, chunk_size_max) + else: + current_chunk_size = chunk_size + + # Check if we have space for this chunk (including EOS) + if max_length and total_length + current_chunk_size > max_length: + # Use remaining space (leave 1 for EOS if possible) + remaining = max_length - total_length - 1 + if remaining > 0: + take = min(remaining, len(tokens) - i) + chunk = tokens[i:i + take] + ids.extend(chunk) + ids.append(eos_token_id) + eos_pos.append(len(ids) - 1) + break + + current_chunk_len = current_chunk_size - 1 # Reserve 1 slot for EOS + take = min(current_chunk_len, len(tokens) - i) chunk = tokens[i:i + take] ids.extend(chunk) ids.append(eos_token_id) eos_pos.append(len(ids) - 1) # EOS position for pooling + total_length += current_chunk_size i += take return ids, eos_pos @@ -110,6 +152,7 @@ def _tokenize_and_pad_chunked_passages( tokenizer: PreTrainedTokenizer, data_args: DataArguments, chunk_sizes: Optional[List[int]] = None, + chunk_size_range: Optional[Tuple[int, int]] = None, ) -> Tuple[dict, List[List[int]]]: """ Tokenize and chunk passages with EOS separators. Each chunk ends with EOS for embedding extraction. @@ -118,6 +161,7 @@ def _tokenize_and_pad_chunked_passages( :param tokenizer: Tokenizer for encoding :param data_args: DataArguments with chunk_size, max_len, pad_to_multiple_of :param chunk_sizes: Optional list of chunk sizes (one per passage). If None, uses data_args.passage_chunk_size + :param chunk_size_range: Optional (min, max) tuple for variable chunk sizes per chunk. If set, each chunk within a passage uses a random size. :return: (collated_dict, eos_positions) - padded tensors and EOS positions per passage """ eos_id = tokenizer.eos_token_id @@ -141,6 +185,7 @@ def _tokenize_and_pad_chunked_passages( chunk_size=chunk_size, eos_token_id=eos_id, max_length=max_length, + chunk_size_range=chunk_size_range, ) all_input_ids.append(ids) all_eos_positions.append(eos_pos) @@ -216,12 +261,22 @@ def __call__(self, features: List[Tuple[str, List[str]]]): if chunk_size_max < chunk_size_min: raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") - # Generate random chunk sizes for each passage - chunk_sizes = [ - random.randint(chunk_size_min, chunk_size_max) - for _ in all_passages - ] - d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_sizes=chunk_sizes) + if self.data_args.passage_chunk_size_variable: + # Variable chunk sizes: each chunk within a passage gets a random size + # Pass the range to _chunk_tokens, which will randomly pick a size for each chunk + chunk_size_range = (chunk_size_min, chunk_size_max) + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages( + all_passages, + chunk_size_range=chunk_size_range + ) + else: + # Fixed random chunk size per passage: all chunks in a passage use the same random size + # Generate random chunk sizes for each passage + chunk_sizes = [ + random.randint(chunk_size_min, chunk_size_max) + for _ in all_passages + ] + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_sizes=chunk_sizes) return q_collated, d_collated, eos_positions elif use_fixed_chunking: d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) @@ -247,8 +302,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): ) return q_collated, d_collated - def _tokenize_and_pad_chunked_passages(self, passages: List[str], chunk_sizes: Optional[List[int]] = None): - return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args, chunk_sizes=chunk_sizes) + def _tokenize_and_pad_chunked_passages(self, passages: List[str], chunk_sizes: Optional[List[int]] = None, chunk_size_range: Optional[Tuple[int, int]] = None): + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args, chunk_sizes=chunk_sizes, chunk_size_range=chunk_size_range) @dataclass From 338fe92cd0750d635e34c9abd2649474224bace0 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Mon, 29 Dec 2025 13:46:43 -0500 Subject: [PATCH 28/31] removed useless variables --- src/tevatron/retriever/collator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 26391f28..ca9debd2 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -80,6 +80,7 @@ def _chunk_tokens( while i < len(tokens): # Pick chunk size for this chunk if use_variable_sizes: + # Randomly pick a chunk size between min and max current_chunk_size = random.randint(chunk_size_min, chunk_size_max) else: current_chunk_size = chunk_size @@ -102,7 +103,9 @@ def _chunk_tokens( ids.extend(chunk) ids.append(eos_token_id) eos_pos.append(len(ids) - 1) # EOS position for pooling - total_length += current_chunk_size + # Use actual chunk size (take + 1 for EOS) for total_length tracking + actual_chunk_size = take + 1 + total_length += actual_chunk_size i += take return ids, eos_pos @@ -179,6 +182,7 @@ def _tokenize_and_pad_chunked_passages( passage = "" tokens = tokenizer.encode(passage, add_special_tokens=False) # Use per-passage chunk size if provided, otherwise use fixed chunk size + # Note: chunk_size is ignored in _chunk_tokens when chunk_size_range is provided chunk_size = chunk_sizes[idx] if chunk_sizes is not None else data_args.passage_chunk_size ids, eos_pos = _chunk_tokens( tokens=tokens, From 190f9377997513f8933c84cfea9510a00b184460 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Mon, 29 Dec 2025 13:58:20 -0500 Subject: [PATCH 29/31] Refactored the randomization --- src/tevatron/retriever/collator.py | 68 ++++++------------------------ 1 file changed, 14 insertions(+), 54 deletions(-) diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index ca9debd2..c40bf115 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -25,87 +25,53 @@ def _chunk_tokens( Chunk tokens into chunks with EOS separators. :param tokens: Token IDs to chunk - :param chunk_size: Max chunk size (before EOS). Must be >= 2. Used when chunk_size_range is None. + :param chunk_size: Fixed chunk size (before EOS). Must be >= 2. Used when chunk_size_range is None. :param eos_token_id: EOS token ID to append after each chunk :param max_length: Optional max total length (including EOS). If None, no limit. :param chunk_size_range: Optional (min, max) tuple for variable chunk sizes. If set, each chunk uses a random size in [min, max]. :return: (chunked_ids, eos_positions) - token IDs with EOS separators and EOS positions """ - # Determine chunk size parameters - if chunk_size_range is not None: + # Validate and set up chunk size parameters + if chunk_size_range: chunk_size_min, chunk_size_max = chunk_size_range - if chunk_size_min < 2: - raise ValueError(f"Minimum chunk size must be >= 2, got {chunk_size_min}") - if chunk_size_max < chunk_size_min: - raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") use_variable_sizes = True - # For max_length calculation, use average chunk size - avg_chunk_size = (chunk_size_min + chunk_size_max) // 2 - effective_chunk_size = avg_chunk_size else: if chunk_size < 2: return [], [] use_variable_sizes = False - effective_chunk_size = chunk_size - - chunk_len = effective_chunk_size - 1 # Reserve 1 slot for EOS (for estimation) - - # Truncate tokens to fit within max_length (conservative estimate) - if max_length and max_length > 0: - max_tokens_to_use = 0 - remaining_length = max_length - - if use_variable_sizes: - # For variable sizes, use min chunk size for conservative estimation - est_chunk_size = chunk_size_min - else: - est_chunk_size = chunk_size - - while remaining_length > 1 and max_tokens_to_use < len(tokens): - if remaining_length >= est_chunk_size: - max_tokens_to_use += (est_chunk_size - 1) - remaining_length -= est_chunk_size - else: - max_tokens_to_use += remaining_length - 1 - break - - tokens = tokens[:max_tokens_to_use] # Chunk tokens and add EOS after each chunk ids = [] eos_pos = [] - i = 0 total_length = 0 + while i < len(tokens): # Pick chunk size for this chunk if use_variable_sizes: - # Randomly pick a chunk size between min and max current_chunk_size = random.randint(chunk_size_min, chunk_size_max) else: current_chunk_size = chunk_size - # Check if we have space for this chunk (including EOS) + # Check if we would exceed max_length with this chunk if max_length and total_length + current_chunk_size > max_length: # Use remaining space (leave 1 for EOS if possible) remaining = max_length - total_length - 1 if remaining > 0: take = min(remaining, len(tokens) - i) - chunk = tokens[i:i + take] - ids.extend(chunk) + ids.extend(tokens[i:i + take]) ids.append(eos_token_id) eos_pos.append(len(ids) - 1) break - current_chunk_len = current_chunk_size - 1 # Reserve 1 slot for EOS + # Take tokens for this chunk (reserve 1 slot for EOS) + current_chunk_len = current_chunk_size - 1 take = min(current_chunk_len, len(tokens) - i) - chunk = tokens[i:i + take] - ids.extend(chunk) + ids.extend(tokens[i:i + take]) ids.append(eos_token_id) - eos_pos.append(len(ids) - 1) # EOS position for pooling - # Use actual chunk size (take + 1 for EOS) for total_length tracking - actual_chunk_size = take + 1 - total_length += actual_chunk_size + eos_pos.append(len(ids) - 1) + + total_length += take + 1 # +1 for EOS i += take return ids, eos_pos @@ -269,17 +235,11 @@ def __call__(self, features: List[Tuple[str, List[str]]]): # Variable chunk sizes: each chunk within a passage gets a random size # Pass the range to _chunk_tokens, which will randomly pick a size for each chunk chunk_size_range = (chunk_size_min, chunk_size_max) - d_collated, eos_positions = self._tokenize_and_pad_chunked_passages( - all_passages, - chunk_size_range=chunk_size_range - ) + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_size_range=chunk_size_range) else: # Fixed random chunk size per passage: all chunks in a passage use the same random size # Generate random chunk sizes for each passage - chunk_sizes = [ - random.randint(chunk_size_min, chunk_size_max) - for _ in all_passages - ] + chunk_sizes = [random.randint(chunk_size_min, chunk_size_max) for _ in all_passages] d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_sizes=chunk_sizes) return q_collated, d_collated, eos_positions elif use_fixed_chunking: From 0a9626ff870c240b0e4fec9700cd734d5e03cf67 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Mon, 5 Jan 2026 20:53:18 -0500 Subject: [PATCH 30/31] added search time passage encoding with chunks --- src/tevatron/retriever/arguments.py | 5 ++ src/tevatron/retriever/collator.py | 98 +++++++++++++++++++++++++ src/tevatron/retriever/dataset.py | 26 +++++-- src/tevatron/retriever/driver/encode.py | 16 +++- 4 files changed, 136 insertions(+), 9 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index d27bb79a..1a719a75 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -218,6 +218,11 @@ class DataArguments: metadata={"help": "If True and passage_chunk_size_range is set, each chunk within a passage gets a random size from the range. If False, all chunks in a passage use the same random size. Only for training."} ) + encode_use_pre_chunked: bool = field( + default=False, + metadata={"help": "If True, expects dataset with 'chunks' field (list of pre-chunked passage strings). EOS tokens will be added between chunks. If False, uses regular 'text' field. Only for encoding (not training)."} + ) + @dataclass class TevatronTrainingArguments(TrainingArguments): diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index c40bf115..fbcad83a 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -445,6 +445,104 @@ def _tokenize_and_pad_chunked_passages(self, passages: List[str]): return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args) +@dataclass +class PreChunkedEncodeCollator: + """ + Collator for pre-chunked passage encoding (inference/search). + Expects passages as lists of pre-chunked strings and adds EOS tokens between chunks. + """ + data_args: DataArguments + tokenizer: PreTrainedTokenizer + + def __call__(self, features): + """ + Collate pre-chunked passage encoding features. + :param features: List of (doc_id, chunks_list, image, video, audio) tuples + where chunks_list is a list of pre-chunked passage strings + :return: (doc_ids, collated_inputs, eos_positions) + """ + doc_ids = [x[0] for x in features] + chunks_lists = [x[1] for x in features] # List of lists of strings + + # Process pre-chunked passages: tokenize each chunk and add EOS between them + d_collated, all_eos_positions = self._tokenize_and_pad_pre_chunked_passages(chunks_lists) + + return doc_ids, d_collated, all_eos_positions + + def _tokenize_and_pad_pre_chunked_passages(self, chunks_lists: List[List[str]]): + """ + Tokenize pre-chunked passages and add EOS tokens between chunks. + + This is used when you have pre-chunked passages (e.g., from ChatGPT or manual chunking). + Each chunk is tokenized separately, and EOS tokens are inserted between chunks. + + :param chunks_lists: List of lists, where each inner list contains pre-chunked passage strings + Example: [["chunk1", "chunk2"], ["chunk3"]] for 2 passages + :return: (collated_dict, eos_positions) - padded tensors and EOS positions per passage + """ + eos_id = self.tokenizer.eos_token_id + if eos_id is None: + raise ValueError("tokenizer.eos_token_id is None; cannot add EOS tokens between chunks.") + + max_length = self.data_args.passage_max_len + all_input_ids = [] + all_eos_positions = [] + + for chunks in chunks_lists: + if chunks is None: + chunks = [] + if not isinstance(chunks, list): + raise ValueError(f"Expected list of chunks, got {type(chunks)}") + if len(chunks) == 0: + # Empty chunks list - create empty passage with no EOS positions + all_input_ids.append([]) + all_eos_positions.append([]) + continue + + # Tokenize each chunk and concatenate with EOS between them + ids = [] + eos_pos = [] + total_length = 0 + + for chunk_idx, chunk in enumerate(chunks): + if chunk is None: + chunk = "" + # Tokenize this chunk (without special tokens, we'll add EOS manually) + chunk_tokens = self.tokenizer.encode(chunk, add_special_tokens=False) + + # Check if adding this chunk + EOS would exceed max_length + chunk_size = len(chunk_tokens) + if max_length and total_length + chunk_size + 1 > max_length: + # Use remaining space (leave 1 for EOS if possible) + remaining = max_length - total_length - 1 + if remaining > 0: + chunk_tokens = chunk_tokens[:remaining] + ids.extend(chunk_tokens) + ids.append(eos_id) + eos_pos.append(len(ids) - 1) + break + + # Add chunk tokens + ids.extend(chunk_tokens) + # Add EOS after each chunk + ids.append(eos_id) + eos_pos.append(len(ids) - 1) + total_length += chunk_size + 1 + + all_input_ids.append(ids) + all_eos_positions.append(eos_pos) + + d_collated, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=self.tokenizer, + padding_side=self.data_args.padding_side, + pad_to_multiple_of=self.data_args.pad_to_multiple_of, + ) + + return d_collated, adjusted_eos_positions + + @dataclass class MultiModalEncodeCollator: """ diff --git a/src/tevatron/retriever/dataset.py b/src/tevatron/retriever/dataset.py index 8b2bf25b..18c8fc68 100644 --- a/src/tevatron/retriever/dataset.py +++ b/src/tevatron/retriever/dataset.py @@ -293,10 +293,22 @@ def __getitem__(self, item): content_audio = content.get('query_audio', None) else: content_id = content['docid'] - content_text = content.get('text', '') - if 'title' in content: - content_text = content['title'] + ' ' + content_text - content_text = self.data_args.passage_prefix + content_text.strip() + # Support pre-chunked passages (for custom chunking with ChatGPT, etc.) + if self.data_args.encode_use_pre_chunked and 'chunks' in content: + # Pre-chunked: return chunks as a list + chunks = content['chunks'] + if not isinstance(chunks, list): + raise ValueError(f"Expected 'chunks' to be a list, got {type(chunks)}") + # Apply prefix to each chunk if needed + if self.data_args.passage_prefix: + chunks = [self.data_args.passage_prefix + chunk if chunk else chunk for chunk in chunks] + content_text = chunks # Return as list for pre-chunked collator + else: + # Regular text field + content_text = content.get('text', '') + if 'title' in content: + content_text = content['title'] + ' ' + content_text + content_text = self.data_args.passage_prefix + content_text.strip() content_image = content.get('image', None) content_video = content.get('video', None) content_audio = content.get('audio', None) @@ -321,7 +333,11 @@ def __getitem__(self, item): content_audio = None if not self.data_args.encode_text: - content_text = None + # For pre-chunked mode, set to empty list instead of None + if self.data_args.encode_use_pre_chunked and isinstance(content_text, list): + content_text = [] + else: + content_text = None if not self.data_args.encode_image: content_image = None if not self.data_args.encode_video: diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 3d15e7a9..8b1234d5 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -21,7 +21,7 @@ from tevatron.retriever.arguments import ModelArguments, DataArguments, \ TevatronTrainingArguments as TrainingArguments from tevatron.retriever.dataset import EncodeDataset -from tevatron.retriever.collator import EncodeCollator, ChunkedEncodeCollator +from tevatron.retriever.collator import EncodeCollator, ChunkedEncodeCollator, PreChunkedEncodeCollator from tevatron.retriever.modeling import EncoderOutput, DenseModel logger = logging.getLogger(__name__) @@ -83,10 +83,18 @@ def main(): ) use_chunked = not data_args.encode_is_query and data_args.passage_chunk_size > 0 + use_pre_chunked = not data_args.encode_is_query and data_args.encode_use_pre_chunked print("data_args.encode_is_query: ", data_args.encode_is_query) print("data_args.passage_chunk_size: ", data_args.passage_chunk_size) + print("data_args.encode_use_pre_chunked: ", data_args.encode_use_pre_chunked) print("use_chunked: ", use_chunked) - if use_chunked: + print("use_pre_chunked: ", use_pre_chunked) + + if use_pre_chunked: + logger.info("Using pre-chunked passage encoding (custom EOS positions from pre-chunked data)") + model.passage_chunk_size = 1 # Signal to use chunked encoding + encode_collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) + elif use_chunked: logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") model.passage_chunk_size = data_args.passage_chunk_size encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) @@ -109,7 +117,7 @@ def main(): for batch in tqdm(encode_loader): with torch.amp.autocast('cuda') if training_args.fp16 or training_args.bf16 else nullcontext(): with torch.no_grad(): - if use_chunked: + if use_pre_chunked or use_chunked: doc_ids, batch_inputs, eos_positions = batch # batch_inputs: input_ids, attention_mask for k, v in batch_inputs.items(): @@ -137,7 +145,7 @@ def main(): else: model_output: EncoderOutput = model(passage=batch_inputs) encoded.append(model_output.p_reps.cpu().detach().numpy()) - if use_chunked: + if use_pre_chunked or use_chunked: print("use_chunked: ", use_chunked) print(f"encoded: {encoded}") print(f"lookup_indices: {lookup_indices}") From b88e5c1e029df4cbfca0d2c4e1e073732a4822f5 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Sat, 10 Jan 2026 18:01:43 -0500 Subject: [PATCH 31/31] Added random chunk size for eval, and tests for prechunked passages and random chunking --- src/tevatron/retriever/arguments.py | 4 +- src/tevatron/retriever/collator.py | 38 +- src/tevatron/retriever/driver/encode.py | 27 +- tests/test_chunking.py | 940 ++++++++++++++++++++++++ 4 files changed, 997 insertions(+), 12 deletions(-) diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index 1a719a75..a0c6ce5f 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -210,12 +210,12 @@ class DataArguments: passage_chunk_size_range: Optional[str] = field( default=None, - metadata={"help": "Chunk size range for random chunking during training (e.g., '64,128'). Randomly selects chunk size in [min, max] range per passage. Only for training."} + metadata={"help": "Chunk size range for random chunking (e.g., '64,128'). Randomly selects chunk size in [min, max] range per passage. Works for both training and inference."} ) passage_chunk_size_variable: bool = field( default=False, - metadata={"help": "If True and passage_chunk_size_range is set, each chunk within a passage gets a random size from the range. If False, all chunks in a passage use the same random size. Only for training."} + metadata={"help": "If True and passage_chunk_size_range is set, each chunk within a passage gets a random size from the range. If False, all chunks in a passage use the same random size. Works for both training and inference."} ) encode_use_pre_chunked: bool = field( diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index fbcad83a..ca739f65 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -423,7 +423,7 @@ def __call__(self, features): @dataclass class ChunkedEncodeCollator: - """Collator for chunked passage encoding (inference/search). Uses fixed chunk size (passage_chunk_size), not random chunking.""" + """Collator for chunked passage encoding (inference/search). Supports fixed or random chunk sizes.""" data_args: DataArguments tokenizer: PreTrainedTokenizer @@ -436,13 +436,41 @@ def __call__(self, features): doc_ids = [x[0] for x in features] texts = [x[1] for x in features] - # Always use fixed chunking for inference (no random chunk sizes) - d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts) + # Check if we should use random chunking + if self.data_args.passage_chunk_size_range is not None: + # Parse range string (e.g., "64, 128" or "64,128") + try: + parts = [p.strip() for p in self.data_args.passage_chunk_size_range.split(',')] + if len(parts) != 2: + raise ValueError(f"passage_chunk_size_range must contain exactly 2 values separated by comma, got: {self.data_args.passage_chunk_size_range}") + chunk_size_min = int(parts[0]) + chunk_size_max = int(parts[1]) + except ValueError as e: + raise ValueError(f"Invalid passage_chunk_size_range format '{self.data_args.passage_chunk_size_range}'. Expected format: 'min,max' (e.g., '64,128')") from e + + # Validate range + if chunk_size_min < 2: + raise ValueError(f"Minimum chunk size must be >= 2, got {chunk_size_min}") + if chunk_size_max < chunk_size_min: + raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") + + if self.data_args.passage_chunk_size_variable: + # Variable chunk sizes: each chunk within a passage gets a random size + chunk_size_range = (chunk_size_min, chunk_size_max) + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts, chunk_size_range=chunk_size_range) + else: + # Fixed random chunk size per passage: all chunks in a passage use the same random size + # Generate random chunk sizes for each passage + chunk_sizes = [random.randint(chunk_size_min, chunk_size_max) for _ in texts] + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts, chunk_sizes=chunk_sizes) + else: + # Use fixed chunking for inference + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts) return doc_ids, d_collated, all_eos_positions - def _tokenize_and_pad_chunked_passages(self, passages: List[str]): - return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args) + def _tokenize_and_pad_chunked_passages(self, passages: List[str], chunk_sizes: Optional[List[int]] = None, chunk_size_range: Optional[Tuple[int, int]] = None): + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args, chunk_sizes=chunk_sizes, chunk_size_range=chunk_size_range) @dataclass diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 8b1234d5..d816e5da 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -84,19 +84,36 @@ def main(): use_chunked = not data_args.encode_is_query and data_args.passage_chunk_size > 0 use_pre_chunked = not data_args.encode_is_query and data_args.encode_use_pre_chunked + use_random_chunking = not data_args.encode_is_query and data_args.passage_chunk_size_range is not None print("data_args.encode_is_query: ", data_args.encode_is_query) print("data_args.passage_chunk_size: ", data_args.passage_chunk_size) + print("data_args.passage_chunk_size_range: ", data_args.passage_chunk_size_range) + print("data_args.passage_chunk_size_variable: ", data_args.passage_chunk_size_variable) print("data_args.encode_use_pre_chunked: ", data_args.encode_use_pre_chunked) print("use_chunked: ", use_chunked) print("use_pre_chunked: ", use_pre_chunked) + print("use_random_chunking: ", use_random_chunking) if use_pre_chunked: logger.info("Using pre-chunked passage encoding (custom EOS positions from pre-chunked data)") model.passage_chunk_size = 1 # Signal to use chunked encoding encode_collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) - elif use_chunked: - logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") - model.passage_chunk_size = data_args.passage_chunk_size + elif use_chunked or use_random_chunking: + if use_random_chunking: + logger.info(f"Using random chunked passage encoding with chunk_size_range={data_args.passage_chunk_size_range}, variable={data_args.passage_chunk_size_variable}") + else: + logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") + # For random chunking, we still need a base chunk_size for the model + # Use the minimum of the range if random chunking is enabled + if use_random_chunking: + try: + parts = [p.strip() for p in data_args.passage_chunk_size_range.split(',')] + chunk_size_min = int(parts[0]) + model.passage_chunk_size = chunk_size_min + except: + model.passage_chunk_size = data_args.passage_chunk_size if data_args.passage_chunk_size > 0 else 64 + else: + model.passage_chunk_size = data_args.passage_chunk_size encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) else: encode_collator = EncodeCollator(data_args=data_args, tokenizer=tokenizer) @@ -117,7 +134,7 @@ def main(): for batch in tqdm(encode_loader): with torch.amp.autocast('cuda') if training_args.fp16 or training_args.bf16 else nullcontext(): with torch.no_grad(): - if use_pre_chunked or use_chunked: + if use_pre_chunked or use_chunked or use_random_chunking: doc_ids, batch_inputs, eos_positions = batch # batch_inputs: input_ids, attention_mask for k, v in batch_inputs.items(): @@ -145,7 +162,7 @@ def main(): else: model_output: EncoderOutput = model(passage=batch_inputs) encoded.append(model_output.p_reps.cpu().detach().numpy()) - if use_pre_chunked or use_chunked: + if use_pre_chunked or use_chunked or use_random_chunking: print("use_chunked: ", use_chunked) print(f"encoded: {encoded}") print(f"lookup_indices: {lookup_indices}") diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 4f939dad..d3a0bacf 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -1,5 +1,6 @@ import sys from pathlib import Path +import random import pytest import torch @@ -37,6 +38,45 @@ def _strictly_increasing(xs): "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " "development in cerebral white matter in living infants" ) + +# Semantically chunked version of REAL_TEXT - split into meaningful semantic units +REAL_TEXT_SEMANTIC_CHUNKS = [ + # Chunk 1: Introduction - Background on white matter alterations + "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " + "development and result in functional disabilities.", + + # Chunk 2: Methodology - MRI technique description + "A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was " + "applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate " + "three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7).", + + # Chunk 3: Study design - Longitudinal follow-up + "To assess effects of prematurity on cerebral white matter development, early gestation preterm infants " + "(n = 10) were studied a second time at term.", + + # Chunk 4: Results - Central white matter findings + "In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and " + "decreased toward term to 1.2 microm2/ms.", + + # Chunk 5: Results - Internal capsule findings + "In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were " + "similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater " + "absolute values in the internal capsule than in the central white matter.", + + # Chunk 6: Results - Preterm vs full-term comparisons + "Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 " + "versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with " + "full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- " + "4.44 versus 33.1 +/- 0.6% p = 0.006).", + + # Chunk 7: Results - Corpus callosum findings + "Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term " + "and preterm infants at term showed marked differences in white matter fiber organization.", + + # Chunk 8: Conclusion + "The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into " + "microstructural development in cerebral white matter in living infants" +] EOS_TOKEN_ID = 151643 PADDING_TOKEN_ID = 151643 @@ -1226,3 +1266,903 @@ def test_chunking_multiple_passages_different_lengths(train_tokenizer): assert eos_positions[3] == expected_eos_3 assert ids_3 == expected_ids_3 assert mask_3 == expected_mask_3 + + +# ============================================================================ +# Unit tests for random chunk sizes within a range +# ============================================================================ + +@pytest.mark.unit +def test_chunk_tokens_random_chunk_size_range_fixed_per_passage(train_tokenizer): + """Test chunking with random chunk size range, fixed per passage (all chunks in a passage use same random size).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + # Set seed for deterministic results + random.seed(42) + + tokens = list(range(100)) # 100 tokens + eos_id = 99 + chunk_size_range = (10, 20) # Random chunk size between 10 and 20 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size=10, eos_token_id=eos_id, chunk_size_range=chunk_size_range) + + # Hardcoded golden output with seed=42 and chunk_size_range=(10, 20) + # With seed=42, random.randint(10, 20) generates: 19, 12, 11, 15, 14, 13, 13, 12, 5 (for chunks) + # Chunk sizes (before EOS): 19, 12, 11, 15, 14, 13, 13, 12, 5 + # Chunk lengths (tokens per chunk): 18, 11, 10, 14, 13, 12, 12, 11, 4 + expected_ids = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 99, # Chunk 1: 19 tokens (18 + EOS) + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 99, # Chunk 2: 12 tokens (11 + EOS) + 29, 30, 31, 32, 33, 34, 35, 36, 37, 99, # Chunk 3: 11 tokens (10 + EOS) + 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 99, # Chunk 4: 15 tokens (14 + EOS) + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 99, # Chunk 5: 14 tokens (13 + EOS) + 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 99, # Chunk 6: 13 tokens (12 + EOS) + 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 99, # Chunk 7: 13 tokens (12 + EOS) + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 99, # Chunk 8: 12 tokens (11 + EOS) + 96, 97, 98, 99, 99 # Chunk 9: 5 tokens (4 + EOS) + ] + expected_eos_pos = [19, 30, 40, 54, 67, 80, 92, 103, 108] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + # Verify structure: each chunk should end with EOS + for eos_pos_val in eos_pos: + assert ids[eos_pos_val] == eos_id + + +@pytest.mark.unit +def test_chunk_tokens_random_chunk_size_range_with_max_length(train_tokenizer): + """Test random chunk size range with max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + random.seed(123) + + tokens = list(range(200)) + eos_id = 99 + chunk_size_range = (15, 25) + max_length = 50 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size=15, eos_token_id=eos_id, max_length=max_length, chunk_size_range=chunk_size_range) + + # Hardcoded golden output with seed=123, chunk_size_range=(15, 25), max_length=50 + # With seed=123, random.randint(15, 25) generates: 15, 20, 16 (for chunks) + # Chunk sizes (before EOS): 15, 20, 16 + # Chunk lengths (tokens per chunk): 14, 19, 15 + # Total: 14 + 1 + 19 + 1 + 15 + 1 = 50 tokens (exactly max_length) + expected_ids = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99, # Chunk 1: 15 tokens (14 + EOS) + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 99, # Chunk 2: 20 tokens (19 + EOS) + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 99 # Chunk 3: 16 tokens (15 + EOS, truncated to fit max_length) + ] + expected_eos_pos = [14, 33, 49] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + assert len(ids) == max_length # Exactly max_length + + # Verify all EOS positions are valid + for eos_pos_val in eos_pos: + assert ids[eos_pos_val] == eos_id + assert eos_pos_val < len(ids) + + +@pytest.mark.unit +def test_train_collator_random_chunk_size_range_fixed_per_passage(train_tokenizer): + """Test TrainCollator with random chunk size range, fixed per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + random.seed(42) + + data_args = DataArguments( + query_max_len=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=2, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + (("q1", None, None, None), [(REAL_TEXT, None, None, None), (REAL_TEXT, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=False + # With seed=42, random.randint(32, 64) generates: 40 for passage 0, 34 for passage 1 + # Passage 0: chunk_size=40 (chunk_len=39), produces 4 chunks: [38, 77, 116, 127] + # Passage 1: chunk_size=34 (chunk_len=33), produces 4 chunks: [32, 65, 98, 127] + expected_ids_0 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, + EOS_TOKEN_ID, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, EOS_TOKEN_ID, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_0 = [1] * 128 + expected_eos_positions_0 = [38, 77, 116, 127] + + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, EOS_TOKEN_ID, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, EOS_TOKEN_ID, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, EOS_TOKEN_ID, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, + 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_1 = [1] * 128 + expected_eos_positions_1 = [32, 65, 98, 127] + + # Verify structure + assert p_batch["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + # Verify passage 0 + got_ids_0 = p_batch["input_ids"][0].tolist() + got_mask_0 = p_batch["attention_mask"][0].tolist() + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_positions_0 + assert _strictly_increasing(eos_positions[0]) + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + # Verify passage 1 + got_ids_1 = p_batch["input_ids"][1].tolist() + got_mask_1 = p_batch["attention_mask"][1].tolist() + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_positions_1 + assert _strictly_increasing(eos_positions[1]) + for eos_pos in eos_positions[1]: + assert got_ids_1[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_1[eos_pos] == 1 + + +@pytest.mark.unit +def test_train_collator_random_chunk_size_range_variable_per_chunk(train_tokenizer): + """Test TrainCollator with random chunk size range, variable per chunk.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + random.seed(42) + + data_args = DataArguments( + query_max_len=32, + passage_max_len=256, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=1, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=True, # Variable chunk size per chunk + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + (("q1", None, None, None), [(REAL_TEXT, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=True + # With seed=42 and variable chunk sizes, each chunk gets a random size from [32, 64] + # Chunk sizes generated: 40, 34, 50, 48, 47, 41, 3 (last partial chunk) + # EOS positions: [38, 71, 120, 167, 213, 253, 255] + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, EOS_TOKEN_ID, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, EOS_TOKEN_ID, 2086, 882, 518, 4647, + 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, + 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, + 13, 17, 19197, 441, 17, 58634, 13, EOS_TOKEN_ID, 758, 279, 44900, 47594, 315, 279, 5306, + 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, + 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, + 279, 12128, 7194, 572, 311, EOS_TOKEN_ID, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, + 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, + 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, + EOS_TOKEN_ID, 17, EOS_TOKEN_ID + ] + expected_mask = [1] * 256 + expected_eos_positions = [38, 71, 120, 167, 213, 253, 255] + + # Verify structure + assert p_batch["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + + got_ids = p_batch["input_ids"][0].tolist() + got_mask = p_batch["attention_mask"][0].tolist() + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert _strictly_increasing(eos_positions[0]) + + # Verify each EOS position is valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + + +@pytest.mark.unit +def test_train_collator_random_chunk_size_range_hardcoded_output(train_tokenizer): + """Test TrainCollator with random chunk size range - hardcoded golden output.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + random.seed(42) + + data_args = DataArguments( + query_max_len=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=1, + passage_chunk_size_range="32,48", # Random chunk size between 32 and 48 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + short_text = "Hello world this is a test passage" + features = [ + (("q1", None, None, None), [(short_text, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + got_ids = p_batch["input_ids"][0].tolist() + got_mask = p_batch["attention_mask"][0].tolist() + + # Hardcoded golden output with seed=42 and chunk_size_range=(32,48) + # short_text tokenizes to: [9707, 1879, 419, 374, 264, 1273, 21085] + # With seed=42, random.randint(32, 48) = 40 (first call) + # So chunk_len = 39, but we only have 7 tokens, so we get: [7 tokens] + EOS + expected_ids = [9707, 1879, 419, 374, 264, 1273, 21085, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 8 + expected_mask = [1] * 8 + [0] * 8 + expected_eos_positions = [[7]] + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +# ============================================================================ +# Unit tests for prechunked passages +# ============================================================================ + +@pytest.mark.unit +def test_prechunked_encode_collator_basic(train_tokenizer): + """Test PreChunkedEncodeCollator with basic pre-chunked passages.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Pre-chunked passages: each passage is a list of chunk strings + features = [ + ("doc1", ["Hello world", "This is chunk 2", "Final chunk"], None, None, None), + ("doc2", ["Single chunk passage"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output: + # doc1: "Hello world" -> [9707, 1879] + EOS, "This is chunk 2" -> [1986, 374, 11879, 220, 17] + EOS, "Final chunk" -> [19357, 11879] + EOS + # Total: 12 tokens (11 content + 3 EOS), padded to 16 + expected_ids_0 = [9707, 1879, EOS_TOKEN_ID, 1986, 374, 11879, 220, 17, EOS_TOKEN_ID, 19357, 11879, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 4 + expected_mask_0 = [1] * 12 + [0] * 4 + expected_eos_positions_0 = [2, 8, 11] + + # doc2: "Single chunk passage" -> [10888, 11879, 21085] + EOS + # Total: 4 tokens (3 content + 1 EOS), padded to 16 + expected_ids_1 = [10888, 11879, 21085, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 12 + expected_mask_1 = [1] * 4 + [0] * 12 + expected_eos_positions_1 = [3] + + assert doc_ids == ["doc1", "doc2"] + assert d_collated["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + # Verify doc1 + got_ids_0 = d_collated["input_ids"][0].tolist() + got_mask_0 = d_collated["attention_mask"][0].tolist() + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_positions_0 + assert len(eos_positions[0]) == 3 + assert _strictly_increasing(eos_positions[0]) + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + # Verify doc2 + got_ids_1 = d_collated["input_ids"][1].tolist() + got_mask_1 = d_collated["attention_mask"][1].tolist() + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_positions_1 + assert len(eos_positions[1]) == 1 + for eos_pos in eos_positions[1]: + assert got_ids_1[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_1[eos_pos] == 1 + + +@pytest.mark.unit +def test_prechunked_encode_collator_hardcoded_output(train_tokenizer): + """Test PreChunkedEncodeCollator with hardcoded golden output.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Pre-chunked passages + features = [ + ("doc1", ["Hello", "world"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: + # "Hello" -> [9707] + EOS + # "world" -> [14615] + EOS (tokenized separately, different from "Hello world") + # Total: 4 tokens, padded to 16 + expected_ids = [9707, EOS_TOKEN_ID, 14615, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 12 + expected_mask = [1] * 4 + [0] * 12 + expected_eos_positions = [[1, 3]] + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +@pytest.mark.unit +def test_prechunked_encode_collator_max_length_truncation(train_tokenizer): + """Test PreChunkedEncodeCollator with max_length truncation.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=20, # Small max length to trigger truncation + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Create chunks that will exceed max_length + long_chunk = REAL_TEXT[:200] # Long chunk + features = [ + ("doc1", [long_chunk, "Second chunk", "Third chunk"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with max_length=20: + # First chunk (long_chunk) tokenizes to 19 tokens, then EOS is added at position 19 + # Total: 20 tokens (19 content + 1 EOS), which exactly fills max_length + # Second and third chunks are not included due to truncation + # Padded to 32 (multiple of 16) + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, + 7802, 82519, 4401, 323, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 12 + expected_mask = [1] * 20 + [0] * 12 + expected_eos_positions = [19] + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert len(got_ids) == 32 # Padded to multiple of 16 + assert sum(got_mask) == 20 # Exactly 20 tokens (19 content + 1 EOS) + + # Verify EOS positions are valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + assert eos_pos < len(got_ids) + + # Verify truncation: only first chunk fits, second and third chunks are not included + assert len(eos_positions[0]) == 1 # Only one EOS (from first chunk) + + +@pytest.mark.unit +def test_prechunked_encode_collator_left_padding(train_tokenizer): + """Test PreChunkedEncodeCollator with left padding.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="left", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", ["Hello", "world"], None, None, None), + ("doc2", ["Short"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + got_ids_0 = d_collated["input_ids"][0].tolist() + got_mask_0 = d_collated["attention_mask"][0].tolist() + got_ids_1 = d_collated["input_ids"][1].tolist() + got_mask_1 = d_collated["attention_mask"][1].tolist() + + # Both should be padded to same length (64, rounded to 64) + assert len(got_ids_0) == len(got_ids_1) + + # Verify EOS positions are adjusted for left padding + # doc1: [9707, EOS, 1879, EOS] = 4 tokens, padded to 64 -> 60 padding tokens + # EOS positions shift from [1, 3] to [61, 63] + assert len(eos_positions[0]) == 2 + assert eos_positions[0][0] > 1 # Should be shifted right + assert eos_positions[0][1] > 3 # Should be shifted right + + # Verify EOS tokens are at correct positions + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + +@pytest.mark.unit +def test_prechunked_encode_collator_empty_chunks(train_tokenizer): + """Test PreChunkedEncodeCollator with empty chunks list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", [], None, None, None), # Empty chunks + ("doc2", ["Non-empty"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + assert doc_ids == ["doc1", "doc2"] + assert len(eos_positions) == 2 + + # Empty chunks should have no EOS positions + assert eos_positions[0] == [] + + # Non-empty should have EOS positions + assert len(eos_positions[1]) > 0 + + +@pytest.mark.unit +def test_prechunked_encode_collator_multiple_passages_different_lengths(train_tokenizer): + """Test PreChunkedEncodeCollator with multiple passages of different chunk counts.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", ["Chunk 1", "Chunk 2"], None, None, None), # 2 chunks + ("doc2", ["Single chunk"], None, None, None), # 1 chunk + ("doc3", ["A", "B", "C", "D"], None, None, None), # 4 chunks + ] + + doc_ids, d_collated, eos_positions = collator(features) + + assert doc_ids == ["doc1", "doc2", "doc3"] + assert d_collated["input_ids"].shape[0] == 3 + assert len(eos_positions) == 3 + + # Verify each passage has correct number of EOS positions + assert len(eos_positions[0]) == 2 # doc1: 2 chunks + assert len(eos_positions[1]) == 1 # doc2: 1 chunk + assert len(eos_positions[2]) == 4 # doc3: 4 chunks + + # Verify all EOS positions are valid + for i in range(3): + got_ids = d_collated["input_ids"][i].tolist() + got_mask = d_collated["attention_mask"][i].tolist() + + assert _strictly_increasing(eos_positions[i]) + for eos_pos in eos_positions[i]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + + +@pytest.mark.unit +def test_prechunked_encode_collator_semantic_chunks(train_tokenizer): + """Test PreChunkedEncodeCollator with semantically chunked REAL_TEXT.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=512, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Use semantically chunked version of REAL_TEXT + features = [ + ("doc1", REAL_TEXT_SEMANTIC_CHUNKS, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with semantically chunked REAL_TEXT (8 chunks) + # Each semantic chunk is tokenized and separated by EOS tokens + # Total: 437 content tokens + 8 EOS tokens = 445 tokens, padded to 448 (multiple of 16) + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, EOS_TOKEN_ID, 32, 1555, 8569, 57330, 12635, + 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, + 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, EOS_TOKEN_ID, 1249, 8552, + 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, + EOS_TOKEN_ID, 641, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, + 73760, 572, 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, + 311, 220, 16, 13, 17, 19197, 441, 17, 58634, 13, EOS_TOKEN_ID, 641, 279, 44900, 47594, 315, + 279, 5306, 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, + 13, 17, 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, + 5080, 279, 12128, 7194, 572, 311, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, + 304, 279, 8622, 4158, 4925, 13, EOS_TOKEN_ID, 4703, 4991, 41434, 518, 4647, 8542, 5080, 3076, + 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, + 220, 16, 13, 16, 20, 51615, 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, + 15, 13, 15, 16, 21, 8, 323, 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, + 2480, 9663, 41434, 320, 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, + 220, 17, 17, 13, 24, 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, + 5306, 47639, 11, 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, + 51615, 220, 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, EOS_TOKEN_ID, 8121, 2408, + 301, 15479, 48674, 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, + 438, 4124, 438, 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, + 12864, 11799, 304, 4158, 4925, 23788, 7321, 13, EOS_TOKEN_ID, 785, 821, 13216, 429, 46516, + 15449, 315, 3015, 57330, 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, + 59645, 4158, 4925, 304, 5382, 41434, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 11 + expected_mask = [1] * 437 + [0] * 11 + expected_eos_positions = [24, 91, 125, 167, 229, 366, 409, 436] + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Verify structure: should have 8 EOS positions (one per semantic chunk) + assert doc_ids == ["doc1"] + assert d_collated["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert len(eos_positions[0]) == 8 # 8 semantic chunks + assert _strictly_increasing(eos_positions[0]) + + # Verify all EOS positions are valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + assert eos_pos < len(got_ids) + + # Verify that semantic chunks are preserved (each chunk ends with EOS) + # Check that we have content tokens between EOS positions + for i in range(len(eos_positions[0]) - 1): + chunk_start = eos_positions[0][i] + 1 # Start after EOS + chunk_end = eos_positions[0][i + 1] # End at next EOS + assert chunk_end > chunk_start # Should have content tokens between EOS markers + + # Verify total length is reasonable (should fit within max_length=512) + assert len(got_ids) == 448 # Padded to multiple of 16 + assert sum(got_mask) == 437 # 437 content tokens + assert len(got_ids) % 16 == 0 # Padded to multiple of 16 + + +# ============================================================================ +# Unit tests for random chunking in ChunkedEncodeCollator (inference/search) +# ============================================================================ + +@pytest.mark.unit +def test_chunked_encode_collator_random_chunk_size_range_fixed_per_passage(train_tokenizer): + """Test ChunkedEncodeCollator with random chunk size range, fixed per passage (inference).""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + random.seed(42) + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", REAL_TEXT, None, None, None), + ("doc2", REAL_TEXT, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=False + # With seed=42, random.randint(32, 64) generates: 39 for doc1, 33 for doc2 + # doc1: chunk_size=39 (chunk_len=38), produces 4 chunks: [38, 77, 116, 127] + # - Chunk 1: 38 tokens (0-37) + EOS at 38 + # - Chunk 2: 38 tokens (39-76) + EOS at 77 + # - Chunk 3: 38 tokens (78-115) + EOS at 116 + # - Chunk 4: 10 tokens (117-126) + EOS at 127 + # doc2: chunk_size=33 (chunk_len=32), produces 4 chunks: [32, 65, 98, 127] + # - Chunk 1: 32 tokens (0-31) + EOS at 32 + # - Chunk 2: 32 tokens (33-64) + EOS at 65 + # - Chunk 3: 32 tokens (66-97) + EOS at 98 + # - Chunk 4: 28 tokens (99-126) + EOS at 127 + expected_ids_0 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, + EOS_TOKEN_ID, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, EOS_TOKEN_ID, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_0 = [1] * 128 + expected_eos_positions_0 = [38, 77, 116, 127] + + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, EOS_TOKEN_ID, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, EOS_TOKEN_ID, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, EOS_TOKEN_ID, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, + 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_1 = [1] * 128 + expected_eos_positions_1 = [32, 65, 98, 127] + + # Verify structure + assert doc_ids == ["doc1", "doc2"] + assert d_collated["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + # Verify doc1 + got_ids_0 = d_collated["input_ids"][0].tolist() + got_mask_0 = d_collated["attention_mask"][0].tolist() + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_positions_0 + assert _strictly_increasing(eos_positions[0]) + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + # Verify doc2 + got_ids_1 = d_collated["input_ids"][1].tolist() + got_mask_1 = d_collated["attention_mask"][1].tolist() + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_positions_1 + assert _strictly_increasing(eos_positions[1]) + for eos_pos in eos_positions[1]: + assert got_ids_1[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_1[eos_pos] == 1 + + +@pytest.mark.unit +def test_chunked_encode_collator_random_chunk_size_range_variable_per_chunk(train_tokenizer): + """Test ChunkedEncodeCollator with random chunk size range, variable per chunk (inference).""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + random.seed(42) + + data_args = DataArguments( + passage_max_len=256, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=True, # Variable chunk size per chunk + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", REAL_TEXT, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=True + # With seed=42 and variable chunk sizes, each chunk gets a random size from [32, 64] + # Chunk sizes generated: 40, 34, 50, 48, 47, 41, 3 (last partial chunk) + # EOS positions: [38, 71, 120, 167, 213, 253, 255] + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, EOS_TOKEN_ID, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, EOS_TOKEN_ID, 2086, 882, 518, 4647, + 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, + 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, + 13, 17, 19197, 441, 17, 58634, 13, EOS_TOKEN_ID, 758, 279, 44900, 47594, 315, 279, 5306, + 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, + 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, + 279, 12128, 7194, 572, 311, EOS_TOKEN_ID, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, + 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, + 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, + EOS_TOKEN_ID, 17, EOS_TOKEN_ID + ] + expected_mask = [1] * 256 + expected_eos_positions = [38, 71, 120, 167, 213, 253, 255] + + # Verify structure + assert doc_ids == ["doc1"] + assert d_collated["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert _strictly_increasing(eos_positions[0]) + + # Verify each EOS position is valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + + +@pytest.mark.unit +def test_chunked_encode_collator_random_chunk_size_range_hardcoded_output(train_tokenizer): + """Test ChunkedEncodeCollator with random chunk size range - hardcoded golden output (inference).""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + random.seed(42) + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size_range="32,48", # Random chunk size between 32 and 48 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + short_text = "Hello world this is a test passage" + features = [ + ("doc1", short_text, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output with seed=42 and chunk_size_range=(32,48) + # short_text tokenizes to: [9707, 1879, 419, 374, 264, 1273, 21085] + # With seed=42, random.randint(32, 48) = 40 (first call) + # So chunk_len = 39, but we only have 7 tokens, so we get: [7 tokens] + EOS + expected_ids = [9707, 1879, 419, 374, 264, 1273, 21085, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 8 + expected_mask = [1] * 8 + [0] * 8 + expected_eos_positions = [[7]] + + assert doc_ids == ["doc1"] + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +@pytest.mark.unit +def test_chunked_encode_collator_fixed_chunk_size_still_works(train_tokenizer): + """Test ChunkedEncodeCollator with fixed chunk size (no random chunking) still works.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size=32, # Fixed chunk size, no random chunking + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", REAL_TEXT, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Verify structure + assert doc_ids == ["doc1"] + assert d_collated["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + assert len(eos_positions[0]) > 0 + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Verify EOS positions are strictly increasing + assert _strictly_increasing(eos_positions[0]) + + # Verify each EOS position is valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1