Skip to content

Commit 3a0bbcc

Browse files
authored
Merge pull request #22 from ComplexData-MILA/dev/alignment_dataset
Alignment Dataset Object
2 parents 7d92d48 + 1d0dd8f commit 3a0bbcc

File tree

6 files changed

+683
-0
lines changed

6 files changed

+683
-0
lines changed

aif_gen/dataset/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from aif_gen.dataset.alignment_sample import AlignmentDatasetSample
2+
from aif_gen.dataset.alignment_dataset import AlignmentDataset
3+
from aif_gen.dataset.continual_alignment_dataset import ContinualAlignmentDataset
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import json
2+
from dataclasses import asdict
3+
from typing import Any, Dict, List, Union
4+
5+
from aif_gen.task import AlignmentTask
6+
7+
from .alignment_sample import AlignmentDatasetSample
8+
9+
10+
class AlignmentDataset:
11+
r"""Container object for an Alignment Dataset.
12+
13+
Args:
14+
task (AligmnentTask): The AlignmentTask associated with the dataset.
15+
samples (List[AlignmentDatasetSample]): The samples in this AlignmentDataset.
16+
"""
17+
18+
def __init__(
19+
self, task: AlignmentTask, samples: List[AlignmentDatasetSample]
20+
) -> None:
21+
self._task = task
22+
self._samples = samples
23+
24+
@property
25+
def task(self) -> AlignmentTask:
26+
"""AlignmentTask: The task associated with the AlignmentDataset."""
27+
return self._task
28+
29+
@property
30+
def samples(self) -> List[AlignmentDatasetSample]:
31+
"""List[AlignmentDatasetSample]: The list of samples associated with the AlignmentDataset."""
32+
return self._samples
33+
34+
@property
35+
def num_samples(self) -> int:
36+
"""int: The number of samples associated with the AlignmentDataset."""
37+
return len(self.samples)
38+
39+
def __len__(self) -> int:
40+
return self.num_samples
41+
42+
def __getitem__(
43+
self, key: Union[slice, int]
44+
) -> Union[AlignmentDatasetSample, List[AlignmentDatasetSample]]:
45+
# Slicing directly on the samples
46+
return self.samples[key]
47+
48+
def append(self, sample: AlignmentDatasetSample) -> None:
49+
r"""Append a single AlignmentDatasetSample to the Alignment Dataset.
50+
51+
Args:
52+
sample (AlignmentDatasetSample): The new sample to add.
53+
54+
Raises:
55+
TypeError: if the sample is not of type AlignmentDatasetSample.
56+
"""
57+
if isinstance(sample, AlignmentDatasetSample):
58+
self.samples.append(sample)
59+
else:
60+
raise TypeError(
61+
f'Sample: {sample} must be of type AlignmentDatasetSample but got {sample.__class__.__name__}'
62+
)
63+
64+
def extend(self, samples: List[AlignmentDatasetSample]) -> None:
65+
r"""Add multiple AlignmentDatasetSample's to the Alignment Dataset.
66+
67+
Args:
68+
samples (List[AlignmentDatasetSample]): The new samples to add.
69+
70+
Raises:
71+
TypeError: if any sample is not of type AlignmentDatasetSample.
72+
"""
73+
for sample in samples:
74+
self.append(sample)
75+
76+
def to_json(self, file_path: str) -> None:
77+
r"""Save the AlignmentDataset to a json file.
78+
79+
Note: Uses to_dict() under the hood to get a dictionary representation.
80+
81+
Args:
82+
file_path (str): The os.pathlike object to write to.
83+
"""
84+
dataset_dict = self.to_dict()
85+
with open(file_path, 'w') as f:
86+
json.dump(dataset_dict, f)
87+
88+
def to_dict(self) -> Dict[str, Any]:
89+
r"""Convert the AlignmentDataset to dictionary represenetation.
90+
91+
Note: This method is the functional inverse of AlignmentDataset.from_dict().
92+
93+
Returns:
94+
Dict[str, Any]: The dictionary representation of the AlignmentDataset.
95+
"""
96+
dataset_dict: Dict[str, Any] = {}
97+
dataset_dict['task'] = self.task.to_dict()
98+
dataset_dict['samples'] = []
99+
for sample in self.samples:
100+
dataset_dict['samples'].append(asdict(sample))
101+
return dataset_dict
102+
103+
@classmethod
104+
def from_json(cls, file_path: str) -> 'AlignmentDataset':
105+
r"""Load the AlignmentDataset from a json file.
106+
107+
Note: Uses AlignmentDataset.from_dict() under the hood to parse the representation.
108+
109+
Args:
110+
file_path (str): The os.pathlike object to read from.
111+
112+
Returns:
113+
AlignmentDataset: The newly constructed AlignmentDataset.
114+
"""
115+
with open(file_path, 'r') as f:
116+
dataset_dict = json.load(f)
117+
118+
return cls.from_dict(dataset_dict)
119+
120+
@classmethod
121+
def from_dict(cls, dataset_dict: Dict[str, Any]) -> 'AlignmentDataset':
122+
r"""Construct an AlignmentDataset from dictionary representation.
123+
124+
Note:
125+
Expects 'task', and 'samples' keys to be present in the dictionary.
126+
The 'task' value should be parsable by AlignmentTask.from_dict().
127+
The 'samples' value should be a list of dictionaries, each of which
128+
are parsable by AlignmentDatasetSample.
129+
130+
Args:
131+
dataset_dict (Dict[str, Any]): The dictionary that encodes the AlignmentDataset.
132+
133+
Returns:
134+
AlignmentDataset: The newly constructed AlignmentDataset.
135+
136+
Raises:
137+
ValueError: If the input dictionary is missing any required keys.
138+
"""
139+
task = AlignmentTask.from_dict(dataset_dict['task'])
140+
samples = []
141+
for sample in dataset_dict['samples']:
142+
sample = AlignmentDatasetSample(**sample)
143+
samples.append(sample)
144+
145+
return cls(task, samples)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class AlignmentDatasetSample:
6+
r"""Container for a single Alignment Dataset Sample.
7+
8+
This representation is faithful to the "TRL Preference Format with explicit prompt".
9+
See: https://huggingface.co/docs/trl/en/dataset_formats.
10+
11+
Args:
12+
prompt (str): The prompt associated with the sample.
13+
chosen (str): The winning response associated with the sample.
14+
rejected (str): The losing response associated with the sample.
15+
"""
16+
17+
prompt: str
18+
chosen: str
19+
rejected: str
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import json
2+
from typing import Any, Dict, List, Union
3+
4+
from aif_gen.dataset.alignment_sample import AlignmentDatasetSample
5+
6+
from .alignment_dataset import AlignmentDataset
7+
8+
9+
class ContinualAlignmentDataset:
10+
r"""Container object for a Continual Alignment Dataset.
11+
12+
Args:
13+
datasets (List[ContinualAlignmentDataset]): Temporal list of AlignmentDatasets constituents.
14+
"""
15+
16+
def __init__(self, datasets: List[AlignmentDataset]) -> None:
17+
self._datasets = datasets
18+
19+
@property
20+
def datasets(self) -> List[AlignmentDataset]:
21+
"""List[AlignmentDataset]: The list of AlignmentDataset constituents."""
22+
return self._datasets
23+
24+
@property
25+
def num_datasets(self) -> int:
26+
"""int: The number of AlignmentDataset constituents."""
27+
return len(self.datasets)
28+
29+
@property
30+
def num_samples(self) -> int:
31+
"""int: The total number of samples acros all AlignmentDataset constituents."""
32+
return sum(len(dataset) for dataset in self.datasets)
33+
34+
def __len__(self) -> int:
35+
"""int: The total number of samples acros all AlignmentDataset constituents."""
36+
return self.num_samples
37+
38+
def __getitem__(
39+
self, key: Union[slice, int]
40+
) -> Union[AlignmentDatasetSample, List[AlignmentDatasetSample]]:
41+
# Indexing based on **samples** across datasets (not into datasets themselves)
42+
all_samples = [] # This should probably be cached
43+
for dataset in self.datasets:
44+
all_samples.extend(dataset.samples)
45+
return all_samples[key]
46+
47+
def append(self, dataset: AlignmentDataset) -> None:
48+
r"""Append a single AlignmentDataset to the ConitnualAlignmentDataset.
49+
50+
Args:
51+
dataset (AlignmentDataset): The new dataset to add.
52+
53+
Raises:
54+
TypeError: if the sample is not of type AlignmentDataset.
55+
"""
56+
if isinstance(dataset, AlignmentDataset):
57+
self.datasets.append(dataset)
58+
else:
59+
raise TypeError(
60+
f'Dataset: {dataset} must be of type AlignmentDataset but got {dataset.__class__.__name__}'
61+
)
62+
63+
def extend(self, datasets: List[AlignmentDataset]) -> None:
64+
r"""Append multiple AlignmentDataset's to the ConitnualAlignmentDataset.
65+
66+
Args:
67+
datasets (List[AlignmentDataset]): The new datasets to add.
68+
69+
Raises:
70+
TypeError: if any dataset is not of type AlignmentDataset.
71+
"""
72+
for dataset in datasets:
73+
self.append(dataset)
74+
75+
def to_json(self, file_path: str) -> None:
76+
r"""Save the ContinualAlignmentDataset to a json file.
77+
78+
Note: Uses to_dict() under the hood to get a dictionary representation.
79+
80+
Args:
81+
file_path (str): The os.pathlike object to write to.
82+
"""
83+
dataset_dict = self.to_dict()
84+
with open(file_path, 'w') as f:
85+
json.dump(dataset_dict, f)
86+
87+
def to_dict(self) -> Dict[str, Any]:
88+
r"""Convert the ContinualAlignmentDataset to dictionary represenetation.
89+
90+
Note: This method is the functional inverse of ContinualAlignmentDataset.from_dict().
91+
92+
Returns:
93+
Dict[str, Any]: The dictionary representation of the ContinualAlignmentDataset.
94+
"""
95+
dataset_dict: Dict[str, List[Any]] = {'datasets': []}
96+
for dataset in self.datasets:
97+
dataset_dict['datasets'].append(dataset.to_dict())
98+
return dataset_dict
99+
100+
@classmethod
101+
def from_json(cls, file_path: str) -> 'ContinualAlignmentDataset':
102+
r"""Load the ContinualAlignmentDataset from a json file.
103+
104+
Note: Uses ContinualAlignmentDataset.from_dict() under the hood to parse the representation.
105+
106+
Args:
107+
file_path (str): The os.pathlike object to read from.
108+
109+
Returns:
110+
ContinualAlignmentDataset: The newly constructed ContinualAlignmentDataset.
111+
"""
112+
with open(file_path, 'r') as f:
113+
dataset_dict = json.load(f)
114+
return cls.from_dict(dataset_dict)
115+
116+
@classmethod
117+
def from_dict(cls, dataset_dict: Dict[str, Any]) -> 'ContinualAlignmentDataset':
118+
r"""Construct a ContinualAlignmentDataset from dictionary representation.
119+
120+
Note:
121+
Expects 'datasets' key to be present in the dictionary. The value is a list
122+
of dictionaries, each parsable by AlignmentDataset.from_dict().
123+
124+
Args:
125+
dataset_dict (Dict[str, Any]): The dictionary that encodes the ContinualAlignmentDataset.
126+
127+
Returns:
128+
ContinualAlignmentDataset: The newly constructed ContinualAlignmentDataset.
129+
130+
Raises:
131+
ValueError: If the input dictionary is missing any required keys.
132+
"""
133+
datasets = []
134+
for dataset in dataset_dict['datasets']:
135+
datasets.append(AlignmentDataset.from_dict(dataset))
136+
return cls(datasets)

0 commit comments

Comments
 (0)