Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions nlptest/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,7 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data

elif data[0].task == "question-answering":
dataset_name = data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))

if data[0].expected_results is None:
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')
Expand All @@ -646,8 +645,7 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data

elif data[0].task == "summarization":
dataset_name = data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
if data[0].expected_results is None:
raise RuntimeError(f'The dataset {dataset_name} does not contain labels and fairness tests cannot be run with it. Skipping the fairness tests.')

Expand Down Expand Up @@ -802,8 +800,7 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data

elif raw_data[0].task=="question-answering":
dataset_name = raw_data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))

y_true = pd.Series(raw_data).apply(lambda x: x.expected_results)
X_test = pd.Series(raw_data)
Expand All @@ -812,8 +809,7 @@ def run(cls, sample_list: Dict[str, List[Sample]], model: ModelFactory, raw_data

elif raw_data[0].task=="summarization":
dataset_name = raw_data[0].dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))

y_true = pd.Series(raw_data).apply(lambda x: x.expected_results)
X_test = pd.Series(raw_data)
Expand Down
6 changes: 2 additions & 4 deletions nlptest/transform/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ async def run(sample_list: List[Sample], model: ModelFactory, **kwargs) -> List[
if sample.state != "done":
if sample.task == "question-answering":
dataset_name = sample.dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
sample.expected_results = model(text={'context':sample.original_context, 'question': sample.original_question},
prompt={"template":prompt_template, 'input_variables':["context", "question"]})
sample.actual_results = model(text={'context':sample.perturbed_context, 'question': sample.perturbed_question},
prompt={"template":prompt_template, 'input_variables':["context", "question"]})
elif sample.task == "summarization":
dataset_name = sample.dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
sample.expected_results = model(text={'context':sample.original},
prompt={"template":prompt_template, 'input_variables':["context"]})
sample.actual_results = model(text={'context':sample.original},
Expand Down
7 changes: 3 additions & 4 deletions nlptest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,15 @@ async def run(sample_list: List[Sample], model: ModelFactory, **kwargs) -> List[
if sample.state != "done":
if sample.task == 'question-answering':
dataset_name = sample.dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = """Context: {context}\nQuestion: {question}\n """ + user_prompt
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
sample.expected_results = model(text={'context':sample.original_context, 'question': sample.original_question},
prompt={"template":prompt_template, 'input_variables':["context", "question"]})
sample.actual_results = model(text={'context':sample.perturbed_context, 'question': sample.perturbed_question},
prompt={"template":prompt_template, 'input_variables':["context", "question"]})

elif sample.task == 'summarization':
dataset_name = sample.dataset_name.split('-')[0].lower()
user_prompt = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
prompt_template = user_prompt + """Context: {context}\n\n Summary: """
prompt_template = kwargs.get('user_prompt', default_user_prompt.get(dataset_name, ""))
sample.expected_results = model(text={'context':sample.original},
prompt={"template":prompt_template, 'input_variables':["context"]})
sample.actual_results = model(text={'context':sample.original},
Expand Down Expand Up @@ -101,6 +99,7 @@ async def async_run(cls, sample_list: List[Sample], model: ModelFactory, **kwarg
return created_task



class UpperCase(BaseRobustness):
alias_name = "uppercase"

Expand Down
6 changes: 3 additions & 3 deletions nlptest/transform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7191,9 +7191,9 @@ def get_entity_representation_proportions(entity_representation):
return entity_representation_proportion

default_user_prompt = {
"boolq": "I've provided a question and context. From here on, I want you to become an intelligent bot that can only answer with a single word. The words you are capable of saying are True and False. If you think the answer to the question is True, then say 'True'. If it is False, then say 'False'. Do not say anything else other than that.",
"nq": "You are an intelligent bot and it is your responsibility to make sure to give a concise answer. Answer:",
"xsum": "You are an intelligent Context summarizer. Please read the following context carefully. After understanding its content, create a concise summary, capturing the essential themes and key details. Please ensure that the summary does not end abruptly and remains within the max_tokens word limit."
"boolq": "Context: {context}\nQuestion: {question}\n I've provided a question and context. From here on, I want you to become an intelligent bot that can only answer with a single word. The words you are capable of saying are True and False. If you think the answer to the question is True, then say 'True'. If it is False, then say 'False'. Do not say anything else other than that.",
"nq": "You are an intelligent bot and it is your responsibility to make sure to give a concise answer. Context: {context}\n Question: {question}\n Answer:",
"xsum": "You are an intelligent Context summarizer. Please read the following context carefully. After understanding its content, create a concise summary, capturing the essential themes and key details. Please ensure that the summary does not end abruptly and remains within the max_tokens word limit. Context: {context}\n\n Summary: "
}

qa_prompt_template ="""
Expand Down