Skip to content

Commit 034d18d

Browse files
Merge pull request #1190 from JohnSnowLabs/feature/implement-the-fuzz-tests-in-robustness
Feature/implement the fuzz tests in robustness
2 parents d996cbc + 993829d commit 034d18d

6 files changed

Lines changed: 418 additions & 5 deletions

File tree

langtest/modelhandler/llm_modelhandler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22
import inspect
33

4-
from typing import Any, List, Union
4+
from typing import Any, List, Type, Union, TypeVar
55
import langchain.llms as lc
66
import langchain.chat_models as chat_models
77
from langchain.chains.llm import LLMChain
@@ -33,6 +33,8 @@ class PretrainedModelForQA(ModelAPI):
3333
ConfigError: If there is an error in the model configuration.
3434
"""
3535

36+
_T = TypeVar("_T", bound="PretrainedModelForQA")
37+
3638
HUB_PARAM_MAPPING = {
3739
"azure-openai": "max_tokens",
3840
"ai21": "maxTokens",
@@ -60,7 +62,9 @@ def __init__(self, hub: str, model: Any, *args, **kwargs):
6062
self.predict.cache_clear()
6163

6264
@classmethod
63-
def load_model(cls, hub: str, path: str, *args, **kwargs) -> "PretrainedModelForQA":
65+
def load_model(
66+
cls: Type[_T], hub: str, path: str, *args, **kwargs
67+
) -> "PretrainedModelForQA":
6468
"""Load the pretrained model.
6569
6670
Args:

langtest/modelhandler/modelhandler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import defaultdict
33
from typing import Union
44
from functools import lru_cache
5+
from typing import Any, Type, TypeVar
56
from langtest.utils.lib_manager import try_import_lib
67

