Skip to content

Commit b8fd288

Browse files
committed
Implement in_place transform and tests
1 parent d1f8a69 commit b8fd288

5 files changed

Lines changed: 205 additions & 138 deletions

File tree

aif_gen/dataset/transforms/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ def apply(
2626
Union[ContinualAlignmentDataset, AlignmentDataset]: The transformed dataset.
2727
"""
2828

29-
def __call__(
30-
self, dataset: Dataset, in_place: bool = False, *args: Any, **kwargs: Any
31-
) -> Dataset:
32-
return self.apply(dataset, in_place, *args, **kwargs)
29+
def __call__(self, dataset: Dataset, *args: Any, **kwargs: Any) -> Dataset:
30+
return self.apply(dataset, *args, **kwargs)
3331

3432
def __str__(self) -> str:
3533
r"""Returns the type of Dataset transform."""

aif_gen/dataset/transforms/functional.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,4 @@ def preference_swap_transform(
1919
ValueError: If the swap probability is not in the range [0, 1].
2020
"""
2121
transform = PreferenceSwapTransform(swap_probability)
22-
if in_place:
23-
transform(dataset)
24-
else:
25-
dataset = transform(dataset)
26-
return dataset
22+
return transform(dataset, in_place=in_place)

aif_gen/dataset/transforms/preference_swap_transform.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,42 @@ def apply(self, dataset: Dataset, in_place: bool = False) -> Dataset:
5252
if self._is_dataset_continual(dataset):
5353
# This assert is here to make mypy happy
5454
assert isinstance(dataset, ContinualAlignmentDataset)
55-
return ContinualAlignmentDataset(
56-
[self._apply(data) for data in dataset.datasets]
57-
)
55+
if in_place:
56+
for i in range(dataset.num_datasets):
57+
dataset.datasets[i] = self._apply(dataset.datasets[i], in_place)
58+
return dataset
59+
else:
60+
return ContinualAlignmentDataset(
61+
[self._apply(data, in_place) for data in dataset.datasets]
62+
)
5863
else:
5964
# This assert is here to make mypy happy
6065
assert isinstance(dataset, AlignmentDataset)
61-
return self._apply(dataset)
66+
return self._apply(dataset, in_place)
6267

63-
def _apply(self, dataset: AlignmentDataset) -> AlignmentDataset:
68+
def _apply(self, dataset: AlignmentDataset, in_place: bool) -> AlignmentDataset:
6469
swap_outcomes = np.random.binomial(
6570
n=1, p=self.swap_probability, size=len(dataset)
6671
)
72+
if in_place:
73+
return self._apply_inplace(dataset, swap_outcomes)
74+
else:
75+
return self._apply_copy(dataset, swap_outcomes)
76+
77+
def _apply_inplace(
78+
self, dataset: AlignmentDataset, swap_outcomes: np.ndarray
79+
) -> AlignmentDataset:
80+
for i in range(len(dataset)):
81+
if swap_outcomes[i]:
82+
chosen = dataset.samples[i].chosen
83+
rejected = dataset.samples[i].rejected
84+
dataset.samples[i].chosen = rejected
85+
dataset.samples[i].rejected = chosen
86+
return dataset
87+
88+
def _apply_copy(
89+
self, dataset: AlignmentDataset, swap_outcomes: np.ndarray
90+
) -> AlignmentDataset:
6791
transformed_samples = []
6892
for i, sample in enumerate(dataset.samples):
6993
if swap_outcomes[i]:

test/test_transforms/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
from aif_gen.util.seed import seed_everything
4+
5+
6+
@pytest.fixture(autouse=True)
7+
def run_seed_before_tests():
8+
seed_everything(1)
9+
yield
10+
11+
12+
@pytest.fixture(params=[True, False])
13+
def in_place(request):
14+
return request.param
15+
16+
17+
@pytest.fixture(params=['call', 'apply', 'functional'])
18+
def application_type(request):
19+
return request.param

0 commit comments

Comments
 (0)