Skip to content

Commit 75e29cd

Browse files
authored
Dataset Cleanup (#174)
* WIP * Fix tests * update continual dataset * Pipe instead of union * Annotate
1 parent 93bd8bb commit 75e29cd

File tree

7 files changed

+108
-118
lines changed

7 files changed

+108
-118
lines changed

aif_gen/dataset/alignment_dataset.py

Lines changed: 26 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from __future__ import annotations
2+
13
import json
24
import pathlib
35
from dataclasses import asdict
4-
from typing import Any, Dict, List, Union
6+
from typing import Any, Dict, List
57

68
from datasets import Dataset
9+
from pydantic import Field
10+
from pydantic.dataclasses import dataclass
711

812
from aif_gen.task import AlignmentTask
913

1014
from .alignment_sample import AlignmentDatasetSample
1115

1216

17+
@dataclass(slots=True)
1318
class AlignmentDataset:
1419
r"""Container object for an Alignment Dataset.
1520
@@ -22,38 +27,14 @@ class AlignmentDataset:
2227
ValueError: If train_frac is not in the interval [0, 1.0]
2328
"""
2429

25-
def __init__(
26-
self,
27-
task: AlignmentTask,
28-
samples: List[AlignmentDatasetSample],
29-
train_frac: float = 1.0,
30-
) -> None:
31-
self._task = task
32-
self._samples = samples
33-
34-
if not (0 <= train_frac <= 1):
35-
raise ValueError(f'Train fraction must be in [0, 1] but got: {train_frac}')
36-
self._train_frac = train_frac
37-
38-
@property
39-
def task(self) -> AlignmentTask:
40-
r"""AlignmentTask: The task associated with the AlignmentDataset."""
41-
return self._task
42-
43-
@property
44-
def train_frac(self) -> float:
45-
r"""Fraction of samples that belong to the training split."""
46-
return self._train_frac
30+
task: AlignmentTask = Field(frozen=True)
31+
samples: List[AlignmentDatasetSample] = Field(frozen=True)
32+
train_frac: float = Field(default=1.0, ge=0, le=1)
4733

4834
@property
4935
def test_frac(self) -> float:
5036
r"""Fraction of samples that belong to the testing split."""
51-
return 1.0 - self._train_frac
52-
53-
@property
54-
def samples(self) -> List[AlignmentDatasetSample]:
55-
r"""List[AlignmentDatasetSample]: The list of samples associated with the AlignmentDataset."""
56-
return self._samples
37+
return 1.0 - self.train_frac
5738

5839
@property
5940
def train(self) -> List[AlignmentDatasetSample]:
@@ -84,12 +65,12 @@ def __len__(self) -> int:
8465
return self.num_samples
8566

8667
def __getitem__(
87-
self, key: Union[slice, int]
88-
) -> Union[AlignmentDatasetSample, List[AlignmentDatasetSample]]:
68+
self, key: slice | int
69+
) -> AlignmentDatasetSample | List[AlignmentDatasetSample]:
8970
# Slicing directly on the samples
9071
return self.samples[key]
9172

