Skip to content

Commit 7cb8d2c

Browse files
committed
trust remote code when downloading rerankers
1 parent cfcc4e3 commit 7cb8d2c

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

PAKTON Framework/Researcher/src/Researcher/config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ reranking:
6767
use_reranker: True
6868
top_k: 64
6969
similarity_threshold: 0
70-
reranker_type: "llm-reranker" # Options: "cross-encoder", "flag-reranker", llm-reranker
71-
model: BAAI/bge-reranker-v2-minicpm-layerwise # cross-encoder/ms-marco-MiniLM-L6-v2 , BAAI/bge-reranker-v2-m3 , BAAI/bge-reranker-v2-minicpm-layerwise
70+
reranker_type: "flag-reranker" # Options: "cross-encoder", "flag-reranker", llm-reranker
71+
model: BAAI/bge-reranker-v2-m3 # cross-encoder/ms-marco-MiniLM-L6-v2 , BAAI/bge-reranker-v2-m3 , BAAI/bge-reranker-v2-minicpm-layerwise
7272
use_fp16: True # Whether to use FP16 for flag-reranker (speeds up computation with slight performance degradation)
73-
cutoff_layers: 28 # layers to cut off for flag-reranker, only for minicpm layerwise
73+
# cutoff_layers: 28 # layers to cut off for flag-reranker, only for minicpm layerwise
7474

7575
llm_filtering:
7676
use_llm_filtering: True

PAKTON Framework/Researcher/src/Researcher/graph/nodes/reranking.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,50 @@
88
Date: 2025/03/10
99
"""
1010
import torch
11-
from sentence_transformers import CrossEncoder
11+
1212
from Researcher.types import RetrievalState
1313
from Researcher.utils import logger, config
1414

1515
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1616

17-
cross_encoder_config = config.get("reranking", {})
18-
top_k = cross_encoder_config.get("top_k", 5)
19-
similarity_threshold = cross_encoder_config.get("similarity_threshold", None)
20-
reranker_type = cross_encoder_config.get("reranker_type", "cross-encoder")
21-
model = cross_encoder_config.get("model", None)
22-
use_reranker = cross_encoder_config.get("use_reranker", False)
23-
use_fp16 = cross_encoder_config.get("use_fp16", False)
24-
cutoff_layers = cross_encoder_config.get("cutoff_layers", None)
17+
reranker_config = config.get("reranking", {})
18+
top_k = reranker_config.get("top_k", 5)
19+
similarity_threshold = reranker_config.get("similarity_threshold", None)
20+
reranker_type = reranker_config.get("reranker_type", "cross-encoder")
21+
model = reranker_config.get("model", None)
22+
use_reranker = reranker_config.get("use_reranker", False)
23+
use_fp16 = reranker_config.get("use_fp16", False)
24+
cutoff_layers = reranker_config.get("cutoff_layers", None)
2525

2626
# Import and initialize appropriate reranker based on configuration
27-
reranker = None
28-
if reranker_type == "cross-encoder":
29-
reranker = CrossEncoder(model, default_activation_function=torch.nn.Sigmoid())
30-
logger.info(f"Initialized CrossEncoder reranker with model {model}")
31-
elif reranker_type == "flag-reranker":
32-
try:
33-
from FlagEmbedding import FlagReranker
34-
reranker = FlagReranker(model, use_fp16=use_fp16)
35-
logger.info(f"Initialized FlagReranker with model {model} (use_fp16={use_fp16})")
36-
except ImportError:
37-
logger.error("Failed to import FlagReranker. Make sure FlagEmbedding is installed.")
38-
raise
39-
elif reranker_type == "llm-reranker":
40-
try:
41-
from FlagEmbedding import LayerWiseFlagLLMReranker
42-
reranker = LayerWiseFlagLLMReranker(model, use_fp16=use_fp16)
43-
logger.info(f"Initialized FlagReranker with model {model} (use_fp16={use_fp16})")
44-
except ImportError:
45-
logger.error("Failed to import FlagReranker. Make sure FlagEmbedding is installed.")
46-
raise
47-
else:
48-
raise ValueError(f"Unknown reranker type: {reranker_type}")
27+
if use_reranker:
28+
reranker = None
29+
if reranker_type == "cross-encoder":
30+
try:
31+
from sentence_transformers import CrossEncoder
32+
reranker = CrossEncoder(model, default_activation_function=torch.nn.Sigmoid(), trust_remote_code=True)
33+
logger.info(f"Initialized CrossEncoder reranker with model {model}")
34+
except ImportError:
35+
logger.error("Failed to import CrossEncoder. Make sure sentence-transformers is installed.")
36+
raise
37+
elif reranker_type == "flag-reranker":
38+
try:
39+
from FlagEmbedding import FlagReranker
40+
reranker = FlagReranker(model, use_fp16=use_fp16, trust_remote_code=True)
41+
logger.info(f"Initialized FlagReranker with model {model} (use_fp16={use_fp16})")
42+
except ImportError:
43+
logger.error("Failed to import FlagReranker. Make sure FlagEmbedding is installed.")
44+
raise
45+
elif reranker_type == "llm-reranker":
46+
try:
47+
from FlagEmbedding import LayerWiseFlagLLMReranker
48+
reranker = LayerWiseFlagLLMReranker(model, use_fp16=use_fp16, trust_remote_code=True)
49+
logger.info(f"Initialized LayerWiseFlagLLMReranker with model {model} (use_fp16={use_fp16})")
50+
except ImportError:
51+
logger.error("Failed to import LayerWiseFlagLLMReranker. Make sure FlagEmbedding is installed.")
52+
raise
53+
else:
54+
raise ValueError(f"Unknown reranker type: {reranker_type}")
4955

5056
def rerank(state: RetrievalState) -> RetrievalState:
5157
"""

0 commit comments

Comments
 (0)