Skip to content

Commit 9db44ce

Browse files
committed
improve prompt translation support for Conversation related types (#1441)
2 parents 4430702 + e2d3667 commit 9db44ce

File tree

2 files changed

+45
-28
lines changed

2 files changed

+45
-28
lines changed

garak/probes/base.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,29 @@ def __init__(self, config_root=_config):
7272
"""
7373
self._load_config(config_root)
7474
self.probename = str(self.__class__).split("'")[1]
75-
75+
7676
# Handle deprecated recommended_detector migration
7777
if (
7878
self.primary_detector is None
7979
and self.recommended_detector != ["always.Fail"]
8080
and len(self.recommended_detector) > 0
8181
):
8282
from garak import command
83+
8384
command.deprecation_notice(
8485
f"recommended_detector in probe {self.probename}",
8586
"0.9.0.6",
8687
logging=logging,
8788
)
8889
self.primary_detector = self.recommended_detector[0]
8990
if len(self.recommended_detector) > 1:
90-
existing_extended = list(self.extended_detectors) if self.extended_detectors else []
91-
self.extended_detectors = existing_extended + list(self.recommended_detector[1:])
92-
91+
existing_extended = (
92+
list(self.extended_detectors) if self.extended_detectors else []
93+
)
94+
self.extended_detectors = existing_extended + list(
95+
self.recommended_detector[1:]
96+
)
97+
9398
if hasattr(_config.system, "verbose") and _config.system.verbose > 0:
9499
print(
95100
f"loading {Style.BRIGHT}{Fore.LIGHTYELLOW_EX}probe: {Style.RESET_ALL}{self.probename}"
@@ -362,11 +367,9 @@ def probe(self, generator) -> Iterable[garak.attempt.Attempt]:
362367

363368
# build list of attempts
364369
attempts_todo: Iterable[garak.attempt.Attempt] = []
365-
prompts = list(
370+
prompts = copy.deepcopy(
366371
self.prompts
367-
) # will this still make a copy if prompts are `Message` objects?
368-
lang = self.lang
369-
# account for visual jailbreak until Turn/Conversation is supported
372+
) # make a copy to avoid mutating source list
370373
preparation_bar = tqdm.tqdm(
371374
total=len(prompts),
372375
leave=False,
@@ -387,35 +390,47 @@ def probe(self, generator) -> Iterable[garak.attempt.Attempt]:
387390
for prompt in prompts:
388391
if isinstance(prompt, garak.attempt.Message):
389392
prompt.text = self.langprovider.get_text(
390-
prompt.text, notify_callback=preparation_bar.update
391-
)
393+
[prompt.text], notify_callback=preparation_bar.update
394+
)[0]
392395
prompt.lang = self.langprovider.target_lang
393396
if isinstance(prompt, garak.attempt.Conversation):
394397
for turn in prompt.turns:
395398
msg = turn.content
396399
msg.text = self.langprovider.get_text(
397-
msg.text, notify_callback=preparation_bar.update
398-
)
400+
[msg.text], notify_callback=preparation_bar.update
401+
)[0]
399402
msg.lang = self.langprovider.target_lang
400403
lang = self.langprovider.target_lang
401404
preparation_bar.close()
402405
for seq, prompt in enumerate(prompts):
403-
notes = (
404-
{
405-
"pre_translation_prompt": garak.attempt.Conversation(
406-
[
407-
garak.attempt.Turn(
408-
"user",
409-
garak.attempt.Message(
410-
self.prompts[seq], lang=self.lang
411-
),
412-
)
413-
]
414-
)
415-
}
416-
if lang != self.lang
417-
else None
418-
)
406+
notes = None
407+
if lang != self.lang:
408+
pre_translation_prompt = copy.deepcopy(self.prompts[seq])
409+
if isinstance(pre_translation_prompt, str):
410+
notes = {
411+
"pre_translation_prompt": garak.attempt.Conversation(
412+
[
413+
garak.attempt.Turn(
414+
"user",
415+
garak.attempt.Message(
416+
pre_translation_prompt, lang=self.lang
417+
),
418+
)
419+
]
420+
)
421+
}
422+
elif isinstance(pre_translation_prompt, garak.attempt.Message):
423+
pre_translation_prompt.lang = self.lang
424+
notes = {
425+
"pre_translation_prompt": garak.attempt.Conversation(
426+
[pre_translation_prompt]
427+
)
428+
}
429+
elif isinstance(pre_translation_prompt, garak.attempt.Message):
430+
for turn in pre_translation_prompt.turns:
431+
turn.context.lang = self.lang
432+
notes = {"pre_translation_prompt": pre_translation_prompt}
433+
419434
attempts_todo.append(self._mint_attempt(prompt, seq, notes, lang))
420435

421436
# buff hook

tests/langservice/probes/test_probes_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def test_multi_modal_probe_translation(classname, mocker):
240240
expected_provision_calls += len(probe_instance.attempt_descrs) * 2
241241

242242
assert prompt_mock.call_count == expected_provision_calls
243+
for prompt in probe_instance.prompts:
244+
assert isinstance(prompt.text, str)
243245

244246

245247
@pytest.mark.parametrize("classname", PROBES)

0 commit comments

Comments
 (0)