Skip to content

Commit ad84608

Browse files
authored
Merge pull request #1674 from Giskard-AI/GSK-2332-fix-stochastic-behavior-ethical-bias-detector
Text transformations used by EthicalBiasDetector are not deterministic [GSK-2332]
2 parents 2a6dfaf + ed81295 commit ad84608

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed

giskard/scanner/robustness/text_transformations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22
import json
3-
import random
43
import re
54
from pathlib import Path
65

@@ -70,7 +69,7 @@ def execute(self, data: pd.DataFrame) -> pd.DataFrame:
7069
class TextTypoTransformation(TextTransformation):
7170
name = "Add typos"
7271

73-
def __init__(self, column, rate=0.05, min_length=10, rng_seed=None):
72+
def __init__(self, column, rate=0.05, min_length=10, rng_seed=1729):
7473
super().__init__(column)
7574
from .entity_swap import typos
7675

@@ -150,10 +149,11 @@ def make_perturbation(self, text):
150149
class TextLanguageBasedTransformation(TextTransformation):
151150
needs_dataset = True
152151

153-
def __init__(self, column):
152+
def __init__(self, column, rng_seed=1729):
154153
super().__init__(column)
155154
self._lang_dictionary = dict()
156155
self._load_dictionaries()
156+
self.rng = np.random.default_rng(seed=rng_seed)
157157

158158
def _load_dictionaries(self):
159159
raise NotImplementedError()
@@ -236,7 +236,7 @@ def make_perturbation(self, row):
236236
mask_value = f"__GSK__ENT__RELIGION__{n_list}__{n_term}__"
237237
text, num_rep = re.subn(rf"\b{re.escape(term)}(s?)\b", rf"{mask_value}\1", text, flags=re.IGNORECASE)
238238
if num_rep > 0:
239-
i = (n_term + 1 + random.randrange(len(term_list) - 1)) % len(term_list)
239+
i = (n_term + 1 + self.rng.choice(len(term_list) - 1)) % len(term_list)
240240
replacement = term_list[i]
241241
replacements.append((mask_value, replacement))
242242

@@ -278,7 +278,7 @@ def make_perturbation(self, row):
278278
)
279279
if num_rep > 0:
280280
r_income_type = "low-income" if income_type == "high-income" else "high-income"
281-
replacement = random.choice(nationalities_word_dict[entity_type][r_income_type])
281+
replacement = self.rng.choice(nationalities_word_dict[entity_type][r_income_type])
282282
replacements.append((mask_value, replacement))
283283

284284
# Replace masks

tests/scan/test_text_transformations.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import random
21
import re
32

43
import pandas as pd
@@ -131,9 +130,8 @@ def test_religion_based_transformation():
131130
)
132131
from giskard.scanner.robustness.text_transformations import TextReligionTransformation
133132

134-
t = TextReligionTransformation(column="text")
133+
t = TextReligionTransformation(column="text", rng_seed=10)
135134

136-
random.seed(0)
137135
transformed = dataset.transform(t)
138136
transformed_text = transformed.df.text.values
139137

@@ -142,12 +140,12 @@ def test_religion_based_transformation():
142140
"mois de ramadan."
143141
)
144142
assert (
145-
transformed_text[1] == "Une partie des chrétiens commémorent ce vendredi 5 mai la naissance, l’éveil et la "
146-
"mort de muhammad, dit « le Bouddha »"
143+
transformed_text[1] == "Une partie des hindous commémorent ce vendredi 5 mai la naissance, l’éveil et la "
144+
"mort de abraham, dit « le Bouddha »"
147145
)
148146
assert (
149147
transformed_text[2] == "Signs have also been placed in the direction of kumbh mela along one of the Peak "
150-
"District’s most popular hiking routes, Cave Dale, to help christians combine prayer "
148+
"District’s most popular hiking routes, Cave Dale, to help jews combine prayer "
151149
"with enjoying the outdoors."
152150
)
153151
assert (
@@ -157,9 +155,6 @@ def test_religion_based_transformation():
157155

158156

159157
def test_country_based_transformation():
160-
import random
161-
162-
random.seed(10)
163158
dataset = _dataset_from_dict(
164159
{
165160
"text": [
@@ -173,31 +168,30 @@ def test_country_based_transformation():
173168
)
174169
from giskard.scanner.robustness.text_transformations import TextNationalityTransformation
175170

176-
t = TextNationalityTransformation(column="text")
171+
t = TextNationalityTransformation(column="text", rng_seed=0)
177172

178173
transformed = dataset.transform(t)
179174
transformed_text = transformed.df.text.values
180175

181176
assert (
182-
transformed_text[0] == "Les musulmans de Eswatini fêtent vendredi 21 avril la fin du "
177+
transformed_text[0] == "Les musulmans de Saint Thomas et Prince fêtent vendredi 21 avril la fin du "
183178
"jeûne pratiqué durant le mois de ramadan."
184179
)
185-
assert transformed_text[1] == "Des incendies ravagent l'Congo depuis la fin août 2019."
180+
assert transformed_text[1] == "Des incendies ravagent l'Liban depuis la fin août 2019."
186181
assert (
187-
transformed_text[2] == "Bali is an Libyan island known for its forested volcanic mountains, iconic"
182+
transformed_text[2] == "Bali is an Singaporean island known for its forested volcanic mountains, iconic"
188183
" rice paddies, beaches and coral reefs. The island is home to religious sites "
189184
"such as cliffside Uluwatu Temple"
190185
)
191186
assert (
192187
transformed_text[3]
193-
== "President Joe Biden visited U.S.'s capital for the first time since Nigeria invaded the country"
188+
== "President Joe Biden visited UAE's capital for the first time since Syria invaded the country"
194189
)
195190

196191

197192
def test_country_based_transformation_edge_cases():
198193
from giskard.scanner.robustness.text_transformations import TextNationalityTransformation
199194

200-
random.seed(0)
201195
df = pd.DataFrame(
202196
{
203197
"text": [
@@ -210,7 +204,7 @@ def test_country_based_transformation_edge_cases():
210204
}
211205
)
212206

213-
t = TextNationalityTransformation(column="text")
207+
t = TextNationalityTransformation(column="text", rng_seed=0)
214208

215209
t1 = t.make_perturbation(df.iloc[0])
216210
t2 = t.make_perturbation(df.iloc[1])

0 commit comments

Comments
 (0)