|
9 | 9 | import importlib_resources |
10 | 10 | from langtest.errors import Errors, Warnings |
11 | 11 | from langtest.modelhandler.modelhandler import ModelAPI |
| 12 | +from langtest.tasks.task import TaskManager |
12 | 13 | from langtest.transform.base import ITests, TestFactory |
13 | 14 | from langtest.transform.utils import GENERIC2BRAND_TEMPLATE, filter_unique_samples |
14 | 15 | from langtest.utils.custom_types.helpers import ( |
@@ -924,3 +925,155 @@ def evaluate_responses( |
924 | 925 | ) |
925 | 926 |
|
926 | 927 | return evaluator.aggregate_results(data_retriever, results) |
| 928 | + |
| 929 | + |
| 930 | +class MedFuzz(BaseClinical): |
| 931 | + alias_name = "medfuzz" |
| 932 | + supported_tasks = ["question-answering", "text-generation"] |
| 933 | + |
| 934 | + @staticmethod |
| 935 | + def transform(sample_list: List[Sample], *args, **kwargs): |
| 936 | + # return super().transform(*args, **kwargs) |
| 937 | + from langtest.transform.utils import AttackerLLM, TargetLLM |
| 938 | + from langtest.utils.custom_types.sample import MedFuzzSample |
| 939 | + from tqdm.auto import tqdm |
| 940 | + |
| 941 | + try: |
| 942 | + attacker_model_info = kwargs.get("attacker_llm", None) |
| 943 | + if attacker_model_info is not None: |
| 944 | + task = TaskManager("question-answering") |
| 945 | + model = task.model( |
| 946 | + model_path=attacker_model_info["model"], |
| 947 | + model_hub=attacker_model_info["hub"], |
| 948 | + model_type=attacker_model_info["type"], |
| 949 | + ) |
| 950 | + else: |
| 951 | + from textwrap import dedent |
| 952 | + |
| 953 | + error_message = dedent( |
| 954 | + """ |
| 955 | + Attack model information is not provided in Configuration. Please provide the attack model information. |
| 956 | + { |
| 957 | + "medfuzz": { |
| 958 | + "attacker_llm": { |
| 959 | + "model": "<model_name>", |
| 960 | + "hub": "<model_hub>", |
| 961 | + "type": "<chat | completion>" |
| 962 | + } |
| 963 | + } |
| 964 | + } |
| 965 | + """ |
| 966 | + ).strip() |
| 967 | + |
| 968 | + raise ValueError(error_message) |
| 969 | + |
| 970 | + # model = task.model(model=kwargs) |
| 971 | + |
| 972 | + samples = tqdm( |
| 973 | + sample_list, |
| 974 | + desc="Transforming the samples", |
| 975 | + unit="samples", |
| 976 | + position=1, |
| 977 | + ) |
| 978 | + |
| 979 | + transformed_samples = [] |
| 980 | + for sample in samples: |
| 981 | + # llms |
| 982 | + |
| 983 | + llm_attacker = AttackerLLM(model) |
| 984 | + llm_target = TargetLLM(model) |
| 985 | + |
| 986 | + # sample |
| 987 | + med_sample = MedFuzzSample(**sample.dict()) |
| 988 | + med_sample.test_type = "medfuzz" |
| 989 | + med_sample.category = "clinical" |
| 990 | + |
| 991 | + if med_sample.options not in [None, ""]: |
| 992 | + med_sample.original_question = ( |
| 993 | + f"{med_sample.original_question}\n{med_sample.options}" |
| 994 | + ) |
| 995 | + med_sample.options = None |
| 996 | + |
| 997 | + # ot = llm_target.process_user_text(f"{med_sample.original_question}\n{med_sample.options}") |
| 998 | + ot = llm_target.process_user_text(med_sample.original_question) |
| 999 | + |
| 1000 | + # generate the attack plan |
| 1001 | + llm_attacker.generate_attack_plan( |
| 1002 | + benchmark_item=med_sample.original_question, |
| 1003 | + correct_answer="".join(med_sample.expected_results), |
| 1004 | + reasoning=ot["reasoning"], |
| 1005 | + confidence=ot["confidence_scores"], |
| 1006 | + ) |
| 1007 | + |
| 1008 | + # med_sample.perturbed_context = llm_attacker.generate_modified_question( |
| 1009 | + # med_sample.original_question |
| 1010 | + # ) |
| 1011 | + med_sample.perturbed_question = llm_attacker.generate_modified_question( |
| 1012 | + med_sample.original_question |
| 1013 | + ) |
| 1014 | + |
| 1015 | + med_sample.expected_results = "".join( |
| 1016 | + map(str, med_sample.expected_results) |
| 1017 | + )[:1] |
| 1018 | + |
| 1019 | + transformed_samples.append(med_sample) |
| 1020 | + |
| 1021 | + return transformed_samples |
| 1022 | + except Exception: |
| 1023 | + import traceback |
| 1024 | + |
| 1025 | + traceback.print_exc() |
| 1026 | + raise |
| 1027 | + |
| 1028 | + @staticmethod |
| 1029 | + async def run(sample_list: List[Sample], model: ModelAPI, *args, **kwargs): |
| 1030 | + # return super().run(*args, **kwargs) |
| 1031 | + from langtest.transform.utils import TargetLLM |
| 1032 | + |
| 1033 | + progress_bar = kwargs.get("progress_bar", False) |
| 1034 | + |
| 1035 | + for sample in sample_list: |
| 1036 | + if sample.state != "done": |
| 1037 | + target_llm = TargetLLM(model) |
| 1038 | + |
| 1039 | + response = target_llm.process_user_text(sample.perturbed_question) |
| 1040 | + |
| 1041 | + sample.actual_results = response.get("final_answer", "") |
| 1042 | + |
| 1043 | + # del |
| 1044 | + del target_llm |
| 1045 | + |
| 1046 | + sample.state = "done" |
| 1047 | + |
| 1048 | + if progress_bar: |
| 1049 | + progress_bar.update(1) |
| 1050 | + |
| 1051 | + return sample_list |
| 1052 | + |
| 1053 | + @staticmethod |
| 1054 | + def ollama_model_client(model, messages): |
| 1055 | + from ollama import Client |
| 1056 | + |
| 1057 | + client = Client() |
| 1058 | + |
| 1059 | + res = client.chat( |
| 1060 | + model=model, |
| 1061 | + messages=messages, |
| 1062 | + options={ |
| 1063 | + "temperature": 0.9, |
| 1064 | + }, |
| 1065 | + ) |
| 1066 | + return res.message.content |
| 1067 | + |
| 1068 | + @staticmethod |
| 1069 | + def openai_model_client(model, messages): |
| 1070 | + import openai |
| 1071 | + |
| 1072 | + client = openai.Client() |
| 1073 | + |
| 1074 | + res = ( |
| 1075 | + client.chat.completions.create(model=model, messages=messages) |
| 1076 | + .choices[0] |
| 1077 | + .message.content |
| 1078 | + ) |
| 1079 | + return res |
0 commit comments