92-
def to_json(self, file_path: Union[str, pathlib.Path]) -> None:
73+
def to_json(self, file_path: str | pathlib.Path) -> None:
9374
r"""Save the AlignmentDataset to a json file.
9475
9576
Note: Uses to_dict() under the hood to get a dictionary representation.
@@ -104,26 +85,17 @@ def to_json(self, file_path: Union[str, pathlib.Path]) -> None:
10485
def to_dict(self) -> Dict[str, Any]:
10586
r"""Convert the AlignmentDataset to dictionary represenetation.
10687
107-
Note: This method is the functional inverse of AlignmentDataset.from_dict().
108-
10988
Returns:
11089
Dict[str, Any]: The dictionary representation of the AlignmentDataset.
11190
"""
11291
dataset_dict: Dict[str, Any] = {}
11392
dataset_dict['task'] = self.task.to_dict()
114-
dataset_dict['train'] = []
115-
dataset_dict['test'] = []
116-
117-
for sample in self.train:
118-
dataset_dict['train'].append(asdict(sample))
119-
120-
for sample in self.test:
121-
dataset_dict['test'].append(asdict(sample))
122-
93+
dataset_dict['train'] = [asdict(sample) for sample in self.train]
94+
dataset_dict['test'] = [asdict(sample) for sample in self.test]
12395
return dataset_dict
12496

12597
@classmethod
126-
def from_json(cls, file_path: Union[str, pathlib.Path]) -> 'AlignmentDataset':
98+
def from_json(cls, file_path: str | pathlib.Path) -> AlignmentDataset:
12799
r"""Load the AlignmentDataset from a json file.
128100
129101
Note: Uses AlignmentDataset.from_dict() under the hood to parse the representation.
@@ -136,11 +108,10 @@ def from_json(cls, file_path: Union[str, pathlib.Path]) -> 'AlignmentDataset':
136108
"""
137109
with open(file_path, 'r') as f:
138110
dataset_dict = json.load(f)
139-
140111
return cls.from_dict(dataset_dict)
141112

142113
@classmethod
143-
def from_dict(cls, dataset_dict: Dict[str, Any]) -> 'AlignmentDataset':
114+
def from_dict(cls, dataset_dict: Dict[str, Any]) -> AlignmentDataset:
144115
r"""Construct an AlignmentDataset from dictionary representation.
145116
146117
Note:
@@ -161,14 +132,11 @@ def from_dict(cls, dataset_dict: Dict[str, Any]) -> 'AlignmentDataset':
161132
task = AlignmentTask.from_dict(dataset_dict['task'])
162133
samples = []
163134
for sample in dataset_dict['train']:
164-
sample = AlignmentDatasetSample(**sample)
165-
samples.append(sample)
166-
135+
samples.append(AlignmentDatasetSample(**sample))
167136
num_train_samples = len(samples)
168137

169138
for sample in dataset_dict['test']:
170-
sample = AlignmentDatasetSample(**sample)
171-
samples.append(sample)
139+
samples.append(AlignmentDatasetSample(**sample))
172140

173141
train_frac = num_train_samples / len(samples)
174142
return cls(task, samples, train_frac)
@@ -177,26 +145,22 @@ def to_hf_compatible(self) -> Dict[str, Dataset]:
177145
r"""Convert the AlignmentDataset to a dictionary compatible with HuggingFace datasets.
178146
179147
Returns:
180-
dict[str, Dataset]: The dictionary compatible with HuggingFace datasets.
148+
Dict[str, Dataset]: The dictionary compatible with HuggingFace datasets.
181149
"""
182-
dataset_dict: Dict[str, Any] = self.to_dict()
183-
184150
hf_dict: Dict[str, Dataset] = {
185151
'train': Dataset.from_dict(
186152
{
187-
'prompt': [sample['prompt'] for sample in dataset_dict['train']],
188-
'chosen': [sample['chosen'] for sample in dataset_dict['train']],
189-
'rejected': [
190-
sample['rejected'] for sample in dataset_dict['train']
191-
],
153+
'prompt': [sample.prompt for sample in self.train],
154+
'chosen': [sample.chosen for sample in self.train],
155+
'rejected': [sample.rejected for sample in self.train],
192156
},
193157
split='train',
194158
),
195159
'test': Dataset.from_dict(
196160
{
197-
'prompt': [sample['prompt'] for sample in dataset_dict['test']],
198-
'chosen': [sample['chosen'] for sample in dataset_dict['test']],
199-
'rejected': [sample['rejected'] for sample in dataset_dict['test']],
161+
'prompt': [sample.prompt for sample in self.test],
162+
'chosen': [sample.chosen for sample in self.test],
163+
'rejected': [sample.rejected for sample in self.test],
200164
},
201165
split='test',
202166
),

aif_gen/dataset/alignment_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from pydantic.dataclasses import dataclass
22

33

44
@dataclass

aif_gen/dataset/continual_alignment_dataset.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
1+
from __future__ import annotations
2+
13
import json
24
import pathlib
3-
from typing import Any, Dict, List, Union
5+
from typing import Any, Dict, List
46

57
from datasets import Dataset
6-
7-
from aif_gen.dataset.alignment_sample import AlignmentDatasetSample
8+
from pydantic.dataclasses import dataclass
89

910
from .alignment_dataset import AlignmentDataset
11+
from .alignment_sample import AlignmentDatasetSample
1012

1113

14+
@dataclass(slots=True)
1215
class ContinualAlignmentDataset:
1316
r"""Container object for a Continual Alignment Dataset.
1417
1518
Args:
1619
datasets (List[ContinualAlignmentDataset]): Temporal list of AlignmentDatasets constituents.
1720
"""
1821

19-
def __init__(self, datasets: List[AlignmentDataset]) -> None:
20-
self._datasets = datasets
21-
22-
@property
23-
def datasets(self) -> List[AlignmentDataset]:
24-
r"""List[AlignmentDataset]: The list of AlignmentDataset constituents."""
25-
return self._datasets
22+
datasets: List[AlignmentDataset]
2623

2724
@property
2825
def num_datasets(self) -> int:
@@ -39,16 +36,16 @@ def __len__(self) -> int:
3936
return self.num_samples
4037

4138
def __getitem__(
42-
self, key: Union[slice, int]
43-
) -> Union[AlignmentDatasetSample, List[AlignmentDatasetSample]]:
39+
self, key: slice | int
40+
) -> AlignmentDatasetSample | List[AlignmentDatasetSample]:
4441
# Indexing based on **samples** across datasets (not into datasets themselves)
4542
all_samples = [] # This should probably be cached
4643
for dataset in self.datasets:
4744
all_samples.extend(dataset.samples)
4845
return all_samples[key]
4946

