@@ -172,7 +172,7 @@ class DataFactory:
172172    data_sources : Dict [str , BaseDataset ] =  BaseDataset .data_sources 
173173    CURATED_BIAS_DATASETS  =  ["BoolQ" , "XSum" ]
174174
175-     def  __init__ (self , file_path : dict , task : TaskManager , ** kwargs ) ->  None :
175+     def  __init__ (self , file_path : Union [ str ,  dict ] , task : TaskManager , ** kwargs ) ->  None :
176176        """Initializes DataFactory object. 
177177
178178        Args: 
@@ -232,6 +232,9 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
232232        self .init_cls : BaseDataset  =  None 
233233        self .kwargs  =  kwargs 
234234
235+         if  self .task  ==  "ner"  and  "doc_wise"  in  self ._custom_label :
236+             self .kwargs .update ({"doc_wise" : self ._custom_label .get ("doc_wise" , False )})
237+ 
235238    def  load_raw (self ):
236239        """Loads the data into a raw format""" 
237240        self .init_cls  =  self .data_sources [self .file_ext .replace ("." , "" )](
@@ -257,7 +260,9 @@ def load(self) -> List[Sample]:
257260            return  DataFactory .load_curated_bias (self ._file_path )
258261        else :
259262            self .init_cls  =  self .data_sources [self .file_ext .replace ("." , "" )](
260-                 self ._file_path , task = self .task , ** self .kwargs 
263+                 self ._file_path ,
264+                 task = self .task ,
265+                 ** self .kwargs ,
261266            )
262267
263268        loaded_data  =  self .init_cls .load_data ()
@@ -425,7 +430,9 @@ class ConllDataset(BaseDataset):
425430
426431    COLUMN_NAMES  =  {task : COLUMN_MAPPER [task ] for  task  in  supported_tasks }
427432
428-     def  __init__ (self , file_path : str , task : TaskManager ) ->  None :
433+     def  __init__ (
434+         self , file_path : Union [str , Dict [str , str ]], task : TaskManager , ** kwargs 
435+     ) ->  None :
429436        """Initializes ConllDataset object. 
430437
431438        Args: 
@@ -434,7 +441,7 @@ def __init__(self, file_path: str, task: TaskManager) -> None:
434441        """ 
435442        super ().__init__ ()
436443        self ._file_path  =  file_path 
437- 
444+          self . doc_wise   =   kwargs . get ( "doc_wise" )  if   "doc_wise"   in   kwargs   else   False 
438445        self .task  =  task 
439446
440447    def  load_raw_data (self ) ->  List [Dict ]:
@@ -495,42 +502,42 @@ def load_data(self) -> List[NERSample]:
495502            ]
496503            for  d_id , doc  in  enumerate (docs ):
497504                #  file content to sentence split 
498-                 sentences  =  re .split (r"\n\n|\n\s+\n" , doc .strip ())
499- 
500-                 if  sentences  ==  ["" ]:
501-                     continue 
502- 
503-                 for  sent  in  sentences :
504-                     # sentence string to token level split 
505-                     tokens  =  sent .strip ().split ("\n " )
506- 
507-                     # get annotations from token level split 
508-                     valid_tokens , token_list  =  self .__token_validation (tokens )
509- 
510-                     if  not  valid_tokens :
511-                         logging .warning (Warnings .W004 (sent = sent ))
512-                         continue 
513- 
514-                     #  get token and labels from the split 
505+                 if  self .doc_wise :
506+                     tokens  =  doc .strip ().split ("\n " )
515507                    ner_labels  =  []
516508                    cursor  =  0 
517-                     for  split  in  token_list :
518-                         ner_labels .append (
519-                             NERPrediction .from_span (
520-                                 entity = split [- 1 ],
521-                                 word = split [0 ],
509+ 
510+                     for  token  in  tokens :
511+                         token_list  =  token .split ()
512+ 
513+                         if  len (token_list ) ==  0 :
514+                             pred  =  NERPrediction .from_span (
515+                                 entity = "" ,
516+                                 word = "\n " ,
522517                                start = cursor ,
523-                                 end = cursor  +  len (split [0 ]),
524-                                 doc_id = d_id ,
525-                                 doc_name = (
526-                                     docs_strings [d_id ] if  len (docs_strings ) >  0  else  "" 
527-                                 ),
528-                                 pos_tag = split [1 ],
529-                                 chunk_tag = split [2 ],
518+                                 end = cursor ,
519+                                 pos_tag = "" ,
520+                                 chunk_tag = "" ,
530521                            )
531-                         )
532-                         # +1 to account for the white space 
533-                         cursor  +=  len (split [0 ]) +  1 
522+                             ner_labels .append (pred )
523+                         else :
524+                             ner_labels .append (
525+                                 NERPrediction .from_span (
526+                                     entity = token_list [- 1 ],
527+                                     word = token_list [0 ],
528+                                     start = cursor ,
529+                                     end = cursor  +  len (token_list [0 ]),
530+                                     doc_id = d_id ,
531+                                     doc_name = (
532+                                         docs_strings [d_id ]
533+                                         if  len (docs_strings ) >  0 
534+                                         else  "" 
535+                                     ),
536+                                     pos_tag = token_list [1 ],
537+                                     chunk_tag = token_list [2 ],
538+                                 )
539+                             )
540+                             cursor  +=  len (token_list [0 ]) +  1 
534541
535542                    original  =  " " .join ([label .span .word  for  label  in  ner_labels ])
536543
@@ -540,6 +547,55 @@ def load_data(self) -> List[NERSample]:
540547                            expected_results = NEROutput (predictions = ner_labels ),
541548                        )
542549                    )
550+ 
551+                 else :
552+                     sentences  =  re .split (r"\n\n|\n\s+\n" , doc .strip ())
553+ 
554+                     if  sentences  ==  ["" ]:
555+                         continue 
556+ 
557+                     for  sent  in  sentences :
558+                         # sentence string to token level split 
559+                         tokens  =  sent .strip ().split ("\n " )
560+ 
561+                         # get annotations from token level split 
562+                         valid_tokens , token_list  =  self .__token_validation (tokens )
563+ 
564+                         if  not  valid_tokens :
565+                             logging .warning (Warnings .W004 (sent = sent ))
566+                             continue 
567+ 
568+                         #  get token and labels from the split 
569+                         ner_labels  =  []
570+                         cursor  =  0 
571+                         for  split  in  token_list :
572+                             ner_labels .append (
573+                                 NERPrediction .from_span (
574+                                     entity = split [- 1 ],
575+                                     word = split [0 ],
576+                                     start = cursor ,
577+                                     end = cursor  +  len (split [0 ]),
578+                                     doc_id = d_id ,
579+                                     doc_name = (
580+                                         docs_strings [d_id ]
581+                                         if  len (docs_strings ) >  0 
582+                                         else  "" 
583+                                     ),
584+                                     pos_tag = split [1 ],
585+                                     chunk_tag = split [2 ],
586+                                 )
587+                             )
588+                             # +1 to account for the white space 
589+                             cursor  +=  len (split [0 ]) +  1 
590+ 
591+                         original  =  " " .join ([label .span .word  for  label  in  ner_labels ])
592+ 
593+                         data .append (
594+                             self .task .get_sample_class (
595+                                 original = original ,
596+                                 expected_results = NEROutput (predictions = ner_labels ),
597+                             )
598+                         )
543599        self .dataset_size  =  len (data )
544600        return  data 
545601
0 commit comments