|
8 | 8 | Date: 2025/03/10 |
9 | 9 | """ |
10 | 10 | import torch |
11 | | -from sentence_transformers import CrossEncoder |
| 11 | + |
12 | 12 | from Researcher.types import RetrievalState |
13 | 13 | from Researcher.utils import logger, config |
14 | 14 |
|
15 | 15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
16 | 16 |
|
17 | | -cross_encoder_config = config.get("reranking", {}) |
18 | | -top_k = cross_encoder_config.get("top_k", 5) |
19 | | -similarity_threshold = cross_encoder_config.get("similarity_threshold", None) |
20 | | -reranker_type = cross_encoder_config.get("reranker_type", "cross-encoder") |
21 | | -model = cross_encoder_config.get("model", None) |
22 | | -use_reranker = cross_encoder_config.get("use_reranker", False) |
23 | | -use_fp16 = cross_encoder_config.get("use_fp16", False) |
24 | | -cutoff_layers = cross_encoder_config.get("cutoff_layers", None) |
| 17 | +reranker_config = config.get("reranking", {}) |
| 18 | +top_k = reranker_config.get("top_k", 5) |
| 19 | +similarity_threshold = reranker_config.get("similarity_threshold", None) |
| 20 | +reranker_type = reranker_config.get("reranker_type", "cross-encoder") |
| 21 | +model = reranker_config.get("model", None) |
| 22 | +use_reranker = reranker_config.get("use_reranker", False) |
| 23 | +use_fp16 = reranker_config.get("use_fp16", False) |
| 24 | +cutoff_layers = reranker_config.get("cutoff_layers", None) |
25 | 25 |
|
26 | 26 | # Import and initialize appropriate reranker based on configuration |
27 | | -reranker = None |
28 | | -if reranker_type == "cross-encoder": |
29 | | - reranker = CrossEncoder(model, default_activation_function=torch.nn.Sigmoid()) |
30 | | - logger.info(f"Initialized CrossEncoder reranker with model {model}") |
31 | | -elif reranker_type == "flag-reranker": |
32 | | - try: |
33 | | - from FlagEmbedding import FlagReranker |
34 | | - reranker = FlagReranker(model, use_fp16=use_fp16) |
35 | | - logger.info(f"Initialized FlagReranker with model {model} (use_fp16={use_fp16})") |
36 | | - except ImportError: |
37 | | - logger.error("Failed to import FlagReranker. Make sure FlagEmbedding is installed.") |
38 | | - raise |
39 | | -elif reranker_type == "llm-reranker": |
40 | | - try: |
41 | | - from FlagEmbedding import LayerWiseFlagLLMReranker |
42 | | - reranker = LayerWiseFlagLLMReranker(model, use_fp16=use_fp16) |
43 | | - logger.info(f"Initialized FlagReranker with model {model} (use_fp16={use_fp16})") |
44 | | - except ImportError: |
45 | | - logger.error("Failed to import FlagReranker. Make sure FlagEmbedding is installed.") |
46 | | - raise |
47 | | -else: |
48 | | - raise ValueError(f"Unknown reranker type: {reranker_type}") |
| 27 | +if use_reranker: |
| 28 | + reranker = None |
| 29 | + if reranker_type == "cross-encoder": |
| 30 | + try: |
| 31 | + from sentence_transformers import CrossEncoder |
| 32 | + reranker = CrossEncoder(model, default_activation_function=torch.nn.Sigmoid(), trust_remote_code=True) |
| 33 | + logger.info(f"Initialized CrossEncoder reranker with model {model}") |
| 34 | + except ImportError: |
| 35 | + logger.error("Failed to import CrossEncoder. Make sure sentence-transformers is installed.") |
| 36 | + raise |
| 37 | + elif reranker_type == "flag-reranker": |
| 38 | + try: |
| 39 | + from FlagEmbedding import FlagReranker |
| 40 | + reranker = FlagReranker(model, use_fp16=use_fp16, trust_remote_code=True) |
| 41 | + logger.info(f"Initialized FlagReranker with model {model} (use_fp16={use_fp16})") |
| 42 | + except ImportError: |
| 43 | + logger.error("Failed to import FlagReranker. Make sure FlagEmbedding is installed.") |
| 44 | + raise |
| 45 | + elif reranker_type == "llm-reranker": |
| 46 | + try: |
| 47 | + from FlagEmbedding import LayerWiseFlagLLMReranker |
| 48 | + reranker = LayerWiseFlagLLMReranker(model, use_fp16=use_fp16, trust_remote_code=True) |
| 49 | + logger.info(f"Initialized LayerWiseFlagLLMReranker with model {model} (use_fp16={use_fp16})") |
| 50 | + except ImportError: |
| 51 | + logger.error("Failed to import LayerWiseFlagLLMReranker. Make sure FlagEmbedding is installed.") |
| 52 | + raise |
| 53 | + else: |
| 54 | + raise ValueError(f"Unknown reranker type: {reranker_type}") |
49 | 55 |
|
50 | 56 | def rerank(state: RetrievalState) -> RetrievalState: |
51 | 57 | """ |
|
0 commit comments