5047
def append(self, dataset: AlignmentDataset) -> None:
51-
r"""Append a single AlignmentDataset to the ConitnualAlignmentDataset.
48+
r"""Append a single AlignmentDataset to the ContinualAlignmentDataset.
5249
5350
Args:
5451
dataset (AlignmentDataset): The new dataset to add.
@@ -64,7 +61,7 @@ def append(self, dataset: AlignmentDataset) -> None:
6461
)
6562

6663
def extend(self, datasets: List[AlignmentDataset]) -> None:
67-
r"""Append multiple AlignmentDataset's to the ConitnualAlignmentDataset.
64+
r"""Append multiple AlignmentDataset's to the ContinualAlignmentDataset.
6865
6966
Args:
7067
datasets (List[AlignmentDataset]): The new datasets to add.
@@ -75,7 +72,7 @@ def extend(self, datasets: List[AlignmentDataset]) -> None:
7572
for dataset in datasets:
7673
self.append(dataset)
7774

78-
def to_json(self, file_path: Union[str, pathlib.Path]) -> None:
75+
def to_json(self, file_path: str | pathlib.Path) -> None:
7976
r"""Save the ContinualAlignmentDataset to a json file.
8077
8178
Note: Uses to_dict() under the hood to get a dictionary representation.
@@ -90,8 +87,6 @@ def to_json(self, file_path: Union[str, pathlib.Path]) -> None:
9087
def to_dict(self) -> Dict[str, Any]:
9188
r"""Convert the ContinualAlignmentDataset to dictionary represenetation.
9289
93-
Note: This method is the functional inverse of ContinualAlignmentDataset.from_dict().
94-
9590
Returns:
9691
Dict[str, Any]: The dictionary representation of the ContinualAlignmentDataset.
9792
"""
@@ -101,9 +96,7 @@ def to_dict(self) -> Dict[str, Any]:
10196
return dataset_dict
10297

10398
@classmethod
104-
def from_json(
105-
cls, file_path: Union[str, pathlib.Path]
106-
) -> 'ContinualAlignmentDataset':
99+
def from_json(cls, file_path: str | pathlib.Path) -> ContinualAlignmentDataset:
107100
r"""Load the ContinualAlignmentDataset from a json file.
108101
109102
Note: Uses ContinualAlignmentDataset.from_dict() under the hood to parse the representation.
@@ -119,7 +112,7 @@ def from_json(
119112
return cls.from_dict(dataset_dict)
120113

121114
@classmethod
122-
def from_dict(cls, dataset_dict: Dict[str, Any]) -> 'ContinualAlignmentDataset':
115+
def from_dict(cls, dataset_dict: Dict[str, Any]) -> ContinualAlignmentDataset:
123116
r"""Construct a ContinualAlignmentDataset from dictionary representation.
124117
125118
Note:

aif_gen/dataset/split/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ def split(
2323
)
2424
# just change the fractions
2525
for i in range(len(dataset.datasets)):
26-
dataset.datasets[i]._train_frac = 1 - test_ratio
26+
dataset.datasets[i].train_frac = 1 - test_ratio
2727

2828
return dataset

test/test_validation/test_count_validation.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
ContinualAlignmentDataset,
55
)
66
from aif_gen.dataset.validation import count_validation
7+
from aif_gen.task.alignment_task import AlignmentTask
8+
from aif_gen.task.domain import Domain
79

810

911
def test_count_validation_all_unique():
@@ -18,7 +20,9 @@ def test_count_validation_all_unique():
1820
'Mock prompt C 1', 'Winning Response C 1', 'Losing Response C 1'
1921
),
2022
]
21-
mock_task = None
23+
mock_task = AlignmentTask(
24+
domain=Domain.from_dict({'education': {}}), objective='', preference=''
25+
)
2226
dataset = AlignmentDataset(task=mock_task, samples=samples)
2327
expected_counts = [
2428
{
@@ -44,7 +48,9 @@ def test_count_validation_all_same_prompts():
4448
'Mock prompt A 2', 'Winning Response C 2', 'Losing Response C 2'
4549
),
4650
]
47-
mock_task = None
51+
mock_task = AlignmentTask(
52+
domain=Domain.from_dict({'education': {}}), objective='', preference=''
53+
)
4854
dataset = AlignmentDataset(task=mock_task, samples=samples)
4955
expected_counts = [
5056
{
@@ -70,7 +76,9 @@ def test_count_validation_all_same_responses():
7076
'Mock prompt C 3', 'Winning Response A 3', 'Losing Response B 3'
7177
),
7278
]
73-
mock_task = None
79+
mock_task = AlignmentTask(
80+
domain=Domain.from_dict({'education': {}}), objective='', preference=''
81+
)
7482
dataset = AlignmentDataset(task=mock_task, samples=samples)
7583
expected_counts = [
7684
{
@@ -96,7 +104,9 @@ def test_count_validation_all_same_everything():
96104
'Mock prompt A 4', 'Winning Response A 4', 'Losing Response A 4'
97105
),
98106
]
99-
mock_task = None
107+
mock_task = AlignmentTask(
108+
domain=Domain.from_dict({'education': {}}), objective='', preference=''
109+
)
100110
dataset = AlignmentDataset(task=mock_task, samples=samples)
101111
expected_counts = [
102112
{
@@ -122,7 +132,9 @@ def test_count_countinual_dataset():
122132
'Mock prompt C 1', 'Winning Response C 1', 'Losing Response C 1'
123133
),
124134
]
125-
mock_task = None
135+
mock_task = AlignmentTask(
136+
domain=Domain.from_dict({'education': {}}), objective='', preference=''
137+
)
126138
dataset_one = AlignmentDataset(task=mock_task, samples=samples)
127139

128140
samples = [
@@ -136,7 +148,6 @@ def test_count_countinual_dataset():
136148
'Mock prompt A 2', 'Winning Response C 2', 'Losing Response C 2'
137149
),
138150
]
139-
mock_task = None
140151
dataset_two = AlignmentDataset(task=mock_task, samples=samples)
141152

142153
samples = [
@@ -150,7 +161,6 @@ def test_count_countinual_dataset():
150161
'Mock prompt C 3', 'Winning Response A 3', 'Losing Response B 3'
151162
),
152163
]
153-
mock_task = None
154164
dataset_three = AlignmentDataset(task=mock_task, samples=samples)
155165

156166
samples = [
@@ -164,7 +174,6 @@ def test_count_countinual_dataset():
164174
'Mock prompt A 4', 'Winning Response A 4', 'Losing Response A 4'
165175
),
166176
]
167-
mock_task = None
168177
dataset_four = AlignmentDataset(task=mock_task, samples=samples)
169178

170179
dataset = ContinualAlignmentDataset(
@@ -216,7 +225,9 @@ def test_count_validation_stop_words_removed():
216225
'with Mock prompt A 4', 'by Winning Response A 4', 'is Losing Response A 4'
217226
),
218227
]
219-
mock_task = None
228+
mock_task = AlignmentTask(
229+
domain=Domain.from_dict({'education': {}}), objective='', preference=''
230+
)
220231
dataset = AlignmentDataset(task=mock_task, samples=samples)
221232
expected_counts = [
222233
{

0 commit comments

Comments
 (0)