diff --git a/finetune.sh b/finetune.sh new file mode 100755 index 00000000..dd983c33 --- /dev/null +++ b/finetune.sh @@ -0,0 +1,27 @@ +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.train \ + --output_dir retriever-qwen3-emb-ft-chunk-1219-no-chunk-4-group-512-passage \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --do_train \ + --lora \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 50 \ + --dataset_name Tevatron/scifact \ + --dataset_split train \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --passage_prefix "" \ + --bf16 \ + --pooling last \ + --padding_side left \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 4 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 32 \ + --passage_max_len 512 \ + --num_train_epochs 10 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --gradient_accumulation_steps 1 \ + --passage_chunk_size 0 diff --git a/finetune_with_chunk.sh b/finetune_with_chunk.sh new file mode 100755 index 00000000..712fdf09 --- /dev/null +++ b/finetune_with_chunk.sh @@ -0,0 +1,27 @@ +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.train \ + --output_dir retriever-qwen3-emb-ft-chunk-1219-1 \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --do_train \ + --lora \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 50 \ + --dataset_name Tevatron/scifact \ + --dataset_split train \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --passage_prefix "" \ + --bf16 \ + --pooling last \ + --padding_side right \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 4 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 32 \ + --passage_max_len 512 \ + --num_train_epochs 10 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --gradient_accumulation_steps 1 \ + --passage_chunk_size 256 diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..9b649206 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + unit: fast unit tests (no external downloads) + diff --git a/req.txt b/req.txt new file mode 100644 index 00000000..b033b240 --- /dev/null +++ b/req.txt @@ -0,0 +1,271 @@ +accelerate==1.10.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.0 +aiosignal==1.4.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.11.0 +astor==0.8.1 +attrs==25.4.0 +audioread==3.0.1 +Authlib==1.6.5 +av==16.0.1 +beautifulsoup4==4.14.2 +beir==2.2.0 +blake3==1.0.8 +blinker==1.9.0 +blis==1.3.0 +cachetools==6.2.1 +catalogue==2.0.10 +cbor==1.0.0 +cbor2==5.7.0 +certifi==2025.10.5 +cffi==2.0.0 +charset-normalizer==3.4.3 +click==8.2.1 +clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 +cloudpathlib==0.23.0 +cloudpickle==3.1.1 +coloredlogs==15.0.1 +compressed-tensors==0.11.0 +confection==0.1.5 +contourpy==1.3.3 +cramjam==2.11.0 +cryptography==46.0.2 +cupy-cuda12x==13.6.0 +cycler==0.12.1 +cyclopts==3.24.0 +cymem==2.0.11 +Cython==3.1.4 +datasets==2.19.0 +decorator==5.2.1 +decord==0.6.0 +deepspeed==0.18.0 +depyf==0.19.0 +dill==0.3.8 +diskcache==5.6.3 +distro==1.9.0 +dnspython==2.8.0 +docstring_parser==0.17.0 +docutils==0.22.2 +einops==0.8.1 +email-validator==2.3.0 +exceptiongroup==1.3.0 +fairscale==0.4.13 +faiss-cpu==1.12.0 +fastapi==0.119.0 +fastapi-cli==0.0.13 +fastapi-cloud-cli==0.3.1 +fastmcp==2.12.4 +fastparquet==2024.11.0 +fastrlock==0.8.3 +filelock==3.20.0 +flash_attn==2.8.3 +Flask==3.1.2 +flatbuffers==25.9.23 +fonttools==4.60.1 +frozendict==2.4.6 +frozenlist==1.8.0 +fsspec==2024.3.1 +ftfy==6.3.1 +gguf==0.17.1 +h11==0.16.0 +hf-xet==1.1.10 +hjson==3.1.0 +httpcore==1.0.9 +httptools==0.7.1 +httpx==0.28.1 +httpx-sse==0.4.3 +huggingface-hub==0.35.3 +humanfriendly==10.0 +idna==3.10 +ijson==3.4.0.post0 +iniconfig==2.3.0 +inscriptis==2.6.0 +interegular==0.3.3 +ir_datasets==0.5.11 +isodate==0.7.2 +itsdangerous==2.2.0 +Jinja2==3.1.6 +jiter==0.11.0 +joblib==1.5.2 +jsonschema==4.25.1 +jsonschema-path==0.3.4 +jsonschema-specifications==2025.9.1 +kiwisolver==1.4.9 +langcodes==3.5.0 +language_data==1.3.0 +lark==1.2.2 +lazy-object-proxy==1.12.0 +lazy_loader==0.4 +librosa==0.11.0 +llguidance==0.7.30 +llvmlite==0.44.0 +lm-format-enforcer==0.11.3 +lxml==6.0.2 +lz4==4.4.4 +marisa-trie==1.3.1 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +matplotlib==3.10.7 +mcp==1.17.0 +mdurl==0.1.2 +mistral_common==1.8.5 +ml_dtypes==0.5.3 +more-itertools==10.8.0 +mpmath==1.3.0 +msgpack==1.1.2 +msgspec==0.19.0 +multidict==6.7.0 +multiprocess==0.70.16 +murmurhash==1.0.13 +networkx==3.5 +ninja==1.13.0 +numba==0.61.2 +numpy==2.2.6 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.27.3 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.8.90 +omegaconf==2.3.0 +onnx==1.19.1 +onnxoptimizer==0.3.13 +onnxruntime==1.23.1 +openai==2.3.0 +openai-harmony==0.0.4 +openapi-core==0.19.5 +openapi-pydantic==0.5.1 +openapi-schema-validator==0.6.3 +openapi-spec-validator==0.7.2 +opencv-python==4.12.0.88 +opencv-python-headless==4.12.0.88 +orjson==3.11.3 +outlines_core==0.2.11 +packaging==25.0 +pandas==2.3.3 +parse==1.20.2 +partial-json-parser==0.2.1.1.post6 +pathable==0.4.4 +peft==0.17.1 +pillow==11.3.0 +platformdirs==4.5.0 +pluggy==1.6.0 +pooch==1.8.2 +preshed==3.0.10 +prometheus-fastapi-instrumentator==7.1.0 +prometheus_client==0.23.1 +propcache==0.4.1 +protobuf==6.32.1 +psutil==7.1.0 +py-cpuinfo==9.0.0 +pyarrow==21.0.0 +pyarrow-hotfix==0.7 +pybase64==1.4.2 +pybind11==3.0.1 +pycountry==24.6.1 +pycparser==2.23 +pydantic==2.12.0 +pydantic-extra-types==2.10.6 +pydantic-settings==2.11.0 +pydantic_core==2.41.1 +Pygments==2.19.2 +pyjnius==1.7.0 +pynndescent==0.5.13 +pyparsing==3.2.5 +pyperclip==1.11.0 +-e git+ssh://git@github.com/FarmersWrap/pyserini.git@a1995bffa243636c89029735236348c1e5206161#egg=pyserini +pytest==9.0.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-json-logger==4.0.0 +python-multipart==0.0.20 +pytrec_eval-terrier==0.5.9 +pytz==2025.2 +PyYAML==6.0.3 +pyzmq==27.1.0 +qwen-omni-utils==0.0.8 +ranx==0.3.21 +ray==2.50.0 +referencing==0.36.2 +regex==2025.9.18 +requests==2.32.5 +rfc3339-validator==0.1.4 +rich==14.2.0 +rich-rst==1.3.1 +rich-toolkit==0.15.1 +rignore==0.7.1 +rpds-py==0.27.1 +safetensors==0.6.2 +scikit-learn==1.7.2 +scipy==1.16.2 +seaborn==0.13.2 +sentence-transformers==5.1.1 +sentencepiece==0.2.1 +sentry-sdk==2.42.0 +setproctitle==1.3.7 +setuptools==80.9.0 +shellingham==1.5.4 +six==1.17.0 +smart_open==7.3.1 +sniffio==1.3.1 +soundfile==0.13.1 +soupsieve==2.8 +soxr==1.0.0 +spacy==3.8.7 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.5.1 +sse-starlette==3.0.2 +starlette==0.48.0 +sympy==1.14.0 +tabulate==0.9.0 +-e git+ssh://git@github.com/FarmersWrap/tevatron.git@add3832f2071525e257658cbe42cf9f9bbb3b928#egg=tevatron +thinc==8.3.6 +threadpoolctl==3.6.0 +tiktoken==0.12.0 +timm==1.0.20 +tokenizers==0.22.1 +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 +tqdm==4.67.1 +transformers==4.57.0 +trec-car-tools==2.6 +triton==3.4.0 +typeguard==4.4.4 +typer==0.19.2 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.2 +umap-learn==0.5.9.post2 +uniir_for_pyserini==0.1.1 +unlzw3==0.2.3 +urllib3==2.5.0 +uvicorn==0.37.0 +uvloop==0.22.1 +vllm==0.11.0 +warc3-wet==0.2.5 +warc3-wet-clueweb09==0.2.5 +wasabi==1.1.3 +watchfiles==1.1.1 +wcwidth==0.2.14 +weasel==0.4.1 +websockets==15.0.1 +Werkzeug==3.1.1 +wheel==0.45.1 +wrapt==1.17.3 +xformers==0.0.32.post1 +xgrammar==0.1.25 +xxhash==3.6.0 +yarl==1.22.0 +zlib-state==0.1.10 diff --git a/run_retrieval.sh b/run_retrieval.sh new file mode 100755 index 00000000..9ee8d347 --- /dev/null +++ b/run_retrieval.sh @@ -0,0 +1,65 @@ +output_dir=retriever-qwen3-emb-ft-chunk-1219-no-chunk-4-group-512-passage +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --query_max_len 512 \ + --dataset_name Tevatron/beir \ + --dataset_config scifact \ + --dataset_split test \ + --encode_output_path ${output_dir}/queries_scifact.pkl \ + --encode_is_query + + +# Encode corpus +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --passage_prefix "" \ + --passage_max_len 512 \ + --dataset_name Tevatron/beir-corpus \ + --dataset_config scifact \ + --dataset_split train \ + --encode_output_path ${output_dir}/corpus_scifact.pkl \ + --passage_chunk_size 0 + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 100 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 100 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query +python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-scifact-test ${output_dir}/rank.scifact.trec + +# recall_100 all 0.9767 +# ndcg_cut_10 all 0.7801 + diff --git a/run_retrieval_chunked.sh b/run_retrieval_chunked.sh new file mode 100755 index 00000000..b80ae37d --- /dev/null +++ b/run_retrieval_chunked.sh @@ -0,0 +1,65 @@ +output_dir=retriever-qwen3-emb-ft-chunk-1219-1 +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --query_prefix "Instruct: Given a scientific claim, retrieve documents that support or refute the claim.\nQuery:" \ + --query_max_len 512 \ + --dataset_name Tevatron/beir \ + --dataset_config scifact \ + --dataset_split test \ + --encode_output_path ${output_dir}/queries_scifact.pkl \ + --encode_is_query + + +# Encode corpus +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path Qwen/Qwen3-Embedding-0.6B \ + --bf16 \ + --per_device_eval_batch_size 4 \ + --normalize \ + --pooling last \ + --padding_side right \ + --passage_prefix "" \ + --passage_max_len 512 \ + --dataset_name Tevatron/beir-corpus \ + --dataset_config scifact \ + --dataset_split train \ + --encode_output_path ${output_dir}/corpus_scifact.pkl \ + --passage_chunk_size 256 + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 100 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query + +python -m tevatron.retriever.driver.search \ + --query_reps ${output_dir}/queries_scifact.pkl \ + --passage_reps ${output_dir}/corpus_scifact.pkl \ + --depth 1000 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to ${output_dir}/rank.scifact.txt + +# Convert to TREC format +python -m tevatron.utils.format.convert_result_to_trec --input ${output_dir}/rank.scifact.txt \ + --output ${output_dir}/rank.scifact.trec \ + --remove_query +python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-scifact-test ${output_dir}/rank.scifact.trec + +# recall_100 all 0.9767 +# ndcg_cut_10 all 0.7801 + diff --git a/src/tevatron/retriever/arguments.py b/src/tevatron/retriever/arguments.py index 00034903..a0c6ce5f 100644 --- a/src/tevatron/retriever/arguments.py +++ b/src/tevatron/retriever/arguments.py @@ -203,6 +203,25 @@ class DataArguments: metadata={"help": "padding side for the tokenizer, can be 'left' or 'right'"} ) + passage_chunk_size: int = field( + default=0, + metadata={"help": "Chunk size for chunked passage encoding with MaxSim. 0=disabled, >0=chunk size in tokens"} + ) + + passage_chunk_size_range: Optional[str] = field( + default=None, + metadata={"help": "Chunk size range for random chunking (e.g., '64,128'). Randomly selects chunk size in [min, max] range per passage. Works for both training and inference."} + ) + + passage_chunk_size_variable: bool = field( + default=False, + metadata={"help": "If True and passage_chunk_size_range is set, each chunk within a passage gets a random size from the range. If False, all chunks in a passage use the same random size. Works for both training and inference."} + ) + + encode_use_pre_chunked: bool = field( + default=False, + metadata={"help": "If True, expects dataset with 'chunks' field (list of pre-chunked passage strings). EOS tokens will be added between chunks. If False, uses regular 'text' field. Only for encoding (not training)."} + ) @dataclass diff --git a/src/tevatron/retriever/collator.py b/src/tevatron/retriever/collator.py index 20e02ef5..ca739f65 100644 --- a/src/tevatron/retriever/collator.py +++ b/src/tevatron/retriever/collator.py @@ -1,17 +1,176 @@ import logging +import random import torch -from typing import List, Tuple +from typing import List, Tuple, Optional from dataclasses import dataclass from transformers import PreTrainedTokenizer, ProcessorMixin from qwen_omni_utils import process_mm_info from PIL import Image +from rich import print from tevatron.retriever.arguments import DataArguments - +torch.set_printoptions(threshold=float('inf'), linewidth=10000) logger = logging.getLogger(__name__) +def _chunk_tokens( + tokens: List[int], + chunk_size: int, + eos_token_id: int, + max_length: int = None, + chunk_size_range: Optional[Tuple[int, int]] = None, +) -> Tuple[List[int], List[int]]: + """ + Chunk tokens into chunks with EOS separators. + + :param tokens: Token IDs to chunk + :param chunk_size: Fixed chunk size (before EOS). Must be >= 2. Used when chunk_size_range is None. + :param eos_token_id: EOS token ID to append after each chunk + :param max_length: Optional max total length (including EOS). If None, no limit. + :param chunk_size_range: Optional (min, max) tuple for variable chunk sizes. If set, each chunk uses a random size in [min, max]. + :return: (chunked_ids, eos_positions) - token IDs with EOS separators and EOS positions + """ + # Validate and set up chunk size parameters + if chunk_size_range: + chunk_size_min, chunk_size_max = chunk_size_range + use_variable_sizes = True + else: + if chunk_size < 2: + return [], [] + use_variable_sizes = False + + # Chunk tokens and add EOS after each chunk + ids = [] + eos_pos = [] + i = 0 + total_length = 0 + + while i < len(tokens): + # Pick chunk size for this chunk + if use_variable_sizes: + current_chunk_size = random.randint(chunk_size_min, chunk_size_max) + else: + current_chunk_size = chunk_size + + # Check if we would exceed max_length with this chunk + if max_length and total_length + current_chunk_size > max_length: + # Use remaining space (leave 1 for EOS if possible) + remaining = max_length - total_length - 1 + if remaining > 0: + take = min(remaining, len(tokens) - i) + ids.extend(tokens[i:i + take]) + ids.append(eos_token_id) + eos_pos.append(len(ids) - 1) + break + + # Take tokens for this chunk (reserve 1 slot for EOS) + current_chunk_len = current_chunk_size - 1 + take = min(current_chunk_len, len(tokens) - i) + ids.extend(tokens[i:i + take]) + ids.append(eos_token_id) + eos_pos.append(len(ids) - 1) + + total_length += take + 1 # +1 for EOS + i += take + + return ids, eos_pos + +def _pad_and_adjust_eos_positions( + all_input_ids: List[List[int]], + all_eos_positions: List[List[int]], + tokenizer: PreTrainedTokenizer, + padding_side: str, + pad_to_multiple_of: int, +) -> Tuple[dict, List[List[int]]]: + """ + Pad input IDs and adjust EOS positions for left padding. + + :param all_input_ids: List of token ID lists (one per passage) + :param all_eos_positions: List of EOS position lists (one per passage) + :param tokenizer: Tokenizer for padding + :param padding_side: 'left' or 'right' + :param pad_to_multiple_of: Pad to multiple of this value + :return: (padded_dict, adjusted_eos_positions) - padded tensors and adjusted EOS positions + """ + d_collated = {'input_ids': all_input_ids} + original_lengths = [len(ids) for ids in all_input_ids] + tokenizer.padding_side = padding_side + + d_collated = tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) + + # Shift EOS positions for left padding + adjusted_eos_positions = [list(eos_pos_list) for eos_pos_list in all_eos_positions] + if padding_side == 'left': + padded_lengths = d_collated['input_ids'].shape[1] + for i, eos_pos_list in enumerate(adjusted_eos_positions): + padding_length = padded_lengths - original_lengths[i] + adjusted_eos_positions[i] = [pos + padding_length for pos in eos_pos_list] + + return d_collated, adjusted_eos_positions + + +def _tokenize_and_pad_chunked_passages( + passages: List[str], + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, + chunk_sizes: Optional[List[int]] = None, + chunk_size_range: Optional[Tuple[int, int]] = None, +) -> Tuple[dict, List[List[int]]]: + """ + Tokenize and chunk passages with EOS separators. Each chunk ends with EOS for embedding extraction. + + :param passages: Passage texts to tokenize and chunk + :param tokenizer: Tokenizer for encoding + :param data_args: DataArguments with chunk_size, max_len, pad_to_multiple_of + :param chunk_sizes: Optional list of chunk sizes (one per passage). If None, uses data_args.passage_chunk_size + :param chunk_size_range: Optional (min, max) tuple for variable chunk sizes per chunk. If set, each chunk within a passage uses a random size. + :return: (collated_dict, eos_positions) - padded tensors and EOS positions per passage + """ + eos_id = tokenizer.eos_token_id + if eos_id is None: + raise ValueError("tokenizer.eos_token_id is None; cannot chunk passages with EOS separators.") + if chunk_sizes is not None and len(chunk_sizes) != len(passages): + raise ValueError(f"chunk_sizes length ({len(chunk_sizes)}) must match passages length ({len(passages)})") + max_length = data_args.passage_max_len # cap total length (incl. EOS per chunk) + + all_input_ids = [] + all_eos_positions = [] + + for idx, passage in enumerate(passages): + if passage is None: + passage = "" + tokens = tokenizer.encode(passage, add_special_tokens=False) + # Use per-passage chunk size if provided, otherwise use fixed chunk size + # Note: chunk_size is ignored in _chunk_tokens when chunk_size_range is provided + chunk_size = chunk_sizes[idx] if chunk_sizes is not None else data_args.passage_chunk_size + ids, eos_pos = _chunk_tokens( + tokens=tokens, + chunk_size=chunk_size, + eos_token_id=eos_id, + max_length=max_length, + chunk_size_range=chunk_size_range, + ) + all_input_ids.append(ids) + all_eos_positions.append(eos_pos) + + d_collated, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=tokenizer, + padding_side=data_args.padding_side, + pad_to_multiple_of=data_args.pad_to_multiple_of, + ) + + return d_collated, adjusted_eos_positions + + @dataclass class TrainCollator: """ @@ -24,7 +183,7 @@ def __call__(self, features: List[Tuple[str, List[str]]]): """ Collate function for training. :param features: list of (query, passages) tuples - :return: tokenized query_ids, passage_ids + :return: tokenized query_ids, passage_ids, [eos_positions if chunked] """ all_queries = [f[0] for f in features] all_passages = [] @@ -32,6 +191,7 @@ def __call__(self, features: List[Tuple[str, List[str]]]): all_passages.extend(f[1]) all_queries = [q[0] for q in all_queries] all_passages = [p[0] for p in all_passages] + q_collated = self.tokenizer( all_queries, padding=False, @@ -41,20 +201,8 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return_token_type_ids=False, add_special_tokens=True, ) - d_collated = self.tokenizer( - all_passages, - padding=False, - truncation=True, - max_length=self.data_args.passage_max_len-1 if self.data_args.append_eos_token else self.data_args.passage_max_len, - return_attention_mask=False, - return_token_type_ids=False, - add_special_tokens=True, - ) - if self.data_args.append_eos_token: q_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in q_collated['input_ids']] - d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']] - q_collated = self.tokenizer.pad( q_collated, padding=True, @@ -62,14 +210,64 @@ def __call__(self, features: List[Tuple[str, List[str]]]): return_attention_mask=True, return_tensors='pt', ) - d_collated = self.tokenizer.pad( - d_collated, - padding=True, - pad_to_multiple_of=self.data_args.pad_to_multiple_of, - return_attention_mask=True, - return_tensors='pt', - ) - return q_collated, d_collated + + # Check if we should use chunking (fixed or random) + use_fixed_chunking = self.data_args.passage_chunk_size > 0 + + if self.data_args.passage_chunk_size_range is not None: + # Parse range string (e.g., "64, 128" or "64,128") + try: + parts = [p.strip() for p in self.data_args.passage_chunk_size_range.split(',')] + if len(parts) != 2: + raise ValueError(f"passage_chunk_size_range must contain exactly 2 values separated by comma, got: {self.data_args.passage_chunk_size_range}") + chunk_size_min = int(parts[0]) + chunk_size_max = int(parts[1]) + except ValueError as e: + raise ValueError(f"Invalid passage_chunk_size_range format '{self.data_args.passage_chunk_size_range}'. Expected format: 'min,max' (e.g., '64,128')") from e + + # Validate range + if chunk_size_min < 2: + raise ValueError(f"Minimum chunk size must be >= 2, got {chunk_size_min}") + if chunk_size_max < chunk_size_min: + raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") + + if self.data_args.passage_chunk_size_variable: + # Variable chunk sizes: each chunk within a passage gets a random size + # Pass the range to _chunk_tokens, which will randomly pick a size for each chunk + chunk_size_range = (chunk_size_min, chunk_size_max) + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_size_range=chunk_size_range) + else: + # Fixed random chunk size per passage: all chunks in a passage use the same random size + # Generate random chunk sizes for each passage + chunk_sizes = [random.randint(chunk_size_min, chunk_size_max) for _ in all_passages] + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages, chunk_sizes=chunk_sizes) + return q_collated, d_collated, eos_positions + elif use_fixed_chunking: + d_collated, eos_positions = self._tokenize_and_pad_chunked_passages(all_passages) + return q_collated, d_collated, eos_positions + else: + d_collated = self.tokenizer( + all_passages, + padding=False, + truncation=True, + max_length=self.data_args.passage_max_len-1 if self.data_args.append_eos_token else self.data_args.passage_max_len, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=True, + ) + if self.data_args.append_eos_token: + d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']] + d_collated = self.tokenizer.pad( + d_collated, + padding=True, + pad_to_multiple_of=self.data_args.pad_to_multiple_of, + return_attention_mask=True, + return_tensors='pt', + ) + return q_collated, d_collated + + def _tokenize_and_pad_chunked_passages(self, passages: List[str], chunk_sizes: Optional[List[int]] = None, chunk_size_range: Optional[Tuple[int, int]] = None): + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args, chunk_sizes=chunk_sizes, chunk_size_range=chunk_size_range) @dataclass @@ -222,6 +420,157 @@ def __call__(self, features): ) return content_ids, collated_inputs + +@dataclass +class ChunkedEncodeCollator: + """Collator for chunked passage encoding (inference/search). Supports fixed or random chunk sizes.""" + data_args: DataArguments + tokenizer: PreTrainedTokenizer + + def __call__(self, features): + """ + Collate chunked passage encoding features. + :param features: List of (doc_id, text, image, video, audio) tuples + :return: (doc_ids, collated_inputs, eos_positions) + """ + doc_ids = [x[0] for x in features] + texts = [x[1] for x in features] + + # Check if we should use random chunking + if self.data_args.passage_chunk_size_range is not None: + # Parse range string (e.g., "64, 128" or "64,128") + try: + parts = [p.strip() for p in self.data_args.passage_chunk_size_range.split(',')] + if len(parts) != 2: + raise ValueError(f"passage_chunk_size_range must contain exactly 2 values separated by comma, got: {self.data_args.passage_chunk_size_range}") + chunk_size_min = int(parts[0]) + chunk_size_max = int(parts[1]) + except ValueError as e: + raise ValueError(f"Invalid passage_chunk_size_range format '{self.data_args.passage_chunk_size_range}'. Expected format: 'min,max' (e.g., '64,128')") from e + + # Validate range + if chunk_size_min < 2: + raise ValueError(f"Minimum chunk size must be >= 2, got {chunk_size_min}") + if chunk_size_max < chunk_size_min: + raise ValueError(f"Maximum chunk size ({chunk_size_max}) must be >= minimum chunk size ({chunk_size_min})") + + if self.data_args.passage_chunk_size_variable: + # Variable chunk sizes: each chunk within a passage gets a random size + chunk_size_range = (chunk_size_min, chunk_size_max) + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts, chunk_size_range=chunk_size_range) + else: + # Fixed random chunk size per passage: all chunks in a passage use the same random size + # Generate random chunk sizes for each passage + chunk_sizes = [random.randint(chunk_size_min, chunk_size_max) for _ in texts] + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts, chunk_sizes=chunk_sizes) + else: + # Use fixed chunking for inference + d_collated, all_eos_positions = self._tokenize_and_pad_chunked_passages(texts) + + return doc_ids, d_collated, all_eos_positions + + def _tokenize_and_pad_chunked_passages(self, passages: List[str], chunk_sizes: Optional[List[int]] = None, chunk_size_range: Optional[Tuple[int, int]] = None): + return _tokenize_and_pad_chunked_passages(passages, self.tokenizer, self.data_args, chunk_sizes=chunk_sizes, chunk_size_range=chunk_size_range) + + +@dataclass +class PreChunkedEncodeCollator: + """ + Collator for pre-chunked passage encoding (inference/search). + Expects passages as lists of pre-chunked strings and adds EOS tokens between chunks. + """ + data_args: DataArguments + tokenizer: PreTrainedTokenizer + + def __call__(self, features): + """ + Collate pre-chunked passage encoding features. + :param features: List of (doc_id, chunks_list, image, video, audio) tuples + where chunks_list is a list of pre-chunked passage strings + :return: (doc_ids, collated_inputs, eos_positions) + """ + doc_ids = [x[0] for x in features] + chunks_lists = [x[1] for x in features] # List of lists of strings + + # Process pre-chunked passages: tokenize each chunk and add EOS between them + d_collated, all_eos_positions = self._tokenize_and_pad_pre_chunked_passages(chunks_lists) + + return doc_ids, d_collated, all_eos_positions + + def _tokenize_and_pad_pre_chunked_passages(self, chunks_lists: List[List[str]]): + """ + Tokenize pre-chunked passages and add EOS tokens between chunks. + + This is used when you have pre-chunked passages (e.g., from ChatGPT or manual chunking). + Each chunk is tokenized separately, and EOS tokens are inserted between chunks. + + :param chunks_lists: List of lists, where each inner list contains pre-chunked passage strings + Example: [["chunk1", "chunk2"], ["chunk3"]] for 2 passages + :return: (collated_dict, eos_positions) - padded tensors and EOS positions per passage + """ + eos_id = self.tokenizer.eos_token_id + if eos_id is None: + raise ValueError("tokenizer.eos_token_id is None; cannot add EOS tokens between chunks.") + + max_length = self.data_args.passage_max_len + all_input_ids = [] + all_eos_positions = [] + + for chunks in chunks_lists: + if chunks is None: + chunks = [] + if not isinstance(chunks, list): + raise ValueError(f"Expected list of chunks, got {type(chunks)}") + if len(chunks) == 0: + # Empty chunks list - create empty passage with no EOS positions + all_input_ids.append([]) + all_eos_positions.append([]) + continue + + # Tokenize each chunk and concatenate with EOS between them + ids = [] + eos_pos = [] + total_length = 0 + + for chunk_idx, chunk in enumerate(chunks): + if chunk is None: + chunk = "" + # Tokenize this chunk (without special tokens, we'll add EOS manually) + chunk_tokens = self.tokenizer.encode(chunk, add_special_tokens=False) + + # Check if adding this chunk + EOS would exceed max_length + chunk_size = len(chunk_tokens) + if max_length and total_length + chunk_size + 1 > max_length: + # Use remaining space (leave 1 for EOS if possible) + remaining = max_length - total_length - 1 + if remaining > 0: + chunk_tokens = chunk_tokens[:remaining] + ids.extend(chunk_tokens) + ids.append(eos_id) + eos_pos.append(len(ids) - 1) + break + + # Add chunk tokens + ids.extend(chunk_tokens) + # Add EOS after each chunk + ids.append(eos_id) + eos_pos.append(len(ids) - 1) + total_length += chunk_size + 1 + + all_input_ids.append(ids) + all_eos_positions.append(eos_pos) + + d_collated, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=self.tokenizer, + padding_side=self.data_args.padding_side, + pad_to_multiple_of=self.data_args.pad_to_multiple_of, + ) + + return d_collated, adjusted_eos_positions + + @dataclass class MultiModalEncodeCollator: """ diff --git a/src/tevatron/retriever/dataset.py b/src/tevatron/retriever/dataset.py index ae3fdb57..18c8fc68 100644 --- a/src/tevatron/retriever/dataset.py +++ b/src/tevatron/retriever/dataset.py @@ -130,6 +130,7 @@ def __getitem__(self, item): # Select negative documents negative_size = self.data_args.train_group_size - 1 if len(group['negative_passages']) < negative_size: + print(f"selected_negatives: Randomly selected!!!!!!!!!!!!!!!!!!!!!!!!!!") selected_negatives = random.choices(group['negative_passages'], k=negative_size) elif self.data_args.train_group_size == 1: selected_negatives = [] @@ -292,10 +293,22 @@ def __getitem__(self, item): content_audio = content.get('query_audio', None) else: content_id = content['docid'] - content_text = content.get('text', '') - if 'title' in content: - content_text = content['title'] + ' ' + content_text - content_text = self.data_args.passage_prefix + content_text.strip() + # Support pre-chunked passages (for custom chunking with ChatGPT, etc.) + if self.data_args.encode_use_pre_chunked and 'chunks' in content: + # Pre-chunked: return chunks as a list + chunks = content['chunks'] + if not isinstance(chunks, list): + raise ValueError(f"Expected 'chunks' to be a list, got {type(chunks)}") + # Apply prefix to each chunk if needed + if self.data_args.passage_prefix: + chunks = [self.data_args.passage_prefix + chunk if chunk else chunk for chunk in chunks] + content_text = chunks # Return as list for pre-chunked collator + else: + # Regular text field + content_text = content.get('text', '') + if 'title' in content: + content_text = content['title'] + ' ' + content_text + content_text = self.data_args.passage_prefix + content_text.strip() content_image = content.get('image', None) content_video = content.get('video', None) content_audio = content.get('audio', None) @@ -320,7 +333,11 @@ def __getitem__(self, item): content_audio = None if not self.data_args.encode_text: - content_text = None + # For pre-chunked mode, set to empty list instead of None + if self.data_args.encode_use_pre_chunked and isinstance(content_text, list): + content_text = [] + else: + content_text = None if not self.data_args.encode_image: content_image = None if not self.data_args.encode_video: diff --git a/src/tevatron/retriever/driver/encode.py b/src/tevatron/retriever/driver/encode.py index 8749dfda..d816e5da 100644 --- a/src/tevatron/retriever/driver/encode.py +++ b/src/tevatron/retriever/driver/encode.py @@ -8,6 +8,9 @@ from tqdm import tqdm import torch +import torch.nn.functional as F + +from rich import print from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -18,7 +21,7 @@ from tevatron.retriever.arguments import ModelArguments, DataArguments, \ TevatronTrainingArguments as TrainingArguments from tevatron.retriever.dataset import EncodeDataset -from tevatron.retriever.collator import EncodeCollator +from tevatron.retriever.collator import EncodeCollator, ChunkedEncodeCollator, PreChunkedEncodeCollator from tevatron.retriever.modeling import EncoderOutput, DenseModel logger = logging.getLogger(__name__) @@ -51,7 +54,8 @@ def main(): ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - + + tokenizer.eos_token_id = tokenizer.pad_token_id if data_args.padding_side == 'right': tokenizer.padding_side = 'right' else: @@ -78,10 +82,41 @@ def main(): data_args=data_args, ) - encode_collator = EncodeCollator( - data_args=data_args, - tokenizer=tokenizer, - ) + use_chunked = not data_args.encode_is_query and data_args.passage_chunk_size > 0 + use_pre_chunked = not data_args.encode_is_query and data_args.encode_use_pre_chunked + use_random_chunking = not data_args.encode_is_query and data_args.passage_chunk_size_range is not None + print("data_args.encode_is_query: ", data_args.encode_is_query) + print("data_args.passage_chunk_size: ", data_args.passage_chunk_size) + print("data_args.passage_chunk_size_range: ", data_args.passage_chunk_size_range) + print("data_args.passage_chunk_size_variable: ", data_args.passage_chunk_size_variable) + print("data_args.encode_use_pre_chunked: ", data_args.encode_use_pre_chunked) + print("use_chunked: ", use_chunked) + print("use_pre_chunked: ", use_pre_chunked) + print("use_random_chunking: ", use_random_chunking) + + if use_pre_chunked: + logger.info("Using pre-chunked passage encoding (custom EOS positions from pre-chunked data)") + model.passage_chunk_size = 1 # Signal to use chunked encoding + encode_collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) + elif use_chunked or use_random_chunking: + if use_random_chunking: + logger.info(f"Using random chunked passage encoding with chunk_size_range={data_args.passage_chunk_size_range}, variable={data_args.passage_chunk_size_variable}") + else: + logger.info(f"Using chunked passage encoding with chunk_size={data_args.passage_chunk_size}") + # For random chunking, we still need a base chunk_size for the model + # Use the minimum of the range if random chunking is enabled + if use_random_chunking: + try: + parts = [p.strip() for p in data_args.passage_chunk_size_range.split(',')] + chunk_size_min = int(parts[0]) + model.passage_chunk_size = chunk_size_min + except: + model.passage_chunk_size = data_args.passage_chunk_size if data_args.passage_chunk_size > 0 else 64 + else: + model.passage_chunk_size = data_args.passage_chunk_size + encode_collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=tokenizer) + else: + encode_collator = EncodeCollator(data_args=data_args, tokenizer=tokenizer) encode_loader = DataLoader( encode_dataset, @@ -96,23 +131,56 @@ def main(): model = model.to(training_args.device) model.eval() - for (batch_ids, batch) in tqdm(encode_loader): - lookup_indices.extend(batch_ids) + for batch in tqdm(encode_loader): with torch.amp.autocast('cuda') if training_args.fp16 or training_args.bf16 else nullcontext(): with torch.no_grad(): - for k, v in batch.items(): - batch[k] = v.to(training_args.device) - if data_args.encode_is_query: - model_output: EncoderOutput = model(query=batch) - encoded.append(model_output.q_reps.cpu().detach().numpy()) + if use_pre_chunked or use_chunked or use_random_chunking: + doc_ids, batch_inputs, eos_positions = batch + # batch_inputs: input_ids, attention_mask + for k, v in batch_inputs.items(): + batch_inputs[k] = v.to(training_args.device) + print(f"eos_positions: {eos_positions}") + chunk_embs, chunk_mask = model.encode_passage(batch_inputs, eos_positions) + # chunk_embs: [batch_size, max_chunks, hidden_size] + # chunk_mask: [batch_size, max_chunks] + batch_size, max_chunks, hidden_size = chunk_embs.shape + for i, doc_id in enumerate(doc_ids): + for chunk_idx in range(max_chunks): + if chunk_mask[i, chunk_idx] > 0: # Valid chunk + encoded.append(chunk_embs[i, chunk_idx].cpu().detach().numpy()) + lookup_indices.append((doc_id, chunk_idx)) else: - model_output: EncoderOutput = model(passage=batch) - encoded.append(model_output.p_reps.cpu().detach().numpy()) - - encoded = np.concatenate(encoded) + batch_ids, batch_inputs = batch + lookup_indices.extend(batch_ids) + + for k, v in batch_inputs.items(): + batch_inputs[k] = v.to(training_args.device) + + if data_args.encode_is_query: + model_output: EncoderOutput = model(query=batch_inputs) + encoded.append(model_output.q_reps.cpu().detach().numpy()) + else: + model_output: EncoderOutput = model(passage=batch_inputs) + encoded.append(model_output.p_reps.cpu().detach().numpy()) + if use_pre_chunked or use_chunked or use_random_chunking: + print("use_chunked: ", use_chunked) + print(f"encoded: {encoded}") + print(f"lookup_indices: {lookup_indices}") + print(f"length of encoded: {len(encoded)}") + print(f"length of lookup_indices: {len(lookup_indices)}") + # Combine encoded embeddings + encoded = np.stack(encoded) + logger.info(f"Encoded {len(set(d for d, c in lookup_indices))} docs into {len(lookup_indices)} chunks") + print(f"encoded.shape: {encoded.shape}") + print(f"length of encoded: {len(encoded)}") + # input("Press Enter to continue...") + else: + encoded = np.concatenate(encoded) with open(data_args.encode_output_path, 'wb') as f: pickle.dump((encoded, lookup_indices), f) + + logger.info(f"Saved embeddings to {data_args.encode_output_path}, shape: {encoded.shape}") if __name__ == "__main__": diff --git a/src/tevatron/retriever/driver/search.py b/src/tevatron/retriever/driver/search.py index 1f374eac..28fc8c8a 100644 --- a/src/tevatron/retriever/driver/search.py +++ b/src/tevatron/retriever/driver/search.py @@ -3,6 +3,7 @@ import numpy as np import glob from argparse import ArgumentParser +from collections import defaultdict from itertools import chain from tqdm import tqdm import faiss @@ -29,6 +30,47 @@ def search_queries(retriever, q_reps, p_lookup, args): return all_scores, psg_indices +def search_queries_chunked(retriever, q_reps, p_lookup, args): + """ + Search with chunked passages and aggregate by document using MaxSim. + """ + # Search more chunks to ensure good recall after aggregation + chunk_multiplier = getattr(args, 'chunk_multiplier', 10) + search_depth = args.depth * chunk_multiplier + + if args.batch_size > 0: + # all_scores.shape = [Q, search_depth] + all_scores, all_indices = retriever.batch_search(q_reps, search_depth, args.batch_size, args.quiet) + else: + # all_scores.shape = [search_depth] + all_scores, all_indices = retriever.search(q_reps, search_depth) + # Aggregate by document ID using MaxSim + aggregated_results = [] + for q_idx in range(len(q_reps)): + scores = all_scores[q_idx] + indices = all_indices[q_idx] + doc_max_scores = defaultdict(lambda: float('-inf')) + for score, idx in zip(scores, indices): + if idx < 0: # FAISS returns -1 for insufficient results + continue + if idx >= len(p_lookup): # Boundary check: prevent IndexError + logger.warning(f"Index {idx} out of bounds for p_lookup (length {len(p_lookup)}), skipping") + continue + + try: + doc_id, chunk_idx = p_lookup[idx] + except (ValueError, TypeError) as e: + logger.error(f"p_lookup[{idx}] is not a tuple (doc_id, chunk_idx): {p_lookup[idx]}, error: {e}") + continue + + # MaxSim: keep the maximum score for each document + doc_max_scores[doc_id] = max(doc_max_scores[doc_id], score) + # Sort by score and take top-depth + sorted_docs = sorted(doc_max_scores.items(), key=lambda x: x[1], reverse=True)[:args.depth] + aggregated_results.append(sorted_docs) + return aggregated_results + + def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): with open(ranking_save_file, 'w') as f: for qid, q_doc_scores, q_doc_indices in zip(q_lookup, corpus_scores, corpus_indices): @@ -38,6 +80,17 @@ def write_ranking(corpus_indices, corpus_scores, q_lookup, ranking_save_file): f.write(f'{qid}\t{idx}\t{s}\n') +def write_ranking_chunked(results, q_lookup, ranking_save_file): + """ + Write ranking results from chunked search. + results: List[List[Tuple[doc_id, score]]] + """ + with open(ranking_save_file, 'w') as f: + for qid, doc_scores in zip(q_lookup, results): + for doc_id, score in doc_scores: + f.write(f'{qid}\t{doc_id}\t{score}\n') + + def pickle_load(path): with open(path, 'rb') as f: reps, lookup = pickle.load(f) @@ -58,6 +111,11 @@ def main(): parser.add_argument('--save_ranking_to', required=True) parser.add_argument('--save_text', action='store_true') parser.add_argument('--quiet', action='store_true') + # Chunked search arguments + parser.add_argument('--chunked', action='store_true', + help='Enable chunked search with document-level MaxSim aggregation') + parser.add_argument('--chunk_multiplier', type=int, default=10, + help='Multiply search depth by this factor for chunked search to ensure recall') args = parser.parse_args() @@ -75,6 +133,13 @@ def main(): retriever.add(p_reps) look_up += p_lookup + # Auto-detect chunked format: lookup entries are tuples (doc_id, chunk_idx) + is_chunked = args.chunked or (len(look_up) > 0 and isinstance(look_up[0], tuple)) + if is_chunked: + unique_docs = len(set(doc_id for doc_id, _ in look_up)) + logger.info(f"Chunked mode: {len(look_up)} chunks from {unique_docs} documents") + logger.info(f"Search depth: {args.depth} docs, chunk search depth: {args.depth * args.chunk_multiplier}") + q_reps, q_lookup = pickle_load(args.query_reps) q_reps = q_reps @@ -96,14 +161,33 @@ def main(): ngpu=num_gpus) logger.info('Index Search Start') - all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) - logger.info('Index Search Finished') - - if args.save_text: - write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) + + if is_chunked: + # Chunked search with MaxSim aggregation + aggregated_results = search_queries_chunked(retriever, q_reps, look_up, args) + logger.info('Index Search Finished (chunked mode with MaxSim aggregation)') + + if args.save_text: + write_ranking_chunked(aggregated_results, q_lookup, args.save_ranking_to) + else: + # Convert to arrays for pickle + all_scores = [] + all_doc_ids = [] + for doc_scores in aggregated_results: + scores = [s for _, s in doc_scores] + doc_ids = [d for d, _ in doc_scores] + all_scores.append(scores) + all_doc_ids.append(doc_ids) + pickle_save((all_scores, all_doc_ids), args.save_ranking_to) else: - pickle_save((all_scores, psg_indices), args.save_ranking_to) + # Standard search + all_scores, psg_indices = search_queries(retriever, q_reps, look_up, args) + logger.info('Index Search Finished') + if args.save_text: + write_ranking(psg_indices, all_scores, q_lookup, args.save_ranking_to) + else: + pickle_save((all_scores, psg_indices), args.save_ranking_to) if __name__ == '__main__': main() diff --git a/src/tevatron/retriever/driver/train.py b/src/tevatron/retriever/driver/train.py index 39abab45..a570231d 100644 --- a/src/tevatron/retriever/driver/train.py +++ b/src/tevatron/retriever/driver/train.py @@ -64,6 +64,7 @@ def main(): model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) + tokenizer.eos_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -87,6 +88,16 @@ def main(): torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation, ) + # Set passage_chunk_size: use fixed chunk size if set, otherwise set to 1 if using random chunking + # (model uses passage_chunk_size > 0 as signal to use chunked encoding) + if data_args.passage_chunk_size > 0: + model.passage_chunk_size = data_args.passage_chunk_size + elif data_args.passage_chunk_size_range is not None: + # For random chunking, set to a positive value to enable chunked encoding + # The actual chunk sizes will be determined per-passage by the collator + model.passage_chunk_size = 1 + else: + model.passage_chunk_size = 0 train_dataset = TrainDataset(data_args) collator = TrainCollator(data_args, tokenizer) diff --git a/src/tevatron/retriever/modeling/dense.py b/src/tevatron/retriever/modeling/dense.py index 8bc50106..dd0d077a 100644 --- a/src/tevatron/retriever/modeling/dense.py +++ b/src/tevatron/retriever/modeling/dense.py @@ -1,21 +1,70 @@ import torch +import torch.nn.functional as F import logging from transformers import Qwen2_5OmniThinkerForConditionalGeneration from .encoder import EncoderModel logger = logging.getLogger(__name__) - +EOS_TOKEN_ID = 151643 class DenseModel(EncoderModel): + def __init__(self, encoder, pooling='cls', normalize=False, temperature=1.0): + super().__init__(encoder, pooling, normalize, temperature) + self.passage_chunk_size = 0 + self.eos_positions = None + def encode_query(self, qry): query_hidden_states = self.encoder(**qry, return_dict=True) query_hidden_states = query_hidden_states.last_hidden_state return self._pooling(query_hidden_states, qry['attention_mask']) - def encode_passage(self, psg): - # encode passage is the same as encode query - return self.encode_query(psg) + def encode_passage(self, psg, eos_positions=None): + print(f"eos_positions: {eos_positions}") + hidden_states = self.encoder(**psg, return_dict=True).last_hidden_state + if self.passage_chunk_size > 0 and eos_positions: + for i, ep in enumerate(eos_positions): + for eos_pos in ep: + assert psg['input_ids'][i][eos_pos] == EOS_TOKEN_ID + + return self._pooling_chunked(hidden_states, eos_positions) + + return self._pooling(hidden_states, psg['attention_mask']) + + def _pooling_chunked(self, last_hidden_state, eos_positions): + batch_size, seq_len, hidden_size = last_hidden_state.shape + print(f"last_hidden_state.shape: {last_hidden_state.shape}") + print(f"eos_positions: {eos_positions}") + + if not eos_positions: + # No chunks, return empty + return torch.zeros(batch_size, 0, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype), \ + torch.zeros(batch_size, 0, device=last_hidden_state.device) + + # Find max number of chunks across all passages + for eos_pos in eos_positions: + print(f"eos_pos: {eos_pos}") + print(f"type(eos_pos): {type(eos_pos)}") + max_chunks = max(len(pos_list) for pos_list in eos_positions) + + chunk_reps = torch.zeros(batch_size, max_chunks, hidden_size, device=last_hidden_state.device, dtype=last_hidden_state.dtype) + chunk_mask = torch.zeros(batch_size, max_chunks, device=last_hidden_state.device, dtype=torch.float) + + # Extract embeddings at eos_positions (this is the pooling operation for chunked passages) + for i, positions in enumerate(eos_positions): + for j, pos in enumerate(positions): + if 0 <= pos < seq_len: + # i is the batch index, j is the chunk index, pos is the eos position + chunk_reps[i, j] = last_hidden_state[i, pos] + # chunk_mask is 1.0 for valid chunks, 0.0 for padding chunks + chunk_mask[i, j] = 1.0 + else: + logger.warning(f"Position {pos} out of bounds for sequence length {seq_len} in batch {i}, chunk {j}") + + if self.normalize: + chunk_reps = F.normalize(chunk_reps, p=2, dim=-1) + + return chunk_reps, chunk_mask def _pooling(self, last_hidden_state, attention_mask): @@ -67,4 +116,4 @@ def encode_query(self, qry): def encode_passage(self, psg): # encode passage is the same as encode query - return self.encode_query(psg) \ No newline at end of file + return self.encode_query(psg) diff --git a/src/tevatron/retriever/modeling/encoder.py b/src/tevatron/retriever/modeling/encoder.py index c3eedc35..443a004b 100644 --- a/src/tevatron/retriever/modeling/encoder.py +++ b/src/tevatron/retriever/modeling/encoder.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Dict, Optional +import os import torch import torch.distributed as dist from torch import nn, Tensor @@ -38,6 +39,7 @@ def __init__(self, self.pooling = pooling self.normalize = normalize self.temperature = temperature + self.passage_chunk_size = 0 self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: @@ -46,7 +48,25 @@ def __init__(self, def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None): q_reps = self.encode_query(query) if query else None - p_reps = self.encode_passage(passage) if passage else None + p_reps, chunk_mask = None, None + if passage: + # If training with chunked passages, eos_positions is produced by the collator and + # attached to the model by TevatronTrainer.compute_loss(). Forward() needs to pass it + # into encode_passage() to actually get chunk reps/masks. + eos_positions = getattr(self, "eos_positions", None) + if self.passage_chunk_size > 0 and eos_positions is not None: + # print(f"eos_positions: {eos_positions}") + try: + p_reps = self.encode_passage(passage, eos_positions=eos_positions) + except TypeError: + # Some models (e.g., multimodal) don't accept eos_positions. + p_reps = self.encode_passage(passage) + else: + p_reps = self.encode_passage(passage) + # print(f"p_reps: {p_reps}") + # print(f"type(p_reps): {type(p_reps)}") + if self.passage_chunk_size > 0 and isinstance(p_reps, tuple): + p_reps, chunk_mask = p_reps # for inference if q_reps is None or p_reps is None: @@ -60,19 +80,34 @@ def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = if self.is_ddp: q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) - - scores = self.compute_similarity(q_reps, p_reps) + # print(f"passage_chunk_size: {self.passage_chunk_size}") + # print(f"chunk_mask: {chunk_mask}") + if self.passage_chunk_size > 0 and chunk_mask is not None: + # print(f"start compute maxsim similarity==========================") + scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + # print(f"end compute maxsim similarity==========================") + else: + # print(f"start compute similarity==========================") + scores = self.compute_similarity(q_reps, p_reps) + # view the scores as [Q, P] where Q is the number of queries and P is the number of passages scores = scores.view(q_reps.size(0), -1) - target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) - target = target * (p_reps.size(0) // q_reps.size(0)) - + num_psg_per_query = scores.size(1) // q_reps.size(0) + target = torch.arange(q_reps.size(0), device=scores.device, dtype=torch.long) + target = target * num_psg_per_query + # target contains the indices of the positive passages in this batch target.shape = [Q] + # so the target is [0, 4, 8, 12] for batch_size = 2, group_size = 4, chunk_size = 64 + print(f"target: {target}") + print(f"target.shape: {target.shape}") loss = self.compute_loss(scores / self.temperature, target) if self.is_ddp: - loss = loss * self.world_size # counter average weight reduction + loss = loss * self.world_size # counter average weight reduction # for eval else: - scores = self.compute_similarity(q_reps, p_reps) + if self.passage_chunk_size > 0 and chunk_mask is not None: + scores = self.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + else: + scores = self.compute_similarity(q_reps, p_reps) loss = None return EncoderOutput( loss=loss, @@ -90,6 +125,60 @@ def encode_query(self, qry): def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1)) + def compute_maxsim_similarity(self, q_reps, p_reps, chunk_mask): + """ + MaxSim: max similarity between query and passage chunks. + q_reps: [Q, H], p_reps: [P, C, H], chunk_mask: [P, C] + Q: number of queries + P: number of passages + C: number of chunks per passage + H: dimension of the embeddings + Returns: [Q, P] + """ + chunk_scores = torch.einsum('qh,pch->qpc', q_reps, p_reps) # 第 q 个 query 和第 p 个 passage 的第 c 个 chunk 的相似度 + if chunk_mask is not None: + padding_mask = ~chunk_mask.unsqueeze(0).bool() + chunk_scores = chunk_scores.masked_fill(padding_mask, float('-inf')) + max_vals, max_idx = chunk_scores.max(dim=-1) # [Q, P], [Q, P] + + # Log maxsim info: read chunk indices directly from max_idx + if True: + # only log from rank-0 if DDP + if (not getattr(self, "is_ddp", False)) or getattr(self, "process_rank", 0) == 0: + eos_positions = getattr(self, "eos_positions", None) + eos_ok = ( + isinstance(eos_positions, (list, tuple)) + and len(eos_positions) == p_reps.size(0) + ) + + # Compute last valid chunk indices for all passages + if chunk_mask is not None: + last_ci_per_passage = (chunk_mask.sum(dim=1) - 1).clamp(min=0) # [P] + else: + last_ci_per_passage = torch.full((p_reps.size(0),), p_reps.size(1) - 1, dtype=torch.long) + + # Log for each query-passage pair + for qi in range(max_idx.size(0)): + for pi in range(max_idx.size(1)): + ci = int(max_idx[qi, pi].item()) # best chunk index from max_idx + last_ci = int(last_ci_per_passage[pi].item()) + score = float(max_vals[qi, pi].item()) + + if eos_ok and eos_positions[pi] and ci < len(eos_positions[pi]): + best_pos = eos_positions[pi][ci] + last_pos = eos_positions[pi][-1] + logger.info( + f"[maxsim] q={qi} p={pi} best_chunk={ci} best_pos={best_pos} " + f"last_chunk={last_ci} last_pos={last_pos} best_score={score:.6f}" + ) + else: + logger.info( + f"[maxsim] q={qi} p={pi} best_chunk={ci} last_chunk={last_ci} " + f"best_score={score:.6f}" + ) + + return max_vals + def compute_loss(self, scores, target): return self.cross_entropy(scores, target) diff --git a/src/tevatron/retriever/trainer.py b/src/tevatron/retriever/trainer.py index 0c6ceb58..22beac16 100644 --- a/src/tevatron/retriever/trainer.py +++ b/src/tevatron/retriever/trainer.py @@ -45,7 +45,11 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - query, passage = inputs + query, passage, *rest = inputs + eos_positions = rest[0] if rest else None + # input(f"trainer.compute_loss: eos_positions: {eos_positions}") + if hasattr(model, 'eos_positions'): + model.eos_positions = eos_positions return model(query=query, passage=passage).loss def training_step(self, *args): diff --git a/tests/test_chunking.py b/tests/test_chunking.py new file mode 100644 index 00000000..d3a0bacf --- /dev/null +++ b/tests/test_chunking.py @@ -0,0 +1,2168 @@ +import sys +from pathlib import Path +import random + +import pytest +import torch + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + # tevatron/tests/test_chunking.py -> tevatron/ -> tevatron/src + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +def _strictly_increasing(xs): + return all(xs[i] > xs[i - 1] for i in range(1, len(xs))) + +REAL_TEXT = ( + "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " + "development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging " + "(MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to " + "calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in " + "preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter " + "development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white " + "matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to " + "1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both " + "times were similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with " + "greater absolute values in the internal capsule than in the central white matter. Preterm infants at term showed " + "higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 versus 1.15 +/- 0.09 microm2/ms, " + "p = 0.016) and lower relative anisotropy in both areas compared with full-term infants (white matter, 10.9 +/- " + "0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- 4.44 versus 33.1 +/- 0.6% p = 0.006). " + "Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term and " + "preterm infants at term showed marked differences in white matter fiber organization. The data indicate that " + "quantitative assessment of water diffusion by diffusion tensor MRI provides insight into microstructural " + "development in cerebral white matter in living infants" +) + +# Semantically chunked version of REAL_TEXT - split into meaningful semantic units +REAL_TEXT_SEMANTIC_CHUNKS = [ + # Chunk 1: Introduction - Background on white matter alterations + "Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical " + "development and result in functional disabilities.", + + # Chunk 2: Methodology - MRI technique description + "A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was " + "applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate " + "three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7).", + + # Chunk 3: Study design - Longitudinal follow-up + "To assess effects of prematurity on cerebral white matter development, early gestation preterm infants " + "(n = 10) were studied a second time at term.", + + # Chunk 4: Results - Central white matter findings + "In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and " + "decreased toward term to 1.2 microm2/ms.", + + # Chunk 5: Results - Internal capsule findings + "In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were " + "similar (1.2 versus 1.1 microm2/ms). Relative anisotropy was higher the closer birth was to term with greater " + "absolute values in the internal capsule than in the central white matter.", + + # Chunk 6: Results - Preterm vs full-term comparisons + "Preterm infants at term showed higher mean diffusion coefficients in the central white matter (1.4 +/- 0.24 " + "versus 1.15 +/- 0.09 microm2/ms, p = 0.016) and lower relative anisotropy in both areas compared with " + "full-term infants (white matter, 10.9 +/- 0.6 versus 22.9 +/- 3.0%, p = 0.001; internal capsule, 24.0 +/- " + "4.44 versus 33.1 +/- 0.6% p = 0.006).", + + # Chunk 7: Results - Corpus callosum findings + "Nonmyelinated fibers in the corpus callosum were visible by diffusion tensor MRI as early as 28 wk; full-term " + "and preterm infants at term showed marked differences in white matter fiber organization.", + + # Chunk 8: Conclusion + "The data indicate that quantitative assessment of water diffusion by diffusion tensor MRI provides insight into " + "microstructural development in cerebral white matter in living infants" +] +EOS_TOKEN_ID = 151643 +PADDING_TOKEN_ID = 151643 + +@pytest.fixture(scope="session") +def train_tokenizer(): + """ + Use the Qwen 0.6B tokenizer. + """ + _add_tevatron_src_to_path() + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.eos_token_id = tok.pad_token_id + tok.padding_side = "right" # finetune_with_chunk.sh uses --padding_side right + return tok + + +# ============================================================================ +# Unit tests for _chunk_tokens helper function +# ============================================================================ + +@pytest.mark.unit +def test_chunk_tokens_basic(): + """Test basic chunking functionality.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=4 means chunk_len=3, so chunks are: + # [0,1,2,99], [3,4,5,99], [6,7,8,99], [9,99] + expected_ids = [0, 1, 2, 99, 3, 4, 5, 99, 6, 7, 8, 99, 9, 99] + expected_eos_pos = [3, 7, 11, 13] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_with_max_length(): + """Test chunking with max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 5 + max_length = 12 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Hardcoded golden output: chunk_size=5 means chunk_len=4 + # First chunk: [0,1,2,3,99] = 5 tokens + # Second chunk: [4,5,6,7,99] = 5 tokens + # Third chunk: [8,99] = 2 tokens (partial, fits in remaining 2 tokens) + # Total: 12 tokens + expected_ids = [0, 1, 2, 3, 99, 4, 5, 6, 7, 99, 8, 99] + expected_eos_pos = [4, 9, 11] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_max_length_exact_fit(): + """Test chunking when max_length exactly fits chunks.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + max_length = 14 # Exactly fits 3 chunks: 3*4 + 2 = 14 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + expected_ids = [0, 1, 2, 99, 3, 4, 5, 99, 6, 7, 8, 99, 9, 99] + expected_eos_pos = [3, 7, 11, 13] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_max_length_too_small(): + """Test chunking when max_length is too small for even one chunk.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(10)) + eos_id = 99 + chunk_size = 4 + max_length = 1 # Too small for even one chunk (need at least 2: 1 token + EOS) + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Should return empty since we can't fit even one chunk + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_empty_input(): + """Test chunking with empty token list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + assert ids == [] + assert eos_pos == [] + +@pytest.mark.unit +def test_chunk_tokens_same_length_as_chunk_size(): + """Test chunking when tokens are the same length as chunk_size.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 16 + max_length = 16 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + expected_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99] + expected_eos_pos = [15] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_single_token(): + """Test chunking with single token.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [42] + eos_id = 99 + chunk_size = 4 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + assert ids == [42, 99] + assert eos_pos == [1] + + +@pytest.mark.unit +def test_chunk_tokens_no_max_length(): + """Test chunking without max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(15)) + eos_id = 99 + chunk_size = 5 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length=None) + + # Hardcoded golden output: chunk_size=5 means chunk_len=4 + # Chunks: [0-3,99], [4-7,99], [8-11,99], [12-14,99] + expected_ids = [0, 1, 2, 3, 99, 4, 5, 6, 7, 99, 8, 9, 10, 11, 99, 12, 13, 14, 99] + expected_eos_pos = [4, 9, 14, 18] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_one(): + """Test chunking with chunk_size=1 (invalid, should return empty).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [1, 2, 3] + eos_id = 99 + chunk_size = 1 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=1 is invalid (need at least 2: 1 token + 1 EOS) + # Should return empty + assert ids == [] + assert eos_pos == [] + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_two(): + """Test chunking with chunk_size=2.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = [1, 2, 3, 4, 5] + eos_id = 99 + chunk_size = 2 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id) + + # chunk_size=2 means chunk_len=1 + # Chunks: [1,99], [2,99], [3,99], [4,99], [5,99] + expected_ids = [1, 99, 2, 99, 3, 99, 4, 99, 5, 99] + expected_eos_pos = [1, 3, 5, 7, 9] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_max_length_stops_at_boundary(): + """Test that max_length stops chunking at chunk boundary.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 5 + max_length = 10 # Exactly 2 chunks: 2*5 = 10 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + expected_ids = [0, 1, 2, 3, 99, 4, 5, 6, 7, 99] + expected_eos_pos = [4, 9] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_chunk_size_greater_than_max_length(): + """Test chunking when chunk_size > max_length (only one partial chunk fits).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + tokens = list(range(20)) + eos_id = 99 + chunk_size = 10 # chunk_size > max_length + max_length = 5 # max_length < chunk_size + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Hardcoded golden output: chunk_size=10 means chunk_len=9, but max_length=5 + # Can only fit: 4 tokens + 1 EOS = 5 tokens (exactly max_length) + expected_ids = [0, 1, 2, 3, 99] + expected_eos_pos = [4] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + +@pytest.mark.unit +def test_chunk_tokens_truncation_takes_from_front(): + """Test that truncation when tokens exceed max_length takes from the front (beginning) of the list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + # Create tokens with distinct values at front and back to verify truncation direction + tokens = list(range(20)) # [0, 1, 2, ..., 19] + eos_id = 99 + chunk_size = 5 # chunk_len = 4 + max_length = 8 # Can fit: 1 full chunk (4 tokens + 1 EOS = 5) + 1 partial (2 tokens + 1 EOS = 3) = 8 total + + ids, eos_pos = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Hardcoded golden output: truncation takes from front, so we get [0,1,2,3,99,4,5,99] + # If it took from back, we'd get [16,17,18,19,99,...] or similar + expected_ids = [0, 1, 2, 3, 99, 4, 5, 99] + expected_eos_pos = [4, 7] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + # Verify it's taking from the front: first token should be 0 (beginning of original list) + assert ids[0] == 0 + # Verify it's NOT taking from the back: last content token should be 5, not 19 + assert ids[-2] == 5 # Last content token before final EOS + assert ids[-2] != 19 # Confirms we're not taking from the end + + +@pytest.mark.unit +def test_chunk_tokens_truncation_then_padding_complex_case(train_tokenizer): + """Test complex case: tokens exceed max_length (truncation from front), then padding is applied.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens, _pad_and_adjust_eos_positions + + # Create a long token sequence that will be truncated + # Use distinct values to clearly see truncation direction + tokens = list(range(100, 200)) # [100, 101, 102, ..., 199] - 100 tokens + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + chunk_size = 10 # chunk_len = 9 + max_length = 20 # Can fit: 1 full chunk (9 tokens + 1 EOS = 10) + 1 partial (9 tokens + 1 EOS = 10) = 20 total + + # Step 1: Chunk with truncation (takes from front) + chunked_ids, eos_positions = _chunk_tokens(tokens, chunk_size, eos_id, max_length) + + # Verify truncation takes from front: should start with 100, not 199 + assert chunked_ids[0] == 100 # First token from original list + assert chunked_ids[-2] == 117 # Last content token (not 199) - second chunk ends at 117 + assert len(chunked_ids) == 20 # Exactly max_length + + # Hardcoded golden output: truncated from front + # Original: 100 tokens [100-199] + # After truncation (front): 18 tokens [100-117] + 2 EOS = 20 tokens + expected_chunked_ids = [ + 100, 101, 102, 103, 104, 105, 106, 107, 108, eos_id, # First chunk: 9 tokens + EOS + 109, 110, 111, 112, 113, 114, 115, 116, 117, eos_id # Second chunk: 9 tokens + EOS + ] + expected_eos_positions = [9, 19] # EOS positions before padding (list, not list of lists) + + assert chunked_ids == expected_chunked_ids + assert eos_positions == expected_eos_positions + + # Step 2: Test left padding with truncation + all_input_ids = [chunked_ids] + all_eos_positions = [eos_positions] + + # Apply our padding function + padded_dict_left, adjusted_eos_positions_left = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + expected_padded_ids_left = [ + pad_id, pad_id, pad_id, pad_id, # 4 padding tokens + 100, 101, 102, 103, 104, 105, 106, 107, 108, eos_id, # First chunk: 9 tokens + EOS + 109, 110, 111, 112, 113, 114, 115, 116, 117, eos_id # Second chunk: 9 tokens + EOS + ] + expected_attention_mask_left = [0, 0, 0, 0] + [1] * 20 # 4 padding + 20 content + expected_adjusted_eos_positions_left = [[13, 23]] + + assert padded_dict_left['input_ids'][0].tolist() == expected_padded_ids_left + assert padded_dict_left['attention_mask'][0].tolist() == expected_attention_mask_left + assert adjusted_eos_positions_left == expected_adjusted_eos_positions_left + +# ============================================================================ +# Unit tests for _pad_and_adjust_eos_positions helper function +# ============================================================================ + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_right_padding(train_tokenizer): + """Test padding with right padding (no EOS position adjustment needed).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[3], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) + [4, 5, eos_id, pad_id], # Passage 1: padded to 4 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # Passage 0: all tokens valid + [1, 1, 1, 0], # Passage 1: last token is padding + ]) + expected_eos_positions = [[3], [2]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=4, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_left_padding(train_tokenizer): + """Test padding with left padding (EOS positions should be shifted).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, 3, eos_id], # Passage 0: 4 tokens, EOS at position 3 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[3], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Passage 0: padded to 4 (no padding needed) + [pad_id, 4, 5, eos_id], # Passage 1: padded to 4 (1 padding token on left) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # Passage 0: all tokens valid + [0, 1, 1, 1], # Passage 1: first token is padding + ]) + # Passage 0: original length 4, padded length 4, padding_length=0, EOS stays at 3 + # Passage 1: original length 3, padded length 4, padding_length=1, EOS shifts from 2 to 3 + expected_eos_positions = [[3], [3]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'left' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=4, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_multiple_eos(train_tokenizer): + """Test padding with multiple EOS positions per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, eos_id, 3, 4, eos_id], # Passage 0: 6 tokens, EOS at positions 2, 5 + [5, eos_id], # Passage 1: 2 tokens, EOS at position 1 + ] + all_eos_positions = [[2, 5], [1]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [pad_id, pad_id, 1, 2, eos_id, 3, 4, eos_id], # Passage 0: padded to 8 (2 padding tokens on left) + [pad_id, pad_id, pad_id, pad_id, pad_id, pad_id, 5, eos_id], # Passage 1: padded to 8 (6 padding tokens on left) + ]) + expected_attention_mask = torch.tensor([ + [0, 0, 1, 1, 1, 1, 1, 1], # Passage 0: first 2 tokens are padding + [0, 0, 0, 0, 0, 0, 1, 1], # Passage 1: first 6 tokens are padding + ]) + # Passage 0: original length 6, padded length 8, padding_length=2, EOS shift from [2,5] to [4,7] + # Passage 1: original length 2, padded length 8, padding_length=6, EOS shift from 1 to 7 + expected_eos_positions = [[4, 7], [7]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'left' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=8, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_no_padding_needed(train_tokenizer): + """Test when sequences are already the same length.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, eos_id], + [3, 4, eos_id], + ] + all_eos_positions = [[2], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=3, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, eos_id], + [3, 4, eos_id], + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1], + [1, 1, 1], + ]) + expected_eos_positions = [[2], [2]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=3, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_empty_input(train_tokenizer): + """Test with empty input.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + all_input_ids = [] + all_eos_positions = [] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output for empty input + expected_eos_positions = [] + + assert adjusted_eos_positions == expected_eos_positions + # When input is empty, tokenizer.pad may return list or tensor depending on implementation + if isinstance(padded_dict['input_ids'], torch.Tensor): + assert padded_dict['input_ids'].shape[0] == 0 + assert padded_dict['attention_mask'].shape[0] == 0 + else: + assert len(padded_dict['input_ids']) == 0 + assert len(padded_dict['attention_mask']) == 0 + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_single_passage(train_tokenizer): + """Test with single passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + + all_input_ids = [[1, 2, 3, eos_id]] + all_eos_positions = [[3]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=4, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, 3, eos_id], # Already length 4, no padding needed + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1, 1], # All tokens valid + ]) + expected_eos_positions = [[3]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=4, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_pad_to_multiple_of_one(train_tokenizer): + """Test with pad_to_multiple_of=1 (no rounding).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, 2, eos_id], + [3, eos_id], + ] + all_eos_positions = [[2], [1]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='right', + pad_to_multiple_of=1, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [1, 2, eos_id], # Padded to max_len=3 (no rounding needed with pad_to_multiple_of=1) + [3, eos_id, pad_id], # Padded to max_len=3 (1 padding token) + ]) + expected_attention_mask = torch.tensor([ + [1, 1, 1], # All tokens valid + [1, 1, 0], # Last token is padding + ]) + expected_eos_positions = [[2], [1]] # EOS positions unchanged for right padding + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'right' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=1, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + +@pytest.mark.unit +def test_pad_and_adjust_eos_positions_left_padding_multiple_chunks(train_tokenizer): + """Test left padding with multiple chunks per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _pad_and_adjust_eos_positions + + eos_id = train_tokenizer.eos_token_id + pad_id = train_tokenizer.pad_token_id + + all_input_ids = [ + [1, eos_id, 2, 3, eos_id], # Passage 0: 5 tokens, EOS at positions 1, 4 + [4, 5, eos_id], # Passage 1: 3 tokens, EOS at position 2 + ] + all_eos_positions = [[1, 4], [2]] + + padded_dict, adjusted_eos_positions = _pad_and_adjust_eos_positions( + all_input_ids=all_input_ids, + all_eos_positions=all_eos_positions, + tokenizer=train_tokenizer, + padding_side='left', + pad_to_multiple_of=8, + ) + + # Hardcoded golden output + expected_input_ids = torch.tensor([ + [pad_id, pad_id, pad_id, 1, eos_id, 2, 3, eos_id], # Passage 0: padded to 8 (3 padding tokens on left) + [pad_id, pad_id, pad_id, pad_id, pad_id, 4, 5, eos_id], # Passage 1: padded to 8 (5 padding tokens on left) + ]) + expected_attention_mask = torch.tensor([ + [0, 0, 0, 1, 1, 1, 1, 1], # Passage 0: first 3 tokens are padding + [0, 0, 0, 0, 0, 1, 1, 1], # Passage 1: first 5 tokens are padding + ]) + # Passage 0: original length 5, padded length 8, padding_length=3, EOS shift from [1,4] to [4,7] + # Passage 1: original length 3, padded length 8, padding_length=5, EOS shift from 2 to 7 + expected_eos_positions = [[4, 7], [7]] + + assert torch.equal(padded_dict['input_ids'], expected_input_ids) + assert torch.equal(padded_dict['attention_mask'], expected_attention_mask) + assert adjusted_eos_positions == expected_eos_positions + + # Verify behavior matches tokenizer.pad directly + train_tokenizer.padding_side = 'left' + direct_padded = train_tokenizer.pad( + {'input_ids': all_input_ids}, + padding=True, + pad_to_multiple_of=8, + return_attention_mask=True, + return_tensors='pt', + ) + assert torch.equal(padded_dict['input_ids'], direct_padded['input_ids']) + assert torch.equal(padded_dict['attention_mask'], direct_padded['attention_mask']) + + + + +@pytest.mark.unit +def test_train_collator_chunked_passages(train_tokenizer): + """Test chunking with passage_max_len=512, passage_chunk_size=256.""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + passage_max_len=512, + passage_chunk_size=256, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: 2 chunks (255 tokens + EOS, 174 tokens + EOS) = 431 tokens, padded to 432 + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, + 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, + 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, + 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, + 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, + 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, + 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, + 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, + 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, + 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, + 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, + 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, EOS_TOKEN_ID, 20, 51615, + 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, + 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, + 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, + 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, + 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, + 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, + 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, + 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, + 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, + 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, + 304, 5382, 41434, EOS_TOKEN_ID, PADDING_TOKEN_ID + ] + expected_mask = [1] * 431 + [0] # 431 ones + 1 zero + expected_eos_positions = [[255, 430]] + + assert sum(got_mask) == 431 + assert len(got_ids) == 432 # Padded to multiple of 16 + assert eos_positions == expected_eos_positions + assert got_ids == expected_ids + assert got_mask == expected_mask + assert got_ids[255] == train_tokenizer.eos_token_id + assert got_ids[430] == train_tokenizer.eos_token_id + assert got_mask[255] == 1 + assert got_mask[430] == 1 + + +@pytest.mark.unit +def test_train_collator_chunked_passages_left_padding(train_tokenizer): + """Test chunking with passage_max_len=512, passage_chunk_size=256, left padding.""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + passage_max_len=512, + passage_chunk_size=256, + pad_to_multiple_of=16, + padding_side="left", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + expected_ids = [ PADDING_TOKEN_ID, + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, 220, 16, + 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, 315, 6811, + 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, 320, 77, 284, + 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, 4158, 4925, + 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, 1550, 11, 220, 16, 13, 23, + 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, 13, 17, 19197, 441, 17, + 58634, 13, 758, 279, 44900, 47594, 315, 279, 5306, 47639, 11, 279, 3076, 9981, 57330, + 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, 19041, 220, 16, 13, 16, 19197, 441, + 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, 279, 12128, 7194, 572, 311, 4647, + 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, + 4991, 41434, 518, 4647, 8542, 5080, 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, + 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, 220, 16, 13, 16, EOS_TOKEN_ID, 20, 51615, + 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, 15, 13, 15, 16, 21, 8, 323, + 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, 2480, 9663, 41434, 320, + 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, 220, 17, 17, 13, 24, + 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, 5306, 47639, 11, + 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, 51615, 220, + 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, 11581, 2408, 301, 15479, 48674, + 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, 438, 4124, 438, + 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, 12864, 11799, + 304, 4158, 4925, 23788, 7321, 13, 576, 821, 13216, 429, 46516, 15449, 315, 3015, 57330, + 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, 59645, 4158, 4925, + 304, 5382, 41434, EOS_TOKEN_ID + ] + expected_mask = [0] + [1] * 431 # 1 padding + 431 content + expected_eos_positions = [[256, 431]] + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +@pytest.mark.unit +def test_chunked_collator_with_multiple_passages(train_tokenizer): + """Test TrainCollator with chunking enabled returns (q_batch, p_batch, eos_positions).""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + query_max_len=32, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=2, + passage_chunk_size=32, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + features = [ + (("q1", None, None, None), [(REAL_TEXT, None, None, None), (REAL_TEXT, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + # Hardcoded golden output: both passages have 2 chunks (31 tokens + EOS, 31 tokens + EOS) = 64 tokens each + expected_ids_0 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + EOS_TOKEN_ID, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, EOS_TOKEN_ID + ] + expected_mask_0 = [1] * 64 + expected_eos_0 = [31, 63] + + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + EOS_TOKEN_ID, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, EOS_TOKEN_ID + ] + expected_mask_1 = [1] * 64 + expected_eos_1 = [31, 63] + + assert p_batch["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + got_ids_0 = p_batch["input_ids"][0].tolist() + got_mask_0 = p_batch["attention_mask"][0].tolist() + got_ids_1 = p_batch["input_ids"][1].tolist() + got_mask_1 = p_batch["attention_mask"][1].tolist() + + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_0 + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_1 + + for i in range(p_batch["input_ids"].shape[0]): + got_ids = p_batch["input_ids"][i].tolist() + got_mask = p_batch["attention_mask"][i].tolist() + + assert len(eos_positions[i]) > 0 + assert _strictly_increasing(eos_positions[i]) + for eos_pos in eos_positions[i]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + assert len(got_ids) == 64 + + +@pytest.mark.unit +def test_chunking_capped_to_maxlen_chunk_size_64(train_tokenizer): + """When chunk_size >= max_len, chunking is capped to max_len with one EOS (chunk_size=64).""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + long_text = (REAL_TEXT + " ") * 20 + data_args = DataArguments( + passage_chunk_size=64, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: 63 tokens + 1 EOS = 64 tokens + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, EOS_TOKEN_ID + ] + expected_mask = [1] * 64 + expected_eos_positions = [[63]] + + assert sum(mask) == 64 + assert len(ids) == 64 + assert eos_positions == expected_eos_positions + assert ids == expected_ids + assert mask == expected_mask + assert ids[63] == EOS_TOKEN_ID + assert EOS_TOKEN_ID not in ids[:63] + assert _strictly_increasing(eos_positions[0]) + + +@pytest.mark.unit +def test_chunking_capped_to_maxlen_chunk_size_128(train_tokenizer): + """When chunk_size >= max_len, chunking is capped to max_len with one EOS (chunk_size=128).""" + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + long_text = (REAL_TEXT + " ") * 20 + data_args = DataArguments( + passage_chunk_size=128, + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([long_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: 63 tokens + 1 EOS = 64 tokens + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, 349, + 2326, EOS_TOKEN_ID + ] + expected_mask = [1] * 64 + expected_eos_positions = [[63]] + + assert sum(mask) == 64 + assert len(ids) == 64 + assert eos_positions == expected_eos_positions + assert ids == expected_ids + assert mask == expected_mask + assert ids[63] == EOS_TOKEN_ID + assert EOS_TOKEN_ID not in ids[:63] + assert _strictly_increasing(eos_positions[0]) + + +@pytest.mark.unit +def test_chunking_short_passage_shorter_than_chunk_size(train_tokenizer): + """ + When passage is shorter than chunk_size, it should still get one chunk with EOS, + and padding should be applied to pad_to_multiple_of. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + short_text = "Hello world" + data_args = DataArguments( + passage_chunk_size=64, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([short_text]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: "Hello world" -> 2 tokens + 1 EOS = 3 tokens, padded to 16 + expected_ids = [9707, 1879, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 13 # 3 content + 13 padding + expected_eos_positions = [[2]] + expected_mask = [1, 1, 1] + [0] * 13 # 3 ones + 13 zeros + + assert sum(mask) == 3 + assert len(ids) == 16 # Padded to multiple of 16 + assert eos_positions == expected_eos_positions + assert ids == expected_ids + assert ids[2] == EOS_TOKEN_ID # EOS at position 2 + assert mask == expected_mask + assert _strictly_increasing(eos_positions[0]) + + +@pytest.mark.unit +def test_chunking_passage_needs_padding_unpadded_not_multiple_of_pad_to_multiple_of(train_tokenizer): + """ + When unpadded length is not a multiple of pad_to_multiple_of, padding should be added. + This tests: unpadded_len=50, pad_to_multiple_of=16 -> padded_len=64. + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + data_args = DataArguments( + passage_chunk_size=32, + passage_max_len=50, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages([REAL_TEXT]) + ids = d_collated["input_ids"][0].tolist() + mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: 50 unpadded tokens (2 chunks: 31+1 EOS, 18+1 EOS), padded to 64 + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + EOS_TOKEN_ID, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, + 279, 9981, 57330, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 14 # 50 content + 14 padding + expected_eos_positions = [[31, 49]] + expected_mask = [1] * 50 + [0] * 14 # 50 ones + 14 zeros + assert sum(mask) == 50 + assert len(ids) == 64 # Padded to multiple of 16 + assert eos_positions == expected_eos_positions + assert ids == expected_ids + assert ids[31] == EOS_TOKEN_ID # First EOS + assert ids[49] == EOS_TOKEN_ID # Second EOS + assert mask == expected_mask + assert _strictly_increasing(eos_positions[0]) + + +@pytest.mark.unit +def test_chunking_multiple_passages_different_lengths(train_tokenizer): + """ + Test batch processing with multiple passages of different lengths: + - Short passage (2 tokens) + - Medium passage (18 tokens) + - Long passage (128 tokens, multiple chunks) + - Very long passage (158 tokens, multiple chunks) + All should be padded to the same length (longest unpadded length rounded up to pad_to_multiple_of). + """ + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + # Create a passage that will result in ~158 tokens + # REAL_TEXT is ~431 tokens, so we'll use a portion of it repeated or extended + long_passage = REAL_TEXT + " " + REAL_TEXT[:200] + + texts = ["Short", REAL_TEXT[:100], REAL_TEXT, long_passage] + data_args = DataArguments( + passage_chunk_size=64, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + d_collated, eos_positions = collator._tokenize_and_pad_chunked_passages(texts) + + expected_ids_0 = [12472, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 126 + expected_mask_0 = [1, 1] + [0] * 126 + expected_eos_0 = [1] + + # Passage 1: REAL_TEXT[:100] -> 17 tokens + 1 EOS = 18 tokens, padded to 160 + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 1062, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 110 + expected_mask_1 = [1] * 18 + [0] * 110 + expected_eos_1 = [17] + + # Passage 2: REAL_TEXT -> 2 chunks (63+1 EOS, 63+1 EOS) = 128 tokens, padded to 160 + expected_ids_2 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, EOS_TOKEN_ID, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, + 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, + 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, + 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, + EOS_TOKEN_ID + ] + expected_mask_2 = [1] * 128 + expected_eos_2 = [63, 127] + + expected_ids_3 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, 6629, 279, + 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, 90684, + 349, 2326, EOS_TOKEN_ID, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, + 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, 8552, 6239, + 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, 41434, + 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, 758, 279, 8622, + EOS_TOKEN_ID + ] + expected_mask_3 = [1] * 128 + expected_eos_3 = [63, 127] + + ids_0 = d_collated["input_ids"][0].tolist() + mask_0 = d_collated["attention_mask"][0].tolist() + ids_1 = d_collated["input_ids"][1].tolist() + mask_1 = d_collated["attention_mask"][1].tolist() + ids_2 = d_collated["input_ids"][2].tolist() + mask_2 = d_collated["attention_mask"][2].tolist() + ids_3 = d_collated["input_ids"][3].tolist() + mask_3 = d_collated["attention_mask"][3].tolist() + + # Passage 0 assertions + assert sum(mask_0) == 2 + assert len(ids_0) == 128 + assert ids_0 == expected_ids_0 + assert mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_0 + + # Passage 1 assertions + assert sum(mask_1) == 18 + assert len(ids_1) == 128 + assert ids_1 == expected_ids_1 + assert mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_1 + + # Passage 2 assertions + assert sum(mask_2) == 128 + assert len(ids_2) == 128 + assert ids_2 == expected_ids_2 + assert mask_2 == expected_mask_2 + assert eos_positions[2] == expected_eos_2 + assert _strictly_increasing(eos_positions[2]) + + # Passage 3 assertions + assert sum(mask_3) == 128 + assert len(ids_3) == 128 + assert eos_positions[3] == expected_eos_3 + assert ids_3 == expected_ids_3 + assert mask_3 == expected_mask_3 + + +# ============================================================================ +# Unit tests for random chunk sizes within a range +# ============================================================================ + +@pytest.mark.unit +def test_chunk_tokens_random_chunk_size_range_fixed_per_passage(train_tokenizer): + """Test chunking with random chunk size range, fixed per passage (all chunks in a passage use same random size).""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + # Set seed for deterministic results + random.seed(42) + + tokens = list(range(100)) # 100 tokens + eos_id = 99 + chunk_size_range = (10, 20) # Random chunk size between 10 and 20 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size=10, eos_token_id=eos_id, chunk_size_range=chunk_size_range) + + # Hardcoded golden output with seed=42 and chunk_size_range=(10, 20) + # With seed=42, random.randint(10, 20) generates: 19, 12, 11, 15, 14, 13, 13, 12, 5 (for chunks) + # Chunk sizes (before EOS): 19, 12, 11, 15, 14, 13, 13, 12, 5 + # Chunk lengths (tokens per chunk): 18, 11, 10, 14, 13, 12, 12, 11, 4 + expected_ids = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 99, # Chunk 1: 19 tokens (18 + EOS) + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 99, # Chunk 2: 12 tokens (11 + EOS) + 29, 30, 31, 32, 33, 34, 35, 36, 37, 99, # Chunk 3: 11 tokens (10 + EOS) + 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 99, # Chunk 4: 15 tokens (14 + EOS) + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 99, # Chunk 5: 14 tokens (13 + EOS) + 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 99, # Chunk 6: 13 tokens (12 + EOS) + 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 99, # Chunk 7: 13 tokens (12 + EOS) + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 99, # Chunk 8: 12 tokens (11 + EOS) + 96, 97, 98, 99, 99 # Chunk 9: 5 tokens (4 + EOS) + ] + expected_eos_pos = [19, 30, 40, 54, 67, 80, 92, 103, 108] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + + # Verify structure: each chunk should end with EOS + for eos_pos_val in eos_pos: + assert ids[eos_pos_val] == eos_id + + +@pytest.mark.unit +def test_chunk_tokens_random_chunk_size_range_with_max_length(train_tokenizer): + """Test random chunk size range with max_length constraint.""" + _add_tevatron_src_to_path() + from tevatron.retriever.collator import _chunk_tokens + + random.seed(123) + + tokens = list(range(200)) + eos_id = 99 + chunk_size_range = (15, 25) + max_length = 50 + + ids, eos_pos = _chunk_tokens(tokens, chunk_size=15, eos_token_id=eos_id, max_length=max_length, chunk_size_range=chunk_size_range) + + # Hardcoded golden output with seed=123, chunk_size_range=(15, 25), max_length=50 + # With seed=123, random.randint(15, 25) generates: 15, 20, 16 (for chunks) + # Chunk sizes (before EOS): 15, 20, 16 + # Chunk lengths (tokens per chunk): 14, 19, 15 + # Total: 14 + 1 + 19 + 1 + 15 + 1 = 50 tokens (exactly max_length) + expected_ids = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99, # Chunk 1: 15 tokens (14 + EOS) + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 99, # Chunk 2: 20 tokens (19 + EOS) + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 99 # Chunk 3: 16 tokens (15 + EOS, truncated to fit max_length) + ] + expected_eos_pos = [14, 33, 49] + + assert ids == expected_ids + assert eos_pos == expected_eos_pos + assert len(ids) == max_length # Exactly max_length + + # Verify all EOS positions are valid + for eos_pos_val in eos_pos: + assert ids[eos_pos_val] == eos_id + assert eos_pos_val < len(ids) + + +@pytest.mark.unit +def test_train_collator_random_chunk_size_range_fixed_per_passage(train_tokenizer): + """Test TrainCollator with random chunk size range, fixed per passage.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + random.seed(42) + + data_args = DataArguments( + query_max_len=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=2, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + (("q1", None, None, None), [(REAL_TEXT, None, None, None), (REAL_TEXT, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=False + # With seed=42, random.randint(32, 64) generates: 40 for passage 0, 34 for passage 1 + # Passage 0: chunk_size=40 (chunk_len=39), produces 4 chunks: [38, 77, 116, 127] + # Passage 1: chunk_size=34 (chunk_len=33), produces 4 chunks: [32, 65, 98, 127] + expected_ids_0 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, + EOS_TOKEN_ID, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, EOS_TOKEN_ID, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_0 = [1] * 128 + expected_eos_positions_0 = [38, 77, 116, 127] + + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, EOS_TOKEN_ID, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, EOS_TOKEN_ID, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, EOS_TOKEN_ID, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, + 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_1 = [1] * 128 + expected_eos_positions_1 = [32, 65, 98, 127] + + # Verify structure + assert p_batch["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + # Verify passage 0 + got_ids_0 = p_batch["input_ids"][0].tolist() + got_mask_0 = p_batch["attention_mask"][0].tolist() + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_positions_0 + assert _strictly_increasing(eos_positions[0]) + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + # Verify passage 1 + got_ids_1 = p_batch["input_ids"][1].tolist() + got_mask_1 = p_batch["attention_mask"][1].tolist() + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_positions_1 + assert _strictly_increasing(eos_positions[1]) + for eos_pos in eos_positions[1]: + assert got_ids_1[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_1[eos_pos] == 1 + + +@pytest.mark.unit +def test_train_collator_random_chunk_size_range_variable_per_chunk(train_tokenizer): + """Test TrainCollator with random chunk size range, variable per chunk.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + random.seed(42) + + data_args = DataArguments( + query_max_len=32, + passage_max_len=256, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=1, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=True, # Variable chunk size per chunk + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + (("q1", None, None, None), [(REAL_TEXT, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=True + # With seed=42 and variable chunk sizes, each chunk gets a random size from [32, 64] + # Chunk sizes generated: 40, 34, 50, 48, 47, 41, 3 (last partial chunk) + # EOS positions: [38, 71, 120, 167, 213, 253, 255] + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, EOS_TOKEN_ID, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, EOS_TOKEN_ID, 2086, 882, 518, 4647, + 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, + 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, + 13, 17, 19197, 441, 17, 58634, 13, EOS_TOKEN_ID, 758, 279, 44900, 47594, 315, 279, 5306, + 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, + 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, + 279, 12128, 7194, 572, 311, EOS_TOKEN_ID, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, + 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, + 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, + EOS_TOKEN_ID, 17, EOS_TOKEN_ID + ] + expected_mask = [1] * 256 + expected_eos_positions = [38, 71, 120, 167, 213, 253, 255] + + # Verify structure + assert p_batch["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + + got_ids = p_batch["input_ids"][0].tolist() + got_mask = p_batch["attention_mask"][0].tolist() + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert _strictly_increasing(eos_positions[0]) + + # Verify each EOS position is valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + + +@pytest.mark.unit +def test_train_collator_random_chunk_size_range_hardcoded_output(train_tokenizer): + """Test TrainCollator with random chunk size range - hardcoded golden output.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import TrainCollator + + random.seed(42) + + data_args = DataArguments( + query_max_len=32, + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + train_group_size=1, + passage_chunk_size_range="32,48", # Random chunk size between 32 and 48 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = TrainCollator(data_args=data_args, tokenizer=train_tokenizer) + + short_text = "Hello world this is a test passage" + features = [ + (("q1", None, None, None), [(short_text, None, None, None)]), + ] + + q_batch, p_batch, eos_positions = collator(features) + + got_ids = p_batch["input_ids"][0].tolist() + got_mask = p_batch["attention_mask"][0].tolist() + + # Hardcoded golden output with seed=42 and chunk_size_range=(32,48) + # short_text tokenizes to: [9707, 1879, 419, 374, 264, 1273, 21085] + # With seed=42, random.randint(32, 48) = 40 (first call) + # So chunk_len = 39, but we only have 7 tokens, so we get: [7 tokens] + EOS + expected_ids = [9707, 1879, 419, 374, 264, 1273, 21085, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 8 + expected_mask = [1] * 8 + [0] * 8 + expected_eos_positions = [[7]] + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +# ============================================================================ +# Unit tests for prechunked passages +# ============================================================================ + +@pytest.mark.unit +def test_prechunked_encode_collator_basic(train_tokenizer): + """Test PreChunkedEncodeCollator with basic pre-chunked passages.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Pre-chunked passages: each passage is a list of chunk strings + features = [ + ("doc1", ["Hello world", "This is chunk 2", "Final chunk"], None, None, None), + ("doc2", ["Single chunk passage"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output: + # doc1: "Hello world" -> [9707, 1879] + EOS, "This is chunk 2" -> [1986, 374, 11879, 220, 17] + EOS, "Final chunk" -> [19357, 11879] + EOS + # Total: 12 tokens (11 content + 3 EOS), padded to 16 + expected_ids_0 = [9707, 1879, EOS_TOKEN_ID, 1986, 374, 11879, 220, 17, EOS_TOKEN_ID, 19357, 11879, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 4 + expected_mask_0 = [1] * 12 + [0] * 4 + expected_eos_positions_0 = [2, 8, 11] + + # doc2: "Single chunk passage" -> [10888, 11879, 21085] + EOS + # Total: 4 tokens (3 content + 1 EOS), padded to 16 + expected_ids_1 = [10888, 11879, 21085, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 12 + expected_mask_1 = [1] * 4 + [0] * 12 + expected_eos_positions_1 = [3] + + assert doc_ids == ["doc1", "doc2"] + assert d_collated["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + # Verify doc1 + got_ids_0 = d_collated["input_ids"][0].tolist() + got_mask_0 = d_collated["attention_mask"][0].tolist() + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_positions_0 + assert len(eos_positions[0]) == 3 + assert _strictly_increasing(eos_positions[0]) + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + # Verify doc2 + got_ids_1 = d_collated["input_ids"][1].tolist() + got_mask_1 = d_collated["attention_mask"][1].tolist() + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_positions_1 + assert len(eos_positions[1]) == 1 + for eos_pos in eos_positions[1]: + assert got_ids_1[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_1[eos_pos] == 1 + + +@pytest.mark.unit +def test_prechunked_encode_collator_hardcoded_output(train_tokenizer): + """Test PreChunkedEncodeCollator with hardcoded golden output.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Pre-chunked passages + features = [ + ("doc1", ["Hello", "world"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output: + # "Hello" -> [9707] + EOS + # "world" -> [14615] + EOS (tokenized separately, different from "Hello world") + # Total: 4 tokens, padded to 16 + expected_ids = [9707, EOS_TOKEN_ID, 14615, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 12 + expected_mask = [1] * 4 + [0] * 12 + expected_eos_positions = [[1, 3]] + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +@pytest.mark.unit +def test_prechunked_encode_collator_max_length_truncation(train_tokenizer): + """Test PreChunkedEncodeCollator with max_length truncation.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=20, # Small max length to trigger truncation + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Create chunks that will exceed max_length + long_chunk = REAL_TEXT[:200] # Long chunk + features = [ + ("doc1", [long_chunk, "Second chunk", "Third chunk"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with max_length=20: + # First chunk (long_chunk) tokenizes to 19 tokens, then EOS is added at position 19 + # Total: 20 tokens (19 content + 1 EOS), which exactly fills max_length + # Second and third chunks are not included due to truncation + # Padded to 32 (multiple of 16) + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, + 7802, 82519, 4401, 323, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 12 + expected_mask = [1] * 20 + [0] * 12 + expected_eos_positions = [19] + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert len(got_ids) == 32 # Padded to multiple of 16 + assert sum(got_mask) == 20 # Exactly 20 tokens (19 content + 1 EOS) + + # Verify EOS positions are valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + assert eos_pos < len(got_ids) + + # Verify truncation: only first chunk fits, second and third chunks are not included + assert len(eos_positions[0]) == 1 # Only one EOS (from first chunk) + + +@pytest.mark.unit +def test_prechunked_encode_collator_left_padding(train_tokenizer): + """Test PreChunkedEncodeCollator with left padding.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="left", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", ["Hello", "world"], None, None, None), + ("doc2", ["Short"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + got_ids_0 = d_collated["input_ids"][0].tolist() + got_mask_0 = d_collated["attention_mask"][0].tolist() + got_ids_1 = d_collated["input_ids"][1].tolist() + got_mask_1 = d_collated["attention_mask"][1].tolist() + + # Both should be padded to same length (64, rounded to 64) + assert len(got_ids_0) == len(got_ids_1) + + # Verify EOS positions are adjusted for left padding + # doc1: [9707, EOS, 1879, EOS] = 4 tokens, padded to 64 -> 60 padding tokens + # EOS positions shift from [1, 3] to [61, 63] + assert len(eos_positions[0]) == 2 + assert eos_positions[0][0] > 1 # Should be shifted right + assert eos_positions[0][1] > 3 # Should be shifted right + + # Verify EOS tokens are at correct positions + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + +@pytest.mark.unit +def test_prechunked_encode_collator_empty_chunks(train_tokenizer): + """Test PreChunkedEncodeCollator with empty chunks list.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=64, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", [], None, None, None), # Empty chunks + ("doc2", ["Non-empty"], None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + assert doc_ids == ["doc1", "doc2"] + assert len(eos_positions) == 2 + + # Empty chunks should have no EOS positions + assert eos_positions[0] == [] + + # Non-empty should have EOS positions + assert len(eos_positions[1]) > 0 + + +@pytest.mark.unit +def test_prechunked_encode_collator_multiple_passages_different_lengths(train_tokenizer): + """Test PreChunkedEncodeCollator with multiple passages of different chunk counts.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", ["Chunk 1", "Chunk 2"], None, None, None), # 2 chunks + ("doc2", ["Single chunk"], None, None, None), # 1 chunk + ("doc3", ["A", "B", "C", "D"], None, None, None), # 4 chunks + ] + + doc_ids, d_collated, eos_positions = collator(features) + + assert doc_ids == ["doc1", "doc2", "doc3"] + assert d_collated["input_ids"].shape[0] == 3 + assert len(eos_positions) == 3 + + # Verify each passage has correct number of EOS positions + assert len(eos_positions[0]) == 2 # doc1: 2 chunks + assert len(eos_positions[1]) == 1 # doc2: 1 chunk + assert len(eos_positions[2]) == 4 # doc3: 4 chunks + + # Verify all EOS positions are valid + for i in range(3): + got_ids = d_collated["input_ids"][i].tolist() + got_mask = d_collated["attention_mask"][i].tolist() + + assert _strictly_increasing(eos_positions[i]) + for eos_pos in eos_positions[i]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + + +@pytest.mark.unit +def test_prechunked_encode_collator_semantic_chunks(train_tokenizer): + """Test PreChunkedEncodeCollator with semantically chunked REAL_TEXT.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import PreChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=512, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + ) + collator = PreChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + # Use semantically chunked version of REAL_TEXT + features = [ + ("doc1", REAL_TEXT_SEMANTIC_CHUNKS, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with semantically chunked REAL_TEXT (8 chunks) + # Each semantic chunk is tokenized and separated by EOS tokens + # Total: 437 content tokens + 8 EOS tokens = 445 tokens, padded to 448 (multiple of 16) + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, EOS_TOKEN_ID, 32, 1555, 8569, 57330, 12635, + 291, 23970, 56981, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, + 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, EOS_TOKEN_ID, 1249, 8552, + 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, 13, + EOS_TOKEN_ID, 641, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, + 73760, 572, 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, + 311, 220, 16, 13, 17, 19197, 441, 17, 58634, 13, EOS_TOKEN_ID, 641, 279, 44900, 47594, 315, + 279, 5306, 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, + 13, 17, 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, + 5080, 279, 12128, 7194, 572, 311, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, 47639, 1091, + 304, 279, 8622, 4158, 4925, 13, EOS_TOKEN_ID, 4703, 4991, 41434, 518, 4647, 8542, 5080, 3076, + 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, 17, 19, 19041, + 220, 16, 13, 16, 20, 51615, 220, 15, 13, 15, 24, 19197, 441, 17, 58634, 11, 281, 284, 220, + 15, 13, 15, 16, 21, 8, 323, 4722, 8674, 458, 285, 354, 17764, 304, 2176, 5671, 7707, 448, + 2480, 9663, 41434, 320, 5782, 4925, 11, 220, 16, 15, 13, 24, 51615, 220, 15, 13, 21, 19041, + 220, 17, 17, 13, 24, 51615, 220, 18, 13, 15, 13384, 281, 284, 220, 15, 13, 15, 15, 16, 26, + 5306, 47639, 11, 220, 17, 19, 13, 15, 51615, 220, 19, 13, 19, 19, 19041, 220, 18, 18, 13, 16, + 51615, 220, 15, 13, 21, 4, 281, 284, 220, 15, 13, 15, 15, 21, 568, EOS_TOKEN_ID, 8121, 2408, + 301, 15479, 48674, 304, 279, 42094, 1620, 385, 1242, 1033, 9434, 553, 57330, 15626, 51360, + 438, 4124, 438, 220, 17, 23, 73760, 26, 2480, 9663, 323, 855, 4991, 41434, 518, 4647, 8542, + 12864, 11799, 304, 4158, 4925, 23788, 7321, 13, EOS_TOKEN_ID, 785, 821, 13216, 429, 46516, + 15449, 315, 3015, 57330, 553, 57330, 15626, 51360, 5707, 20017, 1119, 8003, 95697, 4401, 304, + 59645, 4158, 4925, 304, 5382, 41434, EOS_TOKEN_ID + ] + [PADDING_TOKEN_ID] * 11 + expected_mask = [1] * 437 + [0] * 11 + expected_eos_positions = [24, 91, 125, 167, 229, 366, 409, 436] + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Verify structure: should have 8 EOS positions (one per semantic chunk) + assert doc_ids == ["doc1"] + assert d_collated["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert len(eos_positions[0]) == 8 # 8 semantic chunks + assert _strictly_increasing(eos_positions[0]) + + # Verify all EOS positions are valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + assert eos_pos < len(got_ids) + + # Verify that semantic chunks are preserved (each chunk ends with EOS) + # Check that we have content tokens between EOS positions + for i in range(len(eos_positions[0]) - 1): + chunk_start = eos_positions[0][i] + 1 # Start after EOS + chunk_end = eos_positions[0][i + 1] # End at next EOS + assert chunk_end > chunk_start # Should have content tokens between EOS markers + + # Verify total length is reasonable (should fit within max_length=512) + assert len(got_ids) == 448 # Padded to multiple of 16 + assert sum(got_mask) == 437 # 437 content tokens + assert len(got_ids) % 16 == 0 # Padded to multiple of 16 + + +# ============================================================================ +# Unit tests for random chunking in ChunkedEncodeCollator (inference/search) +# ============================================================================ + +@pytest.mark.unit +def test_chunked_encode_collator_random_chunk_size_range_fixed_per_passage(train_tokenizer): + """Test ChunkedEncodeCollator with random chunk size range, fixed per passage (inference).""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + random.seed(42) + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", REAL_TEXT, None, None, None), + ("doc2", REAL_TEXT, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=False + # With seed=42, random.randint(32, 64) generates: 39 for doc1, 33 for doc2 + # doc1: chunk_size=39 (chunk_len=38), produces 4 chunks: [38, 77, 116, 127] + # - Chunk 1: 38 tokens (0-37) + EOS at 38 + # - Chunk 2: 38 tokens (39-76) + EOS at 77 + # - Chunk 3: 38 tokens (78-115) + EOS at 116 + # - Chunk 4: 10 tokens (117-126) + EOS at 127 + # doc2: chunk_size=33 (chunk_len=32), produces 4 chunks: [32, 65, 98, 127] + # - Chunk 1: 32 tokens (0-31) + EOS at 32 + # - Chunk 2: 32 tokens (33-64) + EOS at 65 + # - Chunk 3: 32 tokens (66-97) + EOS at 98 + # - Chunk 4: 28 tokens (99-126) + EOS at 127 + expected_ids_0 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, 320, 77, 284, + EOS_TOKEN_ID, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, EOS_TOKEN_ID, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_0 = [1] * 128 + expected_eos_positions_0 = [38, 77, 116, 127] + + expected_ids_1 = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, EOS_TOKEN_ID, 31658, 320, 78670, 8, 8500, 448, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, EOS_TOKEN_ID, 23788, 17646, 304, 59645, 4158, 4925, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, EOS_TOKEN_ID, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, + 367, 855, 4991, 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, 2086, 882, 518, 4647, + 13, 758, EOS_TOKEN_ID + ] + expected_mask_1 = [1] * 128 + expected_eos_positions_1 = [32, 65, 98, 127] + + # Verify structure + assert doc_ids == ["doc1", "doc2"] + assert d_collated["input_ids"].shape[0] == 2 + assert len(eos_positions) == 2 + + # Verify doc1 + got_ids_0 = d_collated["input_ids"][0].tolist() + got_mask_0 = d_collated["attention_mask"][0].tolist() + assert got_ids_0 == expected_ids_0 + assert got_mask_0 == expected_mask_0 + assert eos_positions[0] == expected_eos_positions_0 + assert _strictly_increasing(eos_positions[0]) + for eos_pos in eos_positions[0]: + assert got_ids_0[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_0[eos_pos] == 1 + + # Verify doc2 + got_ids_1 = d_collated["input_ids"][1].tolist() + got_mask_1 = d_collated["attention_mask"][1].tolist() + assert got_ids_1 == expected_ids_1 + assert got_mask_1 == expected_mask_1 + assert eos_positions[1] == expected_eos_positions_1 + assert _strictly_increasing(eos_positions[1]) + for eos_pos in eos_positions[1]: + assert got_ids_1[eos_pos] == train_tokenizer.eos_token_id + assert got_mask_1[eos_pos] == 1 + + +@pytest.mark.unit +def test_chunked_encode_collator_random_chunk_size_range_variable_per_chunk(train_tokenizer): + """Test ChunkedEncodeCollator with random chunk size range, variable per chunk (inference).""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + random.seed(42) + + data_args = DataArguments( + passage_max_len=256, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size_range="32,64", # Random chunk size between 32 and 64 + passage_chunk_size_variable=True, # Variable chunk size per chunk + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", REAL_TEXT, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Hardcoded golden output with seed=42, passage_chunk_size_range="32,64", passage_chunk_size_variable=True + # With seed=42 and variable chunk sizes, each chunk gets a random size from [32, 64] + # Chunk sizes generated: 40, 34, 50, 48, 47, 41, 3 (last partial chunk) + # EOS positions: [38, 71, 120, 167, 213, 253, 255] + expected_ids = [ + 74290, 804, 315, 279, 17646, 315, 59645, 4158, 4925, 304, 279, 11220, 3738, 8109, 646, 7802, + 82519, 4401, 323, 1102, 304, 15629, 35701, 13, 362, 1555, 8569, 57330, 12635, 291, 23970, + 56981, 31658, 320, 78670, 8, 8500, 448, EOS_TOKEN_ID, 57330, 15626, 6358, 572, 9251, 311, + 6629, 279, 9981, 57330, 35606, 11, 311, 11047, 8674, 458, 285, 354, 17764, 11, 323, 311, + 90684, 349, 2326, 32420, 23788, 17646, 304, 59645, 4158, 4925, EOS_TOKEN_ID, 304, 855, 4991, + 320, 77, 284, 220, 16, 22, 8, 323, 2480, 9663, 41434, 320, 77, 284, 220, 22, 568, 2014, + 8552, 6239, 315, 6811, 37854, 389, 59645, 4158, 4925, 4401, 11, 4124, 12743, 367, 855, 4991, + 41434, 320, 77, 284, 220, 16, 15, 8, 1033, 19476, 264, EOS_TOKEN_ID, 2086, 882, 518, 4647, + 13, 758, 279, 8622, 4158, 4925, 279, 3076, 9981, 57330, 35606, 518, 220, 17, 23, 73760, 572, + 1550, 11, 220, 16, 13, 23, 19197, 441, 17, 58634, 11, 323, 24938, 8841, 4647, 311, 220, 16, + 13, 17, 19197, 441, 17, 58634, 13, EOS_TOKEN_ID, 758, 279, 44900, 47594, 315, 279, 5306, + 47639, 11, 279, 3076, 9981, 57330, 36829, 518, 2176, 3039, 1033, 4428, 320, 16, 13, 17, + 19041, 220, 16, 13, 16, 19197, 441, 17, 58634, 568, 39402, 458, 285, 354, 17764, 572, 5080, + 279, 12128, 7194, 572, 311, EOS_TOKEN_ID, 4647, 448, 7046, 10740, 2750, 304, 279, 5306, + 47639, 1091, 304, 279, 8622, 4158, 4925, 13, 4968, 4991, 41434, 518, 4647, 8542, 5080, + 3076, 57330, 36829, 304, 279, 8622, 4158, 4925, 320, 16, 13, 19, 51615, 220, 15, 13, + EOS_TOKEN_ID, 17, EOS_TOKEN_ID + ] + expected_mask = [1] * 256 + expected_eos_positions = [38, 71, 120, 167, 213, 253, 255] + + # Verify structure + assert doc_ids == ["doc1"] + assert d_collated["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions[0] == expected_eos_positions + assert _strictly_increasing(eos_positions[0]) + + # Verify each EOS position is valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 + + +@pytest.mark.unit +def test_chunked_encode_collator_random_chunk_size_range_hardcoded_output(train_tokenizer): + """Test ChunkedEncodeCollator with random chunk size range - hardcoded golden output (inference).""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + random.seed(42) + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size_range="32,48", # Random chunk size between 32 and 48 + passage_chunk_size_variable=False, # Fixed random size per passage + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + short_text = "Hello world this is a test passage" + features = [ + ("doc1", short_text, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Hardcoded golden output with seed=42 and chunk_size_range=(32,48) + # short_text tokenizes to: [9707, 1879, 419, 374, 264, 1273, 21085] + # With seed=42, random.randint(32, 48) = 40 (first call) + # So chunk_len = 39, but we only have 7 tokens, so we get: [7 tokens] + EOS + expected_ids = [9707, 1879, 419, 374, 264, 1273, 21085, EOS_TOKEN_ID] + [PADDING_TOKEN_ID] * 8 + expected_mask = [1] * 8 + [0] * 8 + expected_eos_positions = [[7]] + + assert doc_ids == ["doc1"] + assert got_ids == expected_ids + assert got_mask == expected_mask + assert eos_positions == expected_eos_positions + + +@pytest.mark.unit +def test_chunked_encode_collator_fixed_chunk_size_still_works(train_tokenizer): + """Test ChunkedEncodeCollator with fixed chunk size (no random chunking) still works.""" + _add_tevatron_src_to_path() + from tevatron.retriever.arguments import DataArguments + from tevatron.retriever.collator import ChunkedEncodeCollator + + data_args = DataArguments( + passage_max_len=128, + pad_to_multiple_of=16, + padding_side="right", + append_eos_token=False, + passage_chunk_size=32, # Fixed chunk size, no random chunking + ) + collator = ChunkedEncodeCollator(data_args=data_args, tokenizer=train_tokenizer) + + features = [ + ("doc1", REAL_TEXT, None, None, None), + ] + + doc_ids, d_collated, eos_positions = collator(features) + + # Verify structure + assert doc_ids == ["doc1"] + assert d_collated["input_ids"].shape[0] == 1 + assert len(eos_positions) == 1 + assert len(eos_positions[0]) > 0 + + got_ids = d_collated["input_ids"][0].tolist() + got_mask = d_collated["attention_mask"][0].tolist() + + # Verify EOS positions are strictly increasing + assert _strictly_increasing(eos_positions[0]) + + # Verify each EOS position is valid + for eos_pos in eos_positions[0]: + assert got_ids[eos_pos] == train_tokenizer.eos_token_id + assert got_mask[eos_pos] == 1 diff --git a/tests/test_forward.py b/tests/test_forward.py new file mode 100644 index 00000000..e8b54e3b --- /dev/null +++ b/tests/test_forward.py @@ -0,0 +1,137 @@ +import sys +from pathlib import Path + +import pytest +import torch +from unittest.mock import Mock + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + # tevatron/tests/test_forward.py -> tevatron/ -> tevatron/src + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.fixture(scope="session") +def train_tokenizer(): + """ + Use the Qwen 0.6B tokenizer. + """ + _add_tevatron_src_to_path() + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + tok.eos_token_id = tok.pad_token_id # Match training setup + tok.padding_side = "right" + return tok + + +@pytest.mark.unit +def test_compute_maxsim_similarity(): + """ + Test compute_maxsim_similarity function to verify MaxSim pooling logic. + """ + _add_tevatron_src_to_path() + from tevatron.retriever.modeling.encoder import EncoderModel + + # Create a concrete implementation for testing + class TestEncoderModel(EncoderModel): + def encode_query(self, qry): + raise NotImplementedError + def encode_passage(self, psg): + raise NotImplementedError + + model = TestEncoderModel(encoder=Mock(), pooling='last', normalize=False) + + # Test Case 1: Basic MaxSim computation + # Q=2 queries, P=3 passages, C=4 chunks per passage, H=8 hidden size + Q, P, C, H = 2, 3, 4, 8 + + q_reps = torch.randn(Q, H) + p_reps = torch.randn(P, C, H) + chunk_mask = torch.ones(P, C) # All chunks valid + + scores = model.compute_maxsim_similarity(q_reps, p_reps, chunk_mask) + + # Verify output shape + assert scores.shape == (Q, P) + + # Verify scores are computed correctly + # For each query-passage pair, score should be max of chunk similarities + for q_idx in range(Q): + for p_idx in range(P): + # Compute chunk scores manually + chunk_scores = torch.einsum('h,ch->c', q_reps[q_idx], p_reps[p_idx]) + expected_score = chunk_scores.max().item() + assert torch.allclose(scores[q_idx, p_idx], torch.tensor(expected_score)) + + # Test Case 2: With padding (some chunks are invalid) + chunk_mask_padded = torch.tensor([ + [1.0, 1.0, 1.0, 0.0], # Passage 0: 3 valid chunks + [1.0, 1.0, 0.0, 0.0], # Passage 1: 2 valid chunks + [1.0, 0.0, 0.0, 0.0], # Passage 2: 1 valid chunk + ]) + + scores_padded = model.compute_maxsim_similarity(q_reps, p_reps, chunk_mask_padded) + + # Verify shape + assert scores_padded.shape == (Q, P) + + # Verify that padding chunks don't affect the max + for q_idx in range(Q): + for p_idx in range(P): + # Compute chunk scores manually, masking out invalid chunks + chunk_scores = torch.einsum('h,ch->c', q_reps[q_idx], p_reps[p_idx]) + # Mask invalid chunks with -inf + valid_mask = chunk_mask_padded[p_idx].bool() + chunk_scores_masked = chunk_scores.clone() + chunk_scores_masked[~valid_mask] = float('-inf') + expected_score = chunk_scores_masked.max().item() + assert torch.allclose(scores_padded[q_idx, p_idx], torch.tensor(expected_score)) + + # Test Case 3: Single chunk per passage + P_single, C_single = 2, 1 + p_reps_single = torch.randn(P_single, C_single, H) + chunk_mask_single = torch.ones(P_single, C_single) + + scores_single = model.compute_maxsim_similarity(q_reps, p_reps_single, chunk_mask_single) + assert scores_single.shape == (Q, P_single) + + # With single chunk, MaxSim should equal the single chunk similarity + for q_idx in range(Q): + for p_idx in range(P_single): + expected_score = torch.dot(q_reps[q_idx], p_reps_single[p_idx, 0]).item() + assert torch.allclose(scores_single[q_idx, p_idx], torch.tensor(expected_score)) + + # Test Case 4: Different number of chunks per passage + # This tests that max_chunks is handled correctly + p_reps_uneven = torch.randn(P, C, H) + # Passage 0: all 4 chunks valid + # Passage 1: first 2 chunks valid + # Passage 2: first 1 chunk valid + chunk_mask_uneven = torch.tensor([ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + ]) + + scores_uneven = model.compute_maxsim_similarity(q_reps, p_reps_uneven, chunk_mask_uneven) + assert scores_uneven.shape == (Q, P) + + # Verify that only valid chunks are considered + for q_idx in range(Q): + for p_idx in range(P): + chunk_scores = torch.einsum('h,ch->c', q_reps[q_idx], p_reps_uneven[p_idx]) + valid_mask = chunk_mask_uneven[p_idx].bool() + chunk_scores_masked = chunk_scores.clone() + chunk_scores_masked[~valid_mask] = float('-inf') + expected_score = chunk_scores_masked.max().item() + assert torch.allclose(scores_uneven[q_idx, p_idx], torch.tensor(expected_score)) + + diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 00000000..75c7b28d --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,857 @@ +import sys +from pathlib import Path +import pickle +import numpy as np +import pytest +from collections import defaultdict + + +def _tevatron_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _add_tevatron_src_to_path(): + src = _tevatron_root() / "src" + sys.path.insert(0, str(src)) + + +@pytest.mark.unit +def test_search_chunked_vs_non_chunked(): + """ + Test search behavior differences between chunked and non-chunked modes. + This verifies: + 1. Auto-detection of chunked format + 2. MaxSim aggregation logic + 3. Search depth handling + """ + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked, search_queries + from tevatron.retriever.searcher import FaissFlatSearcher + + # Create mock query and passage embeddings + num_queries = 3 + num_docs = 10 + hidden_size = 64 + + # Query embeddings + q_reps = np.random.randn(num_queries, hidden_size).astype(np.float32) + # Normalize for inner product search + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + # Test Case 1: Non-chunked format + # Each document has one embedding + p_reps_non_chunked = np.random.randn(num_docs, hidden_size).astype(np.float32) + # Normalize for inner product search + p_reps_non_chunked = p_reps_non_chunked / np.linalg.norm(p_reps_non_chunked, axis=1, keepdims=True) + p_lookup_non_chunked = [f"doc_{i}" for i in range(num_docs)] + + retriever_non_chunked = FaissFlatSearcher(p_reps_non_chunked) + # Need to add embeddings to index + retriever_non_chunked.add(p_reps_non_chunked) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Non-chunked search + scores_non_chunked, indices_non_chunked = search_queries( + retriever_non_chunked, q_reps, p_lookup_non_chunked, args + ) + + # Verify non-chunked results + assert len(scores_non_chunked) == num_queries + assert len(indices_non_chunked) == num_queries + for q_idx in range(num_queries): + assert len(scores_non_chunked[q_idx]) == args.depth + assert len(indices_non_chunked[q_idx]) == args.depth + # indices_non_chunked contains document IDs (strings), not indices + assert all(isinstance(doc_id, (str, np.str_)) for doc_id in indices_non_chunked[q_idx][:5]) + + # Test Case 2: Chunked format - single chunk per document + # This simulates chunk_size == max_passage_size scenario + # Each document has exactly one chunk + p_reps_chunked_single = np.random.randn(num_docs, hidden_size).astype(np.float32) + # Normalize for inner product search + p_reps_chunked_single = p_reps_chunked_single / np.linalg.norm(p_reps_chunked_single, axis=1, keepdims=True) + q_reps_normalized = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_lookup_chunked_single = [(f"doc_{i}", 0) for i in range(num_docs)] + + retriever_chunked_single = FaissFlatSearcher(p_reps_chunked_single) + # Need to add embeddings to index + retriever_chunked_single.add(p_reps_chunked_single) + + # Chunked search with single chunk per doc + results_chunked_single = search_queries_chunked( + retriever_chunked_single, q_reps_normalized, p_lookup_chunked_single, args + ) + + # Verify chunked results + assert len(results_chunked_single) == num_queries + for q_idx in range(num_queries): + # Results might be less than depth if fewer documents exist + assert len(results_chunked_single[q_idx]) <= args.depth + assert len(results_chunked_single[q_idx]) > 0, "Should have at least some results" + # Each result should be (doc_id, score) tuple + for doc_id, score in results_chunked_single[q_idx]: + assert isinstance(doc_id, str) + assert isinstance(score, (int, float, np.floating)) + + # Test Case 3: Chunked format - multiple chunks per document + # Some documents have multiple chunks + num_chunks_total = 20 # More chunks than documents + p_reps_chunked_multi = np.random.randn(num_chunks_total, hidden_size).astype(np.float32) + # Normalize for inner product search + p_reps_chunked_multi = p_reps_chunked_multi / np.linalg.norm(p_reps_chunked_multi, axis=1, keepdims=True) + p_lookup_chunked_multi = [] + # Document 0-4: 2 chunks each (10 chunks) + # Document 5-9: 2 chunks each (10 chunks) + for doc_idx in range(num_docs): + for chunk_idx in range(2): + p_lookup_chunked_multi.append((f"doc_{doc_idx}", chunk_idx)) + + retriever_chunked_multi = FaissFlatSearcher(p_reps_chunked_multi) + retriever_chunked_multi.add(p_reps_chunked_multi) + + # Chunked search with multiple chunks per doc + results_chunked_multi = search_queries_chunked( + retriever_chunked_multi, q_reps, p_lookup_chunked_multi, args + ) + + # Verify MaxSim aggregation + assert len(results_chunked_multi) == num_queries + for q_idx in range(num_queries): + assert len(results_chunked_multi[q_idx]) == args.depth + # Verify MaxSim: each document should appear at most once + doc_ids = [doc_id for doc_id, _ in results_chunked_multi[q_idx]] + assert len(doc_ids) == len(set(doc_ids)), "Each document should appear only once (MaxSim aggregation)" + + # Verify scores are in descending order + scores = [score for _, score in results_chunked_multi[q_idx]] + assert scores == sorted(scores, reverse=True), "Scores should be in descending order" + + # Test Case 4: Verify MaxSim logic - same document with multiple chunks + # Create a scenario where one document has the best chunks + q_rep_test = np.random.randn(1, hidden_size).astype(np.float32) + q_rep_test = q_rep_test / np.linalg.norm(q_rep_test, axis=1, keepdims=True) + + # Create embeddings where doc_0 chunks are most similar to query + p_reps_test = np.random.randn(5, hidden_size).astype(np.float32) + # Make doc_0 chunks (indices 0, 1) more similar to query + p_reps_test[0] = q_rep_test[0] * 0.9 + np.random.randn(hidden_size) * 0.1 + p_reps_test[1] = q_rep_test[0] * 0.8 + np.random.randn(hidden_size) * 0.2 + # Other chunks less similar + p_reps_test[2:] = q_rep_test[0] * 0.3 + np.random.randn(3, hidden_size) * 0.7 + # Normalize + p_reps_test = p_reps_test / np.linalg.norm(p_reps_test, axis=1, keepdims=True) + + p_lookup_test = [ + ("doc_0", 0), # Best chunk + ("doc_0", 1), # Second best chunk + ("doc_1", 0), # Less similar + ("doc_2", 0), # Less similar + ("doc_3", 0), # Less similar + ] + + retriever_test = FaissFlatSearcher(p_reps_test) + retriever_test.add(p_reps_test) + results_test = search_queries_chunked(retriever_test, q_rep_test, p_lookup_test, args) + + # Verify MaxSim: doc_0 should be ranked first (max of its two chunks) + assert len(results_test) == 1 + assert len(results_test[0]) > 0, "Should have results" + top_doc = results_test[0][0][0] + assert top_doc == "doc_0", "doc_0 should be ranked first due to MaxSim (max of its chunks)" + + # Test Case 5: Verify search depth multiplier + args_large = MockArgs() + args_large.depth = 5 + args_large.chunk_multiplier = 10 + args_large.batch_size = 0 + args_large.quiet = True + + # With chunk_multiplier=10, should search 5 * 10 = 50 chunks + # But we only have 20 chunks, so should get all chunks + results_depth_test = search_queries_chunked( + retriever_chunked_multi, q_reps_normalized, p_lookup_chunked_multi, args_large + ) + + # Should return up to depth documents (after MaxSim aggregation) + assert len(results_depth_test[0]) <= args_large.depth + assert len(results_depth_test[0]) > 0, "Should have some results" + + # Test Case 6: Verify auto-detection logic + # Test that tuple format is detected as chunked + assert isinstance(p_lookup_chunked_single[0], tuple), "Chunked lookup should be tuple" + assert not isinstance(p_lookup_non_chunked[0], tuple), "Non-chunked lookup should be string" + + # Test Case 7: Verify that single chunk per doc behaves correctly + # When chunk_size == max_passage_size, each doc has one chunk + # In this case, MaxSim should give same result as non-chunked (if embeddings are identical) + # But search depth multiplier means we search more candidates + p_reps_single_chunk = p_reps_chunked_single.copy() + q_reps_single = q_reps_normalized.copy() + + # Search with same embeddings but different formats + results_single_chunk = search_queries_chunked( + retriever_chunked_single, q_reps_single, p_lookup_chunked_single, args + ) + + # Verify results structure + assert len(results_single_chunk) == num_queries + for q_idx in range(num_queries): + assert len(results_single_chunk[q_idx]) > 0 + # Each result should be (doc_id, score) + for doc_id, score in results_single_chunk[q_idx]: + assert isinstance(doc_id, str) + assert isinstance(score, (int, float, np.floating)) + + +@pytest.mark.unit +def test_write_ranking(): + """Test write_ranking function for non-chunked search results.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import write_ranking + import tempfile + import os + + # Create mock data + q_lookup = ["q1", "q2", "q3"] + corpus_scores = [ + [0.9, 0.8, 0.7, 0.6, 0.5], + [0.95, 0.85, 0.75, 0.65, 0.55], + [0.88, 0.78, 0.68, 0.58, 0.48] + ] + corpus_indices = [ + ["doc_1", "doc_2", "doc_3", "doc_4", "doc_5"], + ["doc_10", "doc_20", "doc_30", "doc_40", "doc_50"], + ["doc_100", "doc_200", "doc_300", "doc_400", "doc_500"] + ] + + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + write_ranking(corpus_indices, corpus_scores, q_lookup, temp_path) + + # Verify file contents + with open(temp_path, 'r') as f: + lines = f.readlines() + + assert len(lines) == 15 # 3 queries * 5 results + + # Check first query results (should be sorted by score descending) + first_query_lines = lines[:5] + scores = [float(line.strip().split('\t')[2]) for line in first_query_lines] + assert scores == sorted(scores, reverse=True), "Scores should be in descending order" + + # Verify format: qid\tidx\tscore + for line in lines: + parts = line.strip().split('\t') + assert len(parts) == 3, "Each line should have 3 parts: qid, idx, score" + assert parts[0] in q_lookup, "Query ID should be in q_lookup" + assert float(parts[2]) >= 0, "Score should be a number" + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.mark.unit +def test_write_ranking_chunked(): + """Test write_ranking_chunked function for chunked search results.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import write_ranking_chunked + import tempfile + import os + + # Create mock chunked results + q_lookup = ["q1", "q2"] + results = [ + [("doc_1", 0.95), ("doc_2", 0.85), ("doc_3", 0.75)], + [("doc_10", 0.92), ("doc_20", 0.82), ("doc_30", 0.72)] + ] + + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + write_ranking_chunked(results, q_lookup, temp_path) + + # Verify file contents + with open(temp_path, 'r') as f: + lines = f.readlines() + + assert len(lines) == 6 # 2 queries * 3 results + + # Verify format: qid\tdoc_id\tscore + for i, line in enumerate(lines): + parts = line.strip().split('\t') + assert len(parts) == 3, "Each line should have 3 parts: qid, doc_id, score" + + # Check query ID + if i < 3: + assert parts[0] == "q1" + else: + assert parts[0] == "q2" + + # Check score is a number + assert float(parts[2]) >= 0, "Score should be a number" + + # Verify scores are in descending order for each query + q1_scores = [float(lines[i].strip().split('\t')[2]) for i in range(3)] + q2_scores = [float(lines[i].strip().split('\t')[2]) for i in range(3, 6)] + assert q1_scores == sorted(q1_scores, reverse=True) + assert q2_scores == sorted(q2_scores, reverse=True) + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.mark.unit +def test_pickle_load_save(): + """Test pickle_load and pickle_save functions.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import pickle_load, pickle_save + import tempfile + import os + + # Create test data + test_reps = np.random.randn(10, 64).astype(np.float32) + test_lookup = [f"doc_{i}" for i in range(10)] + test_data = (test_reps, test_lookup) + + with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as f: + temp_path = f.name + + try: + # Save + pickle_save(test_data, temp_path) + assert os.path.exists(temp_path), "Pickle file should be created" + + # Load + loaded_reps, loaded_lookup = pickle_load(temp_path) + + # Verify data integrity + assert np.array_equal(loaded_reps, test_reps), "Embeddings should match" + assert loaded_lookup == test_lookup, "Lookup should match" + assert isinstance(loaded_reps, np.ndarray), "Loaded reps should be numpy array" + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + +@pytest.mark.unit +def test_search_batch_size(): + """Test that batch_size parameter works correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries + from tevatron.retriever.searcher import FaissFlatSearcher + + num_queries = 10 + num_docs = 20 + hidden_size = 64 + + q_reps = np.random.randn(num_queries, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_docs, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [f"doc_{i}" for i in range(num_docs)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + quiet = True + chunk_multiplier = 10 + + # Test with batch_size = 0 (no batching) + args_no_batch = MockArgs() + args_no_batch.batch_size = 0 + scores_no_batch, indices_no_batch = search_queries(retriever, q_reps, p_lookup, args_no_batch) + + # Test with batch_size > 0 (batching) + args_batch = MockArgs() + args_batch.batch_size = 3 + scores_batch, indices_batch = search_queries(retriever, q_reps, p_lookup, args_batch) + + # Results should be the same regardless of batching + assert len(scores_no_batch) == len(scores_batch) == num_queries + assert len(indices_no_batch) == len(indices_batch) == num_queries + + # Scores should match (allowing for small numerical differences) + for q_idx in range(num_queries): + assert len(scores_no_batch[q_idx]) == len(scores_batch[q_idx]) == args_no_batch.depth + # Scores should be very similar (allowing for floating point precision) + np.testing.assert_allclose(scores_no_batch[q_idx], scores_batch[q_idx], rtol=1e-5) + + +@pytest.mark.unit +def test_search_chunked_with_negative_indices(): + """Test chunked search handles FAISS -1 indices correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + from unittest.mock import Mock, patch + + hidden_size = 64 + num_docs = 3 + num_chunks = 5 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [ + ("doc_0", 0), + ("doc_0", 1), + ("doc_1", 0), + ("doc_2", 0), + ("doc_2", 1), + ] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 10 # Request more than available + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + # Mock search to return -1 for insufficient results + original_search = retriever.search + + def mock_search(q_reps, k): + scores, indices = original_search(q_reps, k) + # Simulate FAISS returning -1 for insufficient results + if k > num_chunks: + # Pad with -1 indices + padded_indices = np.full((scores.shape[0], k), -1, dtype=indices.dtype) + padded_scores = np.full((scores.shape[0], k), -np.inf, dtype=scores.dtype) + padded_indices[:, :indices.shape[1]] = indices + padded_scores[:, :scores.shape[1]] = scores + return padded_scores, padded_indices + return scores, indices + + retriever.search = mock_search + + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + # Should handle -1 indices gracefully + assert len(results) == 1 + assert len(results[0]) <= num_docs # Should aggregate to unique documents + # All results should be valid (doc_id, score) tuples + for doc_id, score in results[0]: + assert isinstance(doc_id, str) + assert isinstance(score, (int, float, np.floating)) + assert not np.isinf(score), "Scores should not be infinite" + + +@pytest.mark.unit +def test_search_single_query(): + """Test search with a single query.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries, search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_docs = 10 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_docs, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [f"doc_{i}" for i in range(num_docs)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Non-chunked search + scores, indices = search_queries(retriever, q_reps, p_lookup, args) + assert len(scores) == 1 + assert len(indices) == 1 + assert len(scores[0]) == args.depth + assert len(indices[0]) == args.depth + + # Chunked search + p_lookup_chunked = [(f"doc_{i}", 0) for i in range(num_docs)] + results = search_queries_chunked(retriever, q_reps, p_lookup_chunked, args) + assert len(results) == 1 + assert len(results[0]) <= args.depth + assert len(results[0]) > 0 + + +@pytest.mark.unit +def test_search_empty_results(): + """Test search behavior with edge cases.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + + # Single query, no passages + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + # Empty passage index + p_reps = np.random.randn(0, hidden_size).astype(np.float32) + p_lookup = [] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Should handle empty index gracefully + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + assert len(results) == 1 + assert len(results[0]) == 0, "Should return empty results for empty index" + + +@pytest.mark.unit +def test_search_depth_larger_than_documents(): + """Test search when depth is larger than available documents.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries, search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_docs = 5 + + q_reps = np.random.randn(2, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_docs, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [f"doc_{i}" for i in range(num_docs)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 20 # Larger than num_docs + batch_size = 0 + quiet = True + chunk_multiplier = 10 + + args = MockArgs() + + # Non-chunked: should return depth results (with padding if needed) + scores, indices = search_queries(retriever, q_reps, p_lookup, args) + assert len(scores) == 2 + assert len(scores[0]) == args.depth # FAISS will pad with -1 indices + + # Chunked: should return at most num_docs results + p_lookup_chunked = [(f"doc_{i}", 0) for i in range(num_docs)] + results = search_queries_chunked(retriever, q_reps, p_lookup_chunked, args) + assert len(results) == 2 + for q_result in results: + assert len(q_result) <= num_docs, "Should not return more documents than available" + + +@pytest.mark.unit +def test_search_chunked_multiplier_effect(): + """Test that chunk_multiplier affects search depth correctly.""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_docs = 10 + chunks_per_doc = 3 + num_chunks = num_docs * chunks_per_doc + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [(f"doc_{i}", j) for i in range(num_docs) for j in range(chunks_per_doc)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + + # Test with different multipliers + for multiplier in [1, 5, 10]: + args = MockArgs() + args.chunk_multiplier = multiplier + + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + # Should search depth * multiplier chunks + # After MaxSim aggregation, should return at most depth documents + assert len(results) == 1 + assert len(results[0]) <= args.depth, f"With multiplier {multiplier}, should return at most {args.depth} docs" + assert len(results[0]) > 0, "Should have some results" + + +@pytest.mark.unit +def test_index_boundary_check(): + """Verify index boundary check - ensure no out-of-bounds access to p_lookup""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_chunks = 10 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [(f"doc_{i}", 0) for i in range(num_chunks)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 10 # Will search 5 * 10 = 50 chunks, but only 10 available + + args = MockArgs() + + # Should not raise IndexError, FAISS will return -1 or valid indices + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + assert len(results) == 1 + # Should handle gracefully without out-of-bounds + assert len(results[0]) <= num_chunks + + +@pytest.mark.unit +def test_p_lookup_format_validation(): + """Verify p_lookup format - must be (doc_id, chunk_idx) tuples""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_chunks = 5 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + + # Correct format: tuples + p_lookup_correct = [(f"doc_{i}", i % 2) for i in range(num_chunks)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + # Should work correctly + results = search_queries_chunked(retriever, q_reps, p_lookup_correct, args) + assert len(results) == 1 + + # Wrong format: strings (non-chunked format) + p_lookup_wrong = [f"doc_{i}" for i in range(num_chunks)] + + # Function will catch errors and continue, won't raise exception + # but will log error messages + results = search_queries_chunked(retriever, q_reps, p_lookup_wrong, args) + # Due to format error, should return empty or partial results + assert len(results) == 1 + + +@pytest.mark.unit +def test_maxsim_aggregation_correctness(): + """Verify MaxSim aggregation correctness""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + + # Create a query + q_rep = np.random.randn(1, hidden_size).astype(np.float32) + q_rep = q_rep / np.linalg.norm(q_rep, axis=1, keepdims=True) + + # Create documents: doc_0 has 3 chunks, doc_1 has 2 chunks + # Make doc_0's chunk 0 most similar, chunk 1 second most, chunk 2 less similar + # Make doc_1's chunks less similar + p_reps = np.random.randn(5, hidden_size).astype(np.float32) + + # doc_0's chunk 0: most similar + p_reps[0] = q_rep[0] * 0.95 + np.random.randn(hidden_size) * 0.05 + # doc_0's chunk 1: second most similar + p_reps[1] = q_rep[0] * 0.85 + np.random.randn(hidden_size) * 0.15 + # doc_0's chunk 2: less similar + p_reps[2] = q_rep[0] * 0.50 + np.random.randn(hidden_size) * 0.50 + # doc_1's chunks: less similar + p_reps[3] = q_rep[0] * 0.40 + np.random.randn(hidden_size) * 0.60 + p_reps[4] = q_rep[0] * 0.35 + np.random.randn(hidden_size) * 0.65 + + # Normalize + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + + p_lookup = [ + ("doc_0", 0), # Most similar + ("doc_0", 1), # Second most similar + ("doc_0", 2), # Less similar + ("doc_1", 0), # Less similar + ("doc_1", 1), # Less similar + ] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + class MockArgs: + depth = 10 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + results = search_queries_chunked(retriever, q_rep, p_lookup, args) + + assert len(results) == 1 + assert len(results[0]) >= 1 + + # doc_0 should be ranked first (because its max score is chunk 0's score, highest) + top_doc = results[0][0][0] + assert top_doc == "doc_0", f"doc_0 should be top (has best chunk), but got {top_doc}" + + # Verify each document appears only once (MaxSim aggregation) + doc_ids = [doc_id for doc_id, _ in results[0]] + assert len(doc_ids) == len(set(doc_ids)), "Each document should appear only once" + + # Verify scores are in descending order + scores = [score for _, score in results[0]] + assert scores == sorted(scores, reverse=True), "Scores should be in descending order" + + +@pytest.mark.unit +def test_empty_doc_max_scores(): + """Test edge case when all results are -1""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(1, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [("doc_0", 0)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + # Mock search to return all -1 + original_search = retriever.search + + def mock_search_all_negative(q_reps, k): + scores = np.array([[-np.inf] * k]) + indices = np.array([[-1] * k]) + return scores, indices + + retriever.search = mock_search_all_negative + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + + # Should return empty results, not crash + assert len(results) == 1 + assert len(results[0]) == 0, "Should return empty list when all indices are -1" + + +@pytest.mark.unit +def test_index_out_of_bounds_protection(): + """Test index out-of-bounds protection - if FAISS returns out-of-range indices""" + _add_tevatron_src_to_path() + from tevatron.retriever.driver.search import search_queries_chunked + from tevatron.retriever.searcher import FaissFlatSearcher + + hidden_size = 64 + num_chunks = 5 + + q_reps = np.random.randn(1, hidden_size).astype(np.float32) + q_reps = q_reps / np.linalg.norm(q_reps, axis=1, keepdims=True) + + p_reps = np.random.randn(num_chunks, hidden_size).astype(np.float32) + p_reps = p_reps / np.linalg.norm(p_reps, axis=1, keepdims=True) + p_lookup = [(f"doc_{i}", 0) for i in range(num_chunks)] + + retriever = FaissFlatSearcher(p_reps) + retriever.add(p_reps) + + # Mock search to return out-of-bounds indices + original_search = retriever.search + + def mock_search_out_of_bounds(q_reps, k): + # Return some valid indices and some out-of-bounds indices + scores = np.array([[0.9, 0.8, 0.7, 0.6, 0.5]]) + indices = np.array([[0, 1, 2, 10, 20]]) # 10 and 20 are out of bounds + return scores, indices + + retriever.search = mock_search_out_of_bounds + + class MockArgs: + depth = 5 + batch_size = 0 + quiet = True + chunk_multiplier = 1 + + args = MockArgs() + + # Function will catch out-of-bounds indices and log warnings, won't raise exception + results = search_queries_chunked(retriever, q_reps, p_lookup, args) + # Should handle gracefully, only using valid indices + assert len(results) == 1 + # Since we have 3 valid indices (0, 1, 2), should have some results + assert len(results[0]) <= 3 # At most 3 documents (corresponding to indices 0, 1, 2)