78
RENAME_HUBS = {
@@ -34,10 +35,13 @@ class ModelAPI(ABC):
3435
Implementations should inherit from this class and override load_model() and predict() methods.
3536
"""
3637

37-
model_registry = defaultdict(lambda: defaultdict(lambda: ModelAPI))
38+
_T = TypeVar("_T", bound="ModelAPI")
3839

40+
model_registry: defaultdict[str, dict[str, Type[_T]]] = defaultdict(dict)
41+
42+
@classmethod
3943
@abstractmethod
40-
def load_model(cls, *args, **kwargs):
44+
def load_model(cls: Type[_T], *args: Any, **kwargs: Any) -> _T:
4145
"""Load the model."""
4246
raise NotImplementedError()
4347

langtest/transform/clinical.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import importlib_resources
1010
from langtest.errors import Errors, Warnings
1111
from langtest.modelhandler.modelhandler import ModelAPI
12+
from langtest.tasks.task import TaskManager
1213
from langtest.transform.base import ITests, TestFactory
1314
from langtest.transform.utils import GENERIC2BRAND_TEMPLATE, filter_unique_samples
1415
from langtest.utils.custom_types.helpers import (
@@ -924,3 +925,155 @@ def evaluate_responses(
924925
)
925926

926927
return evaluator.aggregate_results(data_retriever, results)
928+
929+
930+
class MedFuzz(BaseClinical):
931+
alias_name = "medfuzz"
932+
supported_tasks = ["question-answering", "text-generation"]
933+
934+
@staticmethod
935+
def transform(sample_list: List[Sample], *args, **kwargs):
936+
# return super().transform(*args, **kwargs)
937+
from langtest.transform.utils import AttackerLLM, TargetLLM
938+
from langtest.utils.custom_types.sample import MedFuzzSample
939+
from tqdm.auto import tqdm
940+
941+
try:
942+
attacker_model_info = kwargs.get("attacker_llm", None)
943+
if attacker_model_info is not None:
944+
task = TaskManager("question-answering")
945+
model = task.model(
946+
model_path=attacker_model_info["model"],
947+
model_hub=attacker_model_info["hub"],
948+
model_type=attacker_model_info["type"],
949+
)
950+
else:
951+
from textwrap import dedent
952+
953+
error_message = dedent(
954+
"""
955+
Attack model information is not provided in Configuration. Please provide the attack model information.
956+
{
957+
"medfuzz": {
958+
"attacker_llm": {
959+
"model": "<model_name>",
960+
"hub": "<model_hub>",
961+
"type": "<chat | completion>"
962+
}
963+
}
964+
}
965+
"""
966+
).strip()
967+
968+
raise ValueError(error_message)
969+
970+
# model = task.model(model=kwargs)
971+
972+
samples = tqdm(
973+
sample_list,
974+
desc="Transforming the samples",
975+
unit="samples",
976+
position=1,
977+
)
978+
979+
transformed_samples = []
980+
for sample in samples:
981+
# llms
982+
983+
llm_attacker = AttackerLLM(model)
984+
llm_target = TargetLLM(model)
985+
986+
# sample
987+
med_sample = MedFuzzSample(**sample.dict())
988+
med_sample.test_type = "medfuzz"
989+
med_sample.category = "clinical"
990+
991+
if med_sample.options not in [None, ""]:
992+
med_sample.original_question = (
993+
f"{med_sample.original_question}\n{med_sample.options}"
994+
)
995+
med_sample.options = None
996+
997+
# ot = llm_target.process_user_text(f"{med_sample.original_question}\n{med_sample.options}")
998+
ot = llm_target.process_user_text(med_sample.original_question)
999+
1000+
# generate the attack plan
1001+
llm_attacker.generate_attack_plan(
1002+
benchmark_item=med_sample.original_question,
1003+
correct_answer="".join(med_sample.expected_results),
1004+
reasoning=ot["reasoning"],
1005+
confidence=ot["confidence_scores"],
1006+
)
1007+
1008+
# med_sample.perturbed_context = llm_attacker.generate_modified_question(
1009+
# med_sample.original_question
1010+
# )
1011+
med_sample.perturbed_question = llm_attacker.generate_modified_question(
1012+
med_sample.original_question
1013+
)
1014+
1015+
med_sample.expected_results = "".join(
1016+
map(str, med_sample.expected_results)
1017+
)[:1]
1018+
1019+
transformed_samples.append(med_sample)
1020+
1021+
return transformed_samples
1022+
except Exception:
1023+
import traceback
1024+
1025+
traceback.print_exc()
1026+
raise
1027+
1028+
@staticmethod
1029+
async def run(sample_list: List[Sample], model: ModelAPI, *args, **kwargs):
1030+
# return super().run(*args, **kwargs)
1031+
from langtest.transform.utils import TargetLLM
1032+
1033+
progress_bar = kwargs.get("progress_bar", False)
1034+
1035+
for sample in sample_list:
1036+
if sample.state != "done":
1037+
target_llm = TargetLLM(model)
1038+
1039+
response = target_llm.process_user_text(sample.perturbed_question)
1040+
1041+
sample.actual_results = response.get("final_answer", "")
1042+
1043+
# del
1044+
del target_llm
1045+
1046+
sample.state = "done"
1047+
1048+
if progress_bar:
1049+
progress_bar.update(1)
1050+
1051+
return sample_list
1052+
1053+
@staticmethod
1054+
def ollama_model_client(model, messages):
1055+
from ollama import Client
1056+
1057+
client = Client()
1058+
1059+
res = client.chat(
1060+
model=model,
1061+
messages=messages,
1062+
options={
1063+
"temperature": 0.9,
1064+
},
1065+
)
1066+
return res.message.content
1067+
1068+
@staticmethod
1069+
def openai_model_client(model, messages):
1070+
import openai
1071+
1072+
client = openai.Client()
1073+
1074+
res = (
1075+
client.chat.completions.create(model=model, messages=messages)
1076+
.choices[0]
1077+
.message.content
1078+
)
1079+
return res

0 commit comments

Comments
 (0)