Skip to content

Commit 08aed8f

Browse files
authored
Transform cleanup (#176)
* Preference swap clean * Split is a transform * Seperate transform module * Revert accidentl file push * Fix tests
1 parent b717ffe commit 08aed8f

File tree

14 files changed

+345
-70
lines changed

14 files changed

+345
-70
lines changed

aif_gen/cli/commands/split.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import click
66

7-
import aif_gen.dataset.split.functional as F
7+
import aif_gen.transforms.functional as F
88
from aif_gen.dataset.continual_alignment_dataset import (
99
ContinualAlignmentDataset,
1010
)
@@ -73,15 +73,13 @@ def split(
7373

7474
seed_everything(random_seed)
7575
logging.info(f'Splitting dataset with test_sample_ratio={test_sample_ratio}')
76-
dataset = F.split(dataset, test_ratio=test_sample_ratio)
76+
transformed_dataset = F.split_transform(dataset, test_ratio=test_sample_ratio)
7777
logging.info(f'Writing dataset to: {output_file}')
78-
dataset.to_json(output_file)
79-
logging.info(f'Wrote {dataset.num_samples} samples to: {output_file}')
78+
transformed_dataset.to_json(output_file)
79+
logging.info(f'Wrote {transformed_dataset.num_samples} samples to: {output_file}')
8080

8181
if hf_repo_id_out is not None:
8282
upload_to_hf(hf_repo_id_out, output_file)
8383
logging.info(f'Uploaded dataset to HuggingFace repo: {hf_repo_id_out}')
8484
else:
8585
logging.info(f'No HuggingFace repo specified for upload.')
86-
87-
return

aif_gen/cli/commands/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import click
66

7-
import aif_gen.dataset.transforms.functional as F
7+
import aif_gen.transforms.functional as F
88
from aif_gen.dataset.continual_alignment_dataset import (
99
ContinualAlignmentDataset,
1010
)

aif_gen/dataset/split/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

aif_gen/dataset/split/functional.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

aif_gen/dataset/transforms/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

aif_gen/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from aif_gen.transforms.base import DatasetTransform
2+
from aif_gen.transforms.preference_swap_transform import PreferenceSwapTransform
3+
from aif_gen.transforms.split_transform import SplitTransform
4+
from aif_gen.transforms.functional import *
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,19 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Union
2+
from typing import Any
33

4-
from aif_gen.dataset import AlignmentDataset, ContinualAlignmentDataset
5-
6-
# Typedef for convenience
7-
Dataset = Union[ContinualAlignmentDataset, AlignmentDataset]
4+
from aif_gen.typing import Dataset
85

96

107
class DatasetTransform(ABC):
118
r"""Base class for transforming Alignment Datasets."""
129

1310
@abstractmethod
14-
def apply(
15-
self, dataset: Dataset, in_place: bool = False, *args: Any, **kwargs: Any
16-
) -> Dataset:
11+
def apply(self, dataset: Dataset, in_place: bool = False) -> Dataset:
1712
r"""Apply the transform onto a dataset.
1813
1914
Args:
2015
dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to transform.
2116
in_place: Whether to apply the transform in-place or return a new dataset.
22-
args (Any): Optional positional arguments.
23-
kwargs (Any): Optional keyword arguments.
2417
2518
Returns:
2619
Union[ContinualAlignmentDataset, AlignmentDataset]: The transformed dataset.
@@ -30,8 +23,4 @@ def __call__(self, dataset: Dataset, *args: Any, **kwargs: Any) -> Dataset:
3023
return self.apply(dataset, *args, **kwargs)
3124

3225
def __str__(self) -> str:
33-
r"""Returns the type of Dataset transform."""
3426
return self.__class__.__name__
35-
36-
def _is_dataset_continual(self, dataset: Dataset) -> bool:
37-
return isinstance(dataset, ContinualAlignmentDataset)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import Dataset
22
from .preference_swap_transform import PreferenceSwapTransform
3+
from .split_transform import SplitTransform
34

45

56
def preference_swap_transform(
@@ -20,3 +21,23 @@ def preference_swap_transform(
2021
"""
2122
transform = PreferenceSwapTransform(swap_probability)
2223
return transform(dataset, in_place=in_place)
24+
25+
26+
def split_transform(
27+
dataset: Dataset, test_ratio: float, in_place: bool = False
28+
) -> Dataset:
29+
r"""Splits a Dataset training data into train and test datasets.
30+
31+
Args:
32+
dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to transform.
33+
in_place: Whether to apply the transform in-place or return a new dataset.
34+
test_ratio (float): The test ratio to split the dataset with.
35+
36+
Returns:
37+
Union[ContinualAlignmentDataset, AlignmentDataset]: The transformed dataset.
38+
39+
Raises:
40+
ValueError: If a dataset in the Continual Dataset has test data.
41+
"""
42+
transform = SplitTransform(test_ratio)
43+
return transform(dataset, in_place=in_place)

aif_gen/dataset/transforms/preference_swap_transform.py renamed to aif_gen/transforms/preference_swap_transform.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ def apply(self, dataset: Dataset, in_place: bool = False) -> Dataset:
4949
if self.swap_probability == 0:
5050
return dataset if in_place else copy.deepcopy(dataset)
5151

52-
if self._is_dataset_continual(dataset):
53-
# This assert is here to make mypy happy
54-
assert isinstance(dataset, ContinualAlignmentDataset)
52+
if isinstance(dataset, ContinualAlignmentDataset):
5553
if in_place:
5654
for i in range(dataset.num_datasets):
5755
dataset.datasets[i] = self._apply(dataset.datasets[i], in_place)
@@ -79,10 +77,10 @@ def _apply_inplace(
7977
) -> AlignmentDataset:
8078
for i in range(len(dataset)):
8179
if swap_outcomes[i]:
82-
chosen = dataset.samples[i].chosen
83-
rejected = dataset.samples[i].rejected
84-
dataset.samples[i].chosen = rejected
85-
dataset.samples[i].rejected = chosen
80+
dataset.samples[i].chosen, dataset.samples[i].rejected = (
81+
dataset.samples[i].rejected,
82+
dataset.samples[i].chosen,
83+
)
8684
return dataset
8785

8886
def _apply_copy(
@@ -103,8 +101,6 @@ def _apply_copy(
103101
train_frac=dataset.train_frac,
104102
)
105103

106-
def _validate_swap_probability(self, swap_probability: float) -> None:
107-
if not 0 <= swap_probability <= 1:
108-
raise ValueError(
109-
f'Expected a swap probability in the range [0, 1] but got: {swap_probability}'
110-
)
104+
def _validate_swap_probability(self, swap_prob: float) -> None:
105+
if not 0 <= swap_prob <= 1:
106+
raise ValueError(f'Swap probability must be in [0, 1], got: {swap_prob}')
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from aif_gen.dataset import AlignmentDataset, ContinualAlignmentDataset
2+
3+
from .base import Dataset, DatasetTransform
4+
5+
6+
class SplitTransform(DatasetTransform):
7+
r"""SplitTransform splits the training data into train/test datasets.
8+
9+
Args:
10+
dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to transform.
11+
test_ratio (float): The test ratio to split the dataset with.
12+
13+
Returns:
14+
Union[ContinualAlignmentDataset, AlignmentDataset]: The transformed dataset.
15+
16+
Raises:
17+
ValueError: If the test ratio is not in the range [0, 1].
18+
"""
19+
20+
def __init__(self, test_ratio: float) -> None:
21+
self._validate_test_ratio(test_ratio)
22+
self._test_ratio = test_ratio
23+
24+
@property
25+
def test_ratio(self) -> float:
26+
r"""float: The test ratio to split the dataset with."""
27+
return self._test_ratio
28+
29+
@test_ratio.setter
30+
def test_ratio(self, test_ratio: float) -> None:
31+
self._validate_test_ratio(test_ratio)
32+
self._test_ratio = test_ratio
33+
34+
def apply(self, dataset: Dataset, in_place: bool = False) -> Dataset:
35+
r"""Splits a ContinualAlignmentDataset's training data into train and test datasets.
36+
37+
Args:
38+
dataset (ContinualAlignmentDataset): The dataset to split.
39+
in_place (bool): Whether to apply the transform in-place or return a new dataset.
40+
41+
Returns:
42+
ContinualAlignmentDataset: The dataset with test data included.
43+
44+
Raises:
45+
ValueError: If a dataset in the Continual Dataset has test data.
46+
"""
47+
self._check_test_frac_empty(dataset)
48+
if isinstance(dataset, ContinualAlignmentDataset):
49+
if in_place:
50+
for i in range(dataset.num_datasets):
51+
dataset.datasets[i].train_frac = 1 - self.test_ratio
52+
return dataset
53+
else:
54+
transformed_datasets = []
55+
for data in dataset.datasets:
56+
transformed_datasets.append(
57+
AlignmentDataset(
58+
data.task,
59+
data.samples,
60+
train_frac=1 - self.test_ratio,
61+
)
62+
)
63+
return ContinualAlignmentDataset(transformed_datasets)
64+
else:
65+
# This assert is here to make mypy happy
66+
assert isinstance(dataset, AlignmentDataset)
67+
if in_place:
68+
dataset.train_frac = 1 - self.test_ratio
69+
return dataset
70+
else:
71+
return AlignmentDataset(
72+
dataset.task, dataset.samples, train_frac=1 - self.test_ratio
73+
)
74+
75+
def _validate_test_ratio(self, test_ratio: float) -> None:
76+
if not 0 <= test_ratio <= 1:
77+
raise ValueError(f'Test ratio must be in [0, 1], got: {test_ratio}')
78+
79+
def _check_test_frac_empty(self, dataset: Dataset) -> None:
80+
if isinstance(dataset, ContinualAlignmentDataset):
81+
datasets = dataset.datasets
82+
else:
83+
assert isinstance(dataset, AlignmentDataset)
84+
datasets = [dataset]
85+
for dataset in datasets:
86+
if dataset.test_frac != 0:
87+
raise ValueError('AlignmentDataset cannot have test data for splitting')

0 commit comments

Comments
 (0)