Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,19 @@ def get_dataset_loader(
return dataset_loader


def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data=False):
"""Load and tokenize raw data.
def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data=False):
"""Load and tokenize raw data in file or dataframe.

Args:
path (str): Path to training, test, or validation data.
data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
is_test (bool, optional): Whether the data is for test or not. Defaults to False.
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
This is effective only when is_test=False. Defaults to False.

Returns:
pandas.DataFrame: Data composed of index, label, and tokenized text.
"""
logging.info(f'Load data from {path}.')
data = pd.read_csv(path, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
data = data.astype(str)
if data.shape[1] == 2:
data.columns = ['label', 'text']
data = data.reset_index()
Expand Down Expand Up @@ -166,21 +164,21 @@ def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data


def load_datasets(
training_file=None,
test_file=None,
val_file=None,
training_data=None,
test_data=None,
val_data=None,
val_size=0.2,
merge_train_val=False,
tokenize_text=True,
remove_no_label_data=False
):
"""Load data from the specified data paths (i.e., `training_file`, `test_file`, and `val_file`).
If `valid.txt` does not exist but `val_size` > 0, the validation set will be split from the training dataset.
"""Load data from the specified data paths or the given dataframe.
If `val_data` does not exist but `val_size` > 0, the validation set will be split from the training dataset.

Args:
training_file (str, optional): Path to training data.
test_file (str, optional): Path to test data.
val_file (str, optional): Path to validation data.
training_data (Union[str, pandas,.Dataframe], optional): Path to training data or a dataframe.
test_data (Union[str, pandas,.Dataframe], optional): Path to test data or a dataframe.
val_data (Union[str, pandas,.Dataframe], optional): Path to validation data or a dataframe.
val_size (float, optional): Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set.
Defaults to 0.2.
merge_train_val (bool, optional): Whether to merge the training and validation data.
Expand All @@ -192,22 +190,37 @@ def load_datasets(
Returns:
dict: A dictionary of datasets.
"""
assert training_file or test_file, "At least one of `training_file` and `test_file` must be specified."
if isinstance(training_data, str) or isinstance(test_data, str):
assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified."
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
assert not training_data.empty or not test_data.empty, "At least one of `training_data` and `test_data` must be specified."

datasets = {}
if training_file is not None:
datasets['train'] = _load_raw_data(training_file, tokenize_text=tokenize_text,
if training_data is not None:
if isinstance(training_data, str):
logging.info(f'Load data from {training_data}.')
training_data = pd.read_csv(training_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['train'] = _load_raw_data(training_data, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)

if val_file is not None:
datasets['val'] = _load_raw_data(val_file, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
if val_data is not None:
if isinstance(val_data, str):
logging.info(f'Load data from {val_data}.')
val_data = pd.read_csv(val_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['val'] = _load_raw_data(val_data, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
elif val_size > 0:
datasets['train'], datasets['val'] = train_test_split(
datasets['train'], test_size=val_size, random_state=42)

if test_file is not None:
datasets['test'] = _load_raw_data(test_file, is_test=True, tokenize_text=tokenize_text,
if test_data is not None:
if isinstance(test_data, str):
logging.info(f'Load data from {test_data}.')
test_data = pd.read_csv(test_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['test'] = _load_raw_data(test_data, is_test=True, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)

if merge_train_val:
Expand Down
6 changes: 3 additions & 3 deletions torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(
# Load dataset
if datasets is None:
self.datasets = data_utils.load_datasets(
training_file=config.training_file,
test_file=config.test_file,
val_file=config.val_file,
training_data=config.training_file,
test_data=config.test_file,
val_data=config.val_file,
val_size=config.val_size,
merge_train_val=config.merge_train_val,
tokenize_text=tokenize_text,
Expand Down