Skip to content

Commit 12f1285

Browse files
authored
Merge pull request #1664 from Giskard-AI/GSK-2152-scanner-mulitlanguage-input
Add language support in LLM generators [GSK-2152]
2 parents 1da6ba0 + 4108196 commit 12f1285

File tree

9 files changed

+266
-7
lines changed

9 files changed

+266
-7
lines changed

giskard/datasets/base/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,26 @@ def upload(self, client: GiskardClient, project_key: str):
507507
)
508508
return dataset_id
509509

510+
def extract_languages(self, columns=None):
511+
"""
512+
Extracts all languages present in the dataset 'text' column.
513+
514+
Args:
515+
list[str]: a list of columns from which languages should be extracted.
516+
517+
Returns:
518+
list[str]: a list of language codes (according to ISO 639-1) containing all languages in the dataset.
519+
"""
520+
columns = columns if columns is not None else self.columns
521+
522+
langs_per_feature = [
523+
self.column_meta[col, "text"]["language"].dropna().unique()
524+
for col, col_type in self.column_types.items()
525+
if (col_type == "text" and col in columns)
526+
]
527+
528+
return list(set().union(*langs_per_feature))
529+
510530
@property
511531
def meta(self):
512532
return DatasetMeta(

giskard/llm/generators/adversarial.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ def _make_dataset_name(self, model: BaseModel, num_samples):
4646
return truncate(f"Adversarial Examples for requirement “{self.requirement}”")
4747

4848
def _make_generate_input_prompt(self, model: BaseModel, num_inputs: int):
49-
return self.prompt.format(
49+
input_prompt = self.prompt.format(
5050
issue_description=self.issue_description,
5151
model_name=model.meta.name,
5252
model_description=model.meta.description,
5353
feature_names=", ".join(model.meta.feature_names),
5454
num_samples=num_inputs,
5555
requirement=self.requirement,
5656
)
57+
if self.languages:
58+
input_prompt = input_prompt + self._default_language_requirement.format(languages=self.languages)
59+
return input_prompt

giskard/llm/generators/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Sequence
22

33
from abc import ABC, abstractmethod
44

@@ -20,20 +20,25 @@
2020
Think step by step and then call the `generate_inputs` function with the generated inputs. You must generate {num_samples} inputs.
2121
"""
2222

23+
LANGUAGE_REQUIREMENT_PROMPT = "You must generate input using different languages among the following list: {languages}."
24+
2325

2426
class LLMGenerator(ABC):
2527
_default_temperature = 0.5
2628
_default_model = "gpt-4"
2729
_default_prompt = DEFAULT_GENERATE_INPUTS_PROMPT
30+
_default_language_requirement = LANGUAGE_REQUIREMENT_PROMPT
2831

2932
def __init__(
3033
self,
3134
llm_temperature: Optional[float] = None,
3235
llm_client: LLMClient = None,
3336
prompt: Optional[str] = None,
37+
languages: Optional[Sequence[str]] = None,
3438
):
3539
self.llm_temperature = llm_temperature if llm_temperature is not None else self._default_temperature
3640
self.llm_client = llm_client or get_default_client()
41+
self.languages = languages
3742
self.prompt = prompt if prompt is not None else self._default_prompt
3843

3944
@abstractmethod
@@ -43,12 +48,15 @@ def generate_dataset(self, model, num_samples=10, column_types=None) -> Dataset:
4348

4449
class BaseDataGenerator(LLMGenerator):
4550
def _make_generate_input_prompt(self, model: BaseModel, num_samples: int):
46-
return self.prompt.format(
51+
input_prompt = self.prompt.format(
4752
model_name=model.meta.name,
4853
model_description=model.meta.description,
4954
feature_names=", ".join(model.meta.feature_names),
5055
num_samples=num_samples,
5156
)
57+
if self.languages:
58+
input_prompt = input_prompt + self._default_language_requirement.format(languages=self.languages)
59+
return input_prompt
5260

5361
def _make_generate_input_functions(self, model: BaseModel, num_samples: int):
5462
return [

giskard/llm/generators/sycophancy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,15 @@ class SycophancyDataGenerator(LLMGenerator):
3939
_default_prompt = GENERATE_INPUTS_PROMPT
4040

4141
def _make_generate_input_prompt(self, model: BaseModel, num_samples):
42-
return self.prompt.format(
42+
input_prompt = self.prompt.format(
4343
model_name=model.meta.name,
4444
model_description=model.meta.description,
4545
feature_names=", ".join(model.meta.feature_names),
4646
num_samples=num_samples,
4747
)
48+
if self.languages:
49+
input_prompt = input_prompt + self._default_language_requirement.format(languages=self.languages)
50+
return input_prompt
4851

4952
def _make_generate_input_functions(self, model: BaseModel):
5053
return [

giskard/scanner/llm/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def run(self, model: BaseModel, dataset: Dataset, features=None) -> Sequence[Iss
6565
issues = []
6666
for requirement in requirements:
6767
logger.info(f"{self.__class__.__name__}: Evaluating requirement: {requirement}")
68-
dg = AdversarialDataGenerator(issue_description=issue_description, requirement=requirement)
68+
69+
languages = dataset.extract_languages(columns=model.meta.feature_names)
70+
71+
dg = AdversarialDataGenerator(
72+
issue_description=issue_description, requirement=requirement, languages=languages
73+
)
6974
eval_dataset = dg.generate_dataset(model, self.num_samples)
7075

7176
evaluator = RequirementEvaluator([requirement])

giskard/scanner/llm/llm_basic_sycophancy_detector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def get_cost_estimate(self, model: BaseModel, dataset: Dataset) -> dict:
7878

7979
def run(self, model: BaseModel, dataset: Dataset, features=None) -> Sequence[Issue]:
8080
# Prepare datasets
81-
generator = SycophancyDataGenerator()
81+
languages = dataset.extract_languages(columns=model.meta.feature_names)
82+
83+
generator = SycophancyDataGenerator(languages=languages)
8284
dataset1, dataset2 = generator.generate_dataset(
8385
model, num_samples=self.num_samples, column_types=dataset.column_types
8486
)

giskard/scanner/llm/llm_implausible_output_detector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def get_cost_estimate(self, model: BaseModel, dataset: Dataset) -> dict:
5959

6060
def run(self, model: BaseModel, dataset: Dataset, features=None) -> Sequence[Issue]:
6161
# Generate inputs
62-
generator = ImplausibleDataGenerator(llm_temperature=0.1)
62+
languages = dataset.extract_languages(columns=model.meta.feature_names)
63+
64+
generator = ImplausibleDataGenerator(llm_temperature=0.1, languages=languages)
6365
eval_dataset = generator.generate_dataset(
6466
model, num_samples=self.num_samples, column_types=dataset.column_types
6567
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import pandas as pd
2+
3+
from giskard.datasets import Dataset
4+
5+
6+
def test_dataset_language_exhaustive_text_column_extraction():
7+
df = pd.DataFrame(
8+
{
9+
"col1": [
10+
"How does deforestation contribute to climate change according to IPCC reports?",
11+
"Quel est le rôle des gaz à effet de serre dans le réchauffement climatique?",
12+
"¿Cuál es el papel de los gases de efecto invernadero en el calentamiento global?",
13+
],
14+
"col2": [
15+
"Proč zpráva IPCC naznačuje, že lidské aktivity nejsou hlavní příčinou klimatických změn?",
16+
"CAT1",
17+
"CAT2",
18+
],
19+
"col3": [0, 1, 2],
20+
}
21+
)
22+
23+
dataset = Dataset(df, column_types={"col1": "text", "col2": "category"}, target="col3")
24+
languages = dataset.extract_languages()
25+
languages.sort()
26+
assert languages == ["en", "es", "fr"]
27+
28+
dataset = Dataset(df, column_types={"col1": "text", "col2": "text"}, target="col3")
29+
languages = dataset.extract_languages()
30+
languages.sort()
31+
assert languages == ["cs", "en", "es", "fr"]
32+
33+
34+
def test_dataset_language_when_empty():
35+
df = pd.DataFrame(
36+
{
37+
"col1": [
38+
"How does deforestation contribute to climate change according to IPCC reports?",
39+
"Quel est le rôle des gaz à effet de serre dans le réchauffement climatique?",
40+
"¿Cuál es el papel de los gases de efecto invernadero en el calentamiento global?",
41+
],
42+
"col2": ["CAT0", "CAT1", "CAT2"],
43+
"col3": [0, 1, 2],
44+
"col4": [3, 4, 5],
45+
}
46+
)
47+
48+
dataset = Dataset(df, column_types={"col1": "category", "col2": "text", "col3": "text"}, target="col4")
49+
languages = dataset.extract_languages()
50+
languages.sort()
51+
assert languages == []
52+
53+
df = pd.DataFrame(
54+
{
55+
"col1": [
56+
"How does deforestation contribute to climate change according to IPCC reports?",
57+
"Quel est le rôle des gaz à effet de serre dans le réchauffement climatique?",
58+
"¿Cuál es el papel de los gases de efecto invernadero en el calentamiento global?",
59+
],
60+
"col2": [None, "CAT1", "CAT2"],
61+
"col3": ["Bonjour", None, None],
62+
"col4": [3, 4, 5],
63+
}
64+
)
65+
66+
dataset = Dataset(df, column_types={"col1": "category", "col2": "text", "col3": "text"}, target="col4")
67+
languages = dataset.extract_languages()
68+
languages.sort()
69+
assert languages == []
70+
71+
72+
def test_dataset_language_column_filtering():
73+
df = pd.DataFrame(
74+
{
75+
"col1": [
76+
"How does deforestation contribute to climate change according to IPCC reports?",
77+
"Quel est le rôle des gaz à effet de serre dans le réchauffement climatique?",
78+
"¿Cuál es el papel de los gases de efecto invernadero en el calentamiento global?",
79+
],
80+
"col2": [
81+
"Proč zpráva IPCC naznačuje, že lidské aktivity nejsou hlavní příčinou klimatických změn?",
82+
"CAT1",
83+
"CAT2",
84+
],
85+
"col3": [0, 1, 2],
86+
"col4": [3, 4, 5],
87+
}
88+
)
89+
90+
dataset = Dataset(df, column_types={"col1": "text", "col2": "text", "col3": "numeric"}, target="col4")
91+
languages = dataset.extract_languages(columns=["col2"])
92+
languages.sort()
93+
assert languages == ["cs"]
94+
95+
dataset = Dataset(df, column_types={"col1": "text", "col2": "text", "col3": "numeric"}, target="col4")
96+
languages = dataset.extract_languages(columns=["col1", "col2"])
97+
languages.sort()
98+
assert languages == ["cs", "en", "es", "fr"]
99+
100+
dataset = Dataset(df, column_types={"col1": "text", "col2": "text", "col3": "numeric"}, target="col4")
101+
languages = dataset.extract_languages(columns=["col3"])
102+
languages.sort()
103+
assert languages == []
104+
105+
dataset = Dataset(df, column_types={"col1": "text", "col2": "text", "col3": "numeric"}, target="col4")
106+
languages = dataset.extract_languages(columns=["col4"])
107+
languages.sort()
108+
assert languages == []

tests/llm/generators/test_base_llm_generators.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,111 @@ def test_generator_casts_based_on_column_types(Generator, args, kwargs):
159159

160160
assert dataset.column_types["question"] == "text"
161161
assert dataset.column_types["other_feature"] == "numeric"
162+
163+
164+
@pytest.mark.parametrize(
165+
"Generator,args,kwargs",
166+
[
167+
(BaseDataGenerator, [], {}),
168+
(ImplausibleDataGenerator, [], {}),
169+
(AdversarialDataGenerator, ["demo", "demo"], {}),
170+
],
171+
)
172+
def test_generator_adds_languages_requirements_in_prompts(Generator, args, kwargs):
173+
llm_client = Mock()
174+
llm_client.complete.side_effect = [
175+
LLMOutput(
176+
None,
177+
LLMFunctionCall(
178+
"generate_inputs",
179+
{
180+
"inputs": [
181+
{"question": "What is the meaning of life?", "other_feature": "test"},
182+
{
183+
"question": "Quel est le rôle des gaz à effet de serre dans le réchauffement climatique??",
184+
"other_feature": "pass",
185+
},
186+
]
187+
},
188+
),
189+
)
190+
]
191+
192+
model = Mock()
193+
model.meta.feature_names = ["question", "other_feature"]
194+
model.meta.name = "Mock model for test"
195+
model.meta.description = "This is a model for testing purposes"
196+
197+
generator = Generator(
198+
*args,
199+
**kwargs,
200+
llm_client=llm_client,
201+
llm_temperature=1.416,
202+
prompt="My custom prompt {model_name} {model_description} {feature_names}, with {num_samples} samples.\n",
203+
languages=["en", "fr"],
204+
)
205+
206+
dataset = generator.generate_dataset(model, num_samples=2)
207+
208+
llm_client.complete.assert_called_once()
209+
210+
called_prompt = llm_client.complete.call_args[1]["messages"][0]["content"]
211+
prompt_with_language_requirement = "My custom prompt Mock model for test This is a model for testing purposes question, other_feature, with 2 samples.\nYou must generate input using different languages among the following list: ['en', 'fr']."
212+
213+
assert isinstance(dataset, Dataset)
214+
assert called_prompt == prompt_with_language_requirement
215+
216+
217+
@pytest.mark.parametrize(
218+
"Generator,args,kwargs",
219+
[
220+
(BaseDataGenerator, [], {}),
221+
(ImplausibleDataGenerator, [], {}),
222+
(AdversarialDataGenerator, ["demo", "demo"], {}),
223+
],
224+
)
225+
def test_generator_empty_languages_requirements(Generator, args, kwargs):
226+
llm_client = Mock()
227+
llm_client.complete.side_effect = [
228+
LLMOutput(
229+
None,
230+
LLMFunctionCall(
231+
"generate_inputs",
232+
{
233+
"inputs": [
234+
{"question": "What is the meaning of life?", "other_feature": "test"},
235+
{
236+
"question": "Quel est le rôle des gaz à effet de serre dans le réchauffement climatique??",
237+
"other_feature": "pass",
238+
},
239+
]
240+
},
241+
),
242+
)
243+
]
244+
245+
model = Mock()
246+
model.meta.feature_names = ["question", "other_feature"]
247+
model.meta.name = "Mock model for test"
248+
model.meta.description = "This is a model for testing purposes"
249+
250+
generator = Generator(
251+
*args,
252+
**kwargs,
253+
llm_client=llm_client,
254+
llm_temperature=1.416,
255+
prompt="My custom prompt {model_name} {model_description} {feature_names}, with {num_samples} samples.\n",
256+
languages=[],
257+
)
258+
259+
dataset = generator.generate_dataset(model, num_samples=2)
260+
261+
llm_client.complete.assert_called_once()
262+
263+
called_prompt = llm_client.complete.call_args[1]["messages"][0]["content"]
264+
265+
assert isinstance(dataset, Dataset)
266+
assert (
267+
called_prompt
268+
== "My custom prompt Mock model for test This is a model for testing purposes question, other_feature, with 2 samples.\n"
269+
)

0 commit comments

Comments
 (0)