Skip to content

Commit b35c28a

Browse files
Merge pull request #1103 from JohnSnowLabs/patch/2.3.1
Patch/2.3.1
2 parents 78cb31f + 134de82 commit b35c28a

File tree

7 files changed

+157
-87
lines changed

7 files changed

+157
-87
lines changed

.github/workflows/build_and_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on:
88
pull_request:
99
branches:
1010
- "release/*"
11+
- "patch/*"
1112
- "main"
1213

1314
jobs:

langtest/augmentation/base.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from langtest.utils.custom_types.predictions import NERPrediction, SequenceLabel
2020
from langtest.utils.custom_types.sample import NERSample
2121
from langtest.tasks import TaskManager
22-
from ..utils.lib_manager import try_import_lib
2322
from ..errors import Errors
2423

2524

@@ -358,6 +357,9 @@ def __init__(
358357
# Extend the existing templates list
359358

360359
self.__templates.extend(generated_templates[:num_extra_templates])
360+
except ModuleNotFoundError:
361+
raise ImportError(Errors.E097())
362+
361363
except Exception as e_msg:
362364
raise Errors.E095(e=e_msg)
363365

@@ -606,19 +608,19 @@ def __generate_templates(
606608
num_extra_templates: int,
607609
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
608610
) -> List[str]:
609-
if try_import_lib("openai"):
610-
from langtest.augmentation.utils import (
611-
generate_templates_azoi, # azoi means Azure OpenAI
612-
generate_templates_openai,
613-
)
611+
"""This method is used to generate extra templates from a given template."""
612+
from langtest.augmentation.utils import (
613+
generate_templates_azoi, # azoi means Azure OpenAI
614+
generate_templates_openai,
615+
)
614616

615-
params = model_config.copy() if model_config else {}
617+
params = model_config.copy() if model_config else {}
616618

617-
if model_config and model_config.get("provider") == "openai":
618-
return generate_templates_openai(template, num_extra_templates, params)
619+
if model_config and model_config.get("provider") == "openai":
620+
return generate_templates_openai(template, num_extra_templates, params)
619621

620-
elif model_config and model_config.get("provider") == "azure":
621-
return generate_templates_azoi(template, num_extra_templates, params)
622+
elif model_config and model_config.get("provider") == "azure":
623+
return generate_templates_azoi(template, num_extra_templates, params)
622624

623-
else:
624-
return generate_templates_openai(template, num_extra_templates)
625+
else:
626+
return generate_templates_openai(template, num_extra_templates)

langtest/augmentation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@ class OpenAIConfig(TypedDict):
1919
class AzureOpenAIConfig(TypedDict):
2020
"""Azure OpenAI Configuration for API Key and Provider."""
2121

22-
from openai.lib.azure import AzureADTokenProvider
23-
2422
azure_endpoint: str
2523
api_version: str
2624
api_key: str
2725
provider: str
2826
azure_deployment: Union[str, None] = None
2927
azure_ad_token: Union[str, None] = (None,)
30-
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
28+
azure_ad_token_provider = (None,)
3129
organization: Union[str, None] = (None,)
3230

3331

@@ -76,6 +74,7 @@ def generate_templates_azoi(
7674
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
7775
):
7876
"""Generate new templates based on the provided template using Azure OpenAI API."""
77+
7978
import openai
8079

8180
if "provider" in model_config:
@@ -139,6 +138,7 @@ def generate_templates_openai(
139138
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
140139
):
141140
"""Generate new templates based on the provided template using OpenAI API."""
141+
142142
import openai
143143

144144
if "provider" in model_config:

langtest/datahandler/datasource.py

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class DataFactory:
172172
data_sources: Dict[str, BaseDataset] = BaseDataset.data_sources
173173
CURATED_BIAS_DATASETS = ["BoolQ", "XSum"]
174174

175-
def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
175+
def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None:
176176
"""Initializes DataFactory object.
177177
178178
Args:
@@ -232,6 +232,9 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
232232
self.init_cls: BaseDataset = None
233233
self.kwargs = kwargs
234234

235+
if self.task == "ner" and "doc_wise" in self._custom_label:
236+
self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)})
237+
235238
def load_raw(self):
236239
"""Loads the data into a raw format"""
237240
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
@@ -257,7 +260,9 @@ def load(self) -> List[Sample]:
257260
return DataFactory.load_curated_bias(self._file_path)
258261
else:
259262
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
260-
self._file_path, task=self.task, **self.kwargs
263+
self._file_path,
264+
task=self.task,
265+
**self.kwargs,
261266
)
262267

263268
loaded_data = self.init_cls.load_data()
@@ -425,7 +430,9 @@ class ConllDataset(BaseDataset):
425430

426431
COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}
427432

428-
def __init__(self, file_path: str, task: TaskManager) -> None:
433+
def __init__(
434+
self, file_path: Union[str, Dict[str, str]], task: TaskManager, **kwargs
435+
) -> None:
429436
"""Initializes ConllDataset object.
430437
431438
Args:
@@ -434,7 +441,7 @@ def __init__(self, file_path: str, task: TaskManager) -> None:
434441
"""
435442
super().__init__()
436443
self._file_path = file_path
437-
444+
self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False
438445
self.task = task
439446

440447
def load_raw_data(self) -> List[Dict]:
@@ -495,42 +502,42 @@ def load_data(self) -> List[NERSample]:
495502
]
496503
for d_id, doc in enumerate(docs):
497504
# file content to sentence split
498-
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())
499-
500-
if sentences == [""]:
501-
continue
502-
503-
for sent in sentences:
504-
# sentence string to token level split
505-
tokens = sent.strip().split("\n")
506-
507-
# get annotations from token level split
508-
valid_tokens, token_list = self.__token_validation(tokens)
509-
510-
if not valid_tokens:
511-
logging.warning(Warnings.W004(sent=sent))
512-
continue
513-
514-
# get token and labels from the split
505+
if self.doc_wise:
506+
tokens = doc.strip().split("\n")
515507
ner_labels = []
516508
cursor = 0
517-
for split in token_list:
518-
ner_labels.append(
519-
NERPrediction.from_span(
520-
entity=split[-1],
521-
word=split[0],
509+
510+
for token in tokens:
511+
token_list = token.split()
512+
513+
if len(token_list) == 0:
514+
pred = NERPrediction.from_span(
515+
entity="",
516+
word="\n",
522517
start=cursor,
523-
end=cursor + len(split[0]),
524-
doc_id=d_id,
525-
doc_name=(
526-
docs_strings[d_id] if len(docs_strings) > 0 else ""
527-
),
528-
pos_tag=split[1],
529-
chunk_tag=split[2],
518+
end=cursor,
519+
pos_tag="",
520+
chunk_tag="",
530521
)
531-
)
532-
# +1 to account for the white space
533-
cursor += len(split[0]) + 1
522+
ner_labels.append(pred)
523+
else:
524+
ner_labels.append(
525+
NERPrediction.from_span(
526+
entity=token_list[-1],
527+
word=token_list[0],
528+
start=cursor,
529+
end=cursor + len(token_list[0]),
530+
doc_id=d_id,
531+
doc_name=(
532+
docs_strings[d_id]
533+
if len(docs_strings) > 0
534+
else ""
535+
),
536+
pos_tag=token_list[1],
537+
chunk_tag=token_list[2],
538+
)
539+
)
540+
cursor += len(token_list[0]) + 1
534541

535542
original = " ".join([label.span.word for label in ner_labels])
536543

@@ -540,6 +547,55 @@ def load_data(self) -> List[NERSample]:
540547
expected_results=NEROutput(predictions=ner_labels),
541548
)
542549
)
550+
551+
else:
552+
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())
553+
554+
if sentences == [""]:
555+
continue
556+
557+
for sent in sentences:
558+
# sentence string to token level split
559+
tokens = sent.strip().split("\n")
560+
561+
# get annotations from token level split
562+
valid_tokens, token_list = self.__token_validation(tokens)
563+
564+
if not valid_tokens:
565+
logging.warning(Warnings.W004(sent=sent))
566+
continue
567+
568+
# get token and labels from the split
569+
ner_labels = []
570+
cursor = 0
571+
for split in token_list:
572+
ner_labels.append(
573+
NERPrediction.from_span(
574+
entity=split[-1],
575+
word=split[0],
576+
start=cursor,
577+
end=cursor + len(split[0]),
578+
doc_id=d_id,
579+
doc_name=(
580+
docs_strings[d_id]
581+
if len(docs_strings) > 0
582+
else ""
583+
),
584+
pos_tag=split[1],
585+
chunk_tag=split[2],
586+
)
587+
)
588+
# +1 to account for the white space
589+
cursor += len(split[0]) + 1
590+
591+
original = " ".join([label.span.word for label in ner_labels])
592+
593+
data.append(
594+
self.task.get_sample_class(
595+
original=original,
596+
expected_results=NEROutput(predictions=ner_labels),
597+
)
598+
)
543599
self.dataset_size = len(data)
544600
return data
545601

langtest/datahandler/format.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -195,43 +195,53 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st
195195
test_case = sample.test_case
196196
original = sample.original
197197
if test_case:
198-
test_case_items = test_case.split()
199-
norm_test_case_items = test_case.lower().split()
200-
norm_original_items = original.lower().split()
198+
test_case_items = test_case.split(" ")
199+
norm_test_case_items = test_case.lower().split(" ")
200+
norm_original_items = original.lower().split(" ")
201201
temp_len = 0
202202
for jdx, item in enumerate(norm_test_case_items):
203-
try:
204-
if item in norm_original_items and jdx >= norm_original_items.index(
205-
item
206-
):
207-
oitem_index = norm_original_items.index(item)
208-
j = sample.expected_results.predictions[oitem_index + temp_len]
209-
if temp_id != j.doc_id and jdx == 0:
210-
text += f"{j.doc_name}\n\n"
211-
temp_id = j.doc_id
212-
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
213-
norm_original_items.pop(oitem_index)
214-
temp_len += 1
215-
else:
216-
o_item = sample.expected_results.predictions[jdx].span.word
217-
letters_count = len(set(item) - set(o_item))
203+
if test_case_items[jdx] == "\n":
204+
text += "\n" # add a newline character after each sentence
205+
else:
206+
try:
218207
if (
219-
len(norm_test_case_items) == len(original.lower().split())
220-
or letters_count < 2
208+
item in norm_original_items
209+
and jdx >= norm_original_items.index(item)
221210
):
222-
tl = sample.expected_results.predictions[jdx]
223-
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
211+
oitem_index = norm_original_items.index(item)
212+
j = sample.expected_results.predictions[
213+
oitem_index + temp_len
214+
]
215+
if temp_id != j.doc_id and jdx == 0:
216+
text += f"{j.doc_name}\n\n"
217+
temp_id = j.doc_id
218+
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
219+
norm_original_items.pop(oitem_index)
220+
temp_len += 1
224221
else:
225-
text += f"{test_case_items[jdx]} -X- -X- O\n"
226-
except IndexError:
227-
text += f"{test_case_items[jdx]} -X- -X- O\n"
222+
o_item = sample.expected_results.predictions[jdx].span.word
223+
letters_count = len(set(item) - set(o_item))
224+
if (
225+
len(norm_test_case_items)
226+
== len(original.lower().split(" "))
227+
or letters_count < 2
228+
):
229+
tl = sample.expected_results.predictions[jdx]
230+
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
231+
else:
232+
text += f"{test_case_items[jdx]} -X- -X- O\n"
233+
except IndexError:
234+
text += f"{test_case_items[jdx]} -X- -X- O\n"
228235

229236
else:
230237
for j in sample.expected_results.predictions:
231-
if temp_id != j.doc_id:
232-
text += f"{j.doc_name}\n\n"
233-
temp_id = j.doc_id
234-
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
238+
if j.span.word == "\n":
239+
text += "\n"
240+
else:
241+
if temp_id != j.doc_id:
242+
text += f"{j.doc_name}\n\n"
243+
temp_id = j.doc_id
244+
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
235245

236246
return text, temp_id
237247

langtest/errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ class Errors(metaclass=ErrorsWithCodes):
275275
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
276276
E095 = ("Failed to make API request: {e}")
277277
E096 = ("Failed to generate the templates in Augmentation: {msg}")
278+
E097 = ("Failed to load openai. Please install it using `pip install openai`")
278279

279280

280281
class ColumnNameError(Exception):

0 commit comments

Comments
 (0)