diff --git a/giskard/scanner/robustness/text_transformations.py b/giskard/scanner/robustness/text_transformations.py index 93e4caeb35..7de6f614e6 100644 --- a/giskard/scanner/robustness/text_transformations.py +++ b/giskard/scanner/robustness/text_transformations.py @@ -1,6 +1,5 @@ import itertools import json -import random import re from pathlib import Path @@ -70,7 +69,7 @@ def execute(self, data: pd.DataFrame) -> pd.DataFrame: class TextTypoTransformation(TextTransformation): name = "Add typos" - def __init__(self, column, rate=0.05, min_length=10, rng_seed=None): + def __init__(self, column, rate=0.05, min_length=10, rng_seed=1729): super().__init__(column) from .entity_swap import typos @@ -150,10 +149,11 @@ def make_perturbation(self, text): class TextLanguageBasedTransformation(TextTransformation): needs_dataset = True - def __init__(self, column): + def __init__(self, column, rng_seed=1729): super().__init__(column) self._lang_dictionary = dict() self._load_dictionaries() + self.rng = np.random.default_rng(seed=rng_seed) def _load_dictionaries(self): raise NotImplementedError() @@ -236,7 +236,7 @@ def make_perturbation(self, row): mask_value = f"__GSK__ENT__RELIGION__{n_list}__{n_term}__" text, num_rep = re.subn(rf"\b{re.escape(term)}(s?)\b", rf"{mask_value}\1", text, flags=re.IGNORECASE) if num_rep > 0: - i = (n_term + 1 + random.randrange(len(term_list) - 1)) % len(term_list) + i = (n_term + 1 + self.rng.choice(len(term_list) - 1)) % len(term_list) replacement = term_list[i] replacements.append((mask_value, replacement)) @@ -278,7 +278,7 @@ def make_perturbation(self, row): ) if num_rep > 0: r_income_type = "low-income" if income_type == "high-income" else "high-income" - replacement = random.choice(nationalities_word_dict[entity_type][r_income_type]) + replacement = self.rng.choice(nationalities_word_dict[entity_type][r_income_type]) replacements.append((mask_value, replacement)) # Replace masks diff --git a/tests/scan/test_text_transformations.py b/tests/scan/test_text_transformations.py index dd83da5ac0..03981e4a2c 100644 --- a/tests/scan/test_text_transformations.py +++ b/tests/scan/test_text_transformations.py @@ -1,4 +1,3 @@ -import random import re import pandas as pd @@ -131,9 +130,8 @@ def test_religion_based_transformation(): ) from giskard.scanner.robustness.text_transformations import TextReligionTransformation - t = TextReligionTransformation(column="text") + t = TextReligionTransformation(column="text", rng_seed=10) - random.seed(0) transformed = dataset.transform(t) transformed_text = transformed.df.text.values @@ -142,12 +140,12 @@ def test_religion_based_transformation(): "mois de ramadan." ) assert ( - transformed_text[1] == "Une partie des chrétiens commémorent ce vendredi 5 mai la naissance, l’éveil et la " - "mort de muhammad, dit « le Bouddha »" + transformed_text[1] == "Une partie des hindous commémorent ce vendredi 5 mai la naissance, l’éveil et la " + "mort de abraham, dit « le Bouddha »" ) assert ( transformed_text[2] == "Signs have also been placed in the direction of kumbh mela along one of the Peak " - "District’s most popular hiking routes, Cave Dale, to help christians combine prayer " + "District’s most popular hiking routes, Cave Dale, to help jews combine prayer " "with enjoying the outdoors." ) assert ( @@ -157,9 +155,6 @@ def test_religion_based_transformation(): def test_country_based_transformation(): - import random - - random.seed(10) dataset = _dataset_from_dict( { "text": [ @@ -173,31 +168,30 @@ def test_country_based_transformation(): ) from giskard.scanner.robustness.text_transformations import TextNationalityTransformation - t = TextNationalityTransformation(column="text") + t = TextNationalityTransformation(column="text", rng_seed=0) transformed = dataset.transform(t) transformed_text = transformed.df.text.values assert ( - transformed_text[0] == "Les musulmans de Eswatini fêtent vendredi 21 avril la fin du " + transformed_text[0] == "Les musulmans de Saint Thomas et Prince fêtent vendredi 21 avril la fin du " "jeûne pratiqué durant le mois de ramadan." ) - assert transformed_text[1] == "Des incendies ravagent l'Congo depuis la fin août 2019." + assert transformed_text[1] == "Des incendies ravagent l'Liban depuis la fin août 2019." assert ( - transformed_text[2] == "Bali is an Libyan island known for its forested volcanic mountains, iconic" + transformed_text[2] == "Bali is an Singaporean island known for its forested volcanic mountains, iconic" " rice paddies, beaches and coral reefs. The island is home to religious sites " "such as cliffside Uluwatu Temple" ) assert ( transformed_text[3] - == "President Joe Biden visited U.S.'s capital for the first time since Nigeria invaded the country" + == "President Joe Biden visited UAE's capital for the first time since Syria invaded the country" ) def test_country_based_transformation_edge_cases(): from giskard.scanner.robustness.text_transformations import TextNationalityTransformation - random.seed(0) df = pd.DataFrame( { "text": [ @@ -210,7 +204,7 @@ def test_country_based_transformation_edge_cases(): } ) - t = TextNationalityTransformation(column="text") + t = TextNationalityTransformation(column="text", rng_seed=0) t1 = t.make_perturbation(df.iloc[0]) t2 = t.make_perturbation(df.iloc[1])