Skip to content
8 changes: 4 additions & 4 deletions giskard/scanner/robustness/text_transformations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
import json
import random
import re
from pathlib import Path

Expand Down Expand Up @@ -150,10 +149,11 @@ def make_perturbation(self, text):
class TextLanguageBasedTransformation(TextTransformation):
needs_dataset = True

def __init__(self, column):
def __init__(self, column, rng_seed=None):
super().__init__(column)
self._lang_dictionary = dict()
self._load_dictionaries()
self.rng = np.random.default_rng(seed=rng_seed)

def _load_dictionaries(self):
raise NotImplementedError()
Expand Down Expand Up @@ -236,7 +236,7 @@ def make_perturbation(self, row):
mask_value = f"__GSK__ENT__RELIGION__{n_list}__{n_term}__"
text, num_rep = re.subn(rf"\b{re.escape(term)}(s?)\b", rf"{mask_value}\1", text, flags=re.IGNORECASE)
if num_rep > 0:
i = (n_term + 1 + random.randrange(len(term_list) - 1)) % len(term_list)
i = (n_term + 1 + self.rng.choice(len(term_list) - 1)) % len(term_list)
replacement = term_list[i]
replacements.append((mask_value, replacement))

Expand Down Expand Up @@ -278,7 +278,7 @@ def make_perturbation(self, row):
)
if num_rep > 0:
r_income_type = "low-income" if income_type == "high-income" else "high-income"
replacement = random.choice(nationalities_word_dict[entity_type][r_income_type])
replacement = self.rng.choice(nationalities_word_dict[entity_type][r_income_type])
replacements.append((mask_value, replacement))

# Replace masks
Expand Down
26 changes: 10 additions & 16 deletions tests/scan/test_text_transformations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
import re

import pandas as pd
Expand Down Expand Up @@ -131,9 +130,8 @@ def test_religion_based_transformation():
)
from giskard.scanner.robustness.text_transformations import TextReligionTransformation

t = TextReligionTransformation(column="text")
t = TextReligionTransformation(column="text", rng_seed=10)

random.seed(0)
transformed = dataset.transform(t)
transformed_text = transformed.df.text.values

Expand All @@ -142,12 +140,12 @@ def test_religion_based_transformation():
"mois de ramadan."
)
assert (
transformed_text[1] == "Une partie des chrétiens commémorent ce vendredi 5 mai la naissance, l’éveil et la "
"mort de muhammad, dit « le Bouddha »"
transformed_text[1] == "Une partie des hindous commémorent ce vendredi 5 mai la naissance, l’éveil et la "
"mort de abraham, dit « le Bouddha »"
)
assert (
transformed_text[2] == "Signs have also been placed in the direction of kumbh mela along one of the Peak "
"District’s most popular hiking routes, Cave Dale, to help christians combine prayer "
"District’s most popular hiking routes, Cave Dale, to help jews combine prayer "
"with enjoying the outdoors."
)
assert (
Expand All @@ -157,9 +155,6 @@ def test_religion_based_transformation():


def test_country_based_transformation():
import random

random.seed(10)
dataset = _dataset_from_dict(
{
"text": [
Expand All @@ -173,31 +168,30 @@ def test_country_based_transformation():
)
from giskard.scanner.robustness.text_transformations import TextNationalityTransformation

t = TextNationalityTransformation(column="text")
t = TextNationalityTransformation(column="text", rng_seed=0)

transformed = dataset.transform(t)
transformed_text = transformed.df.text.values

assert (
transformed_text[0] == "Les musulmans de Eswatini fêtent vendredi 21 avril la fin du "
transformed_text[0] == "Les musulmans de Saint Thomas et Prince fêtent vendredi 21 avril la fin du "
"jeûne pratiqué durant le mois de ramadan."
)
assert transformed_text[1] == "Des incendies ravagent l'Congo depuis la fin août 2019."
assert transformed_text[1] == "Des incendies ravagent l'Liban depuis la fin août 2019."
assert (
transformed_text[2] == "Bali is an Libyan island known for its forested volcanic mountains, iconic"
transformed_text[2] == "Bali is an Singaporean island known for its forested volcanic mountains, iconic"
" rice paddies, beaches and coral reefs. The island is home to religious sites "
"such as cliffside Uluwatu Temple"
)
assert (
transformed_text[3]
== "President Joe Biden visited U.S.'s capital for the first time since Nigeria invaded the country"
== "President Joe Biden visited UAE's capital for the first time since Syria invaded the country"
)


def test_country_based_transformation_edge_cases():
from giskard.scanner.robustness.text_transformations import TextNationalityTransformation

random.seed(0)
df = pd.DataFrame(
{
"text": [
Expand All @@ -210,7 +204,7 @@ def test_country_based_transformation_edge_cases():
}
)

t = TextNationalityTransformation(column="text")
t = TextNationalityTransformation(column="text", rng_seed=0)

t1 = t.make_perturbation(df.iloc[0])
t2 = t.make_perturbation(df.iloc[1])
Expand Down