1+ from __future__ import annotations
2+
13import json
24import pathlib
35from dataclasses import asdict
4- from typing import Any , Dict , List , Union
6+ from typing import Any , Dict , List
57
68from datasets import Dataset
9+ from pydantic import Field
10+ from pydantic .dataclasses import dataclass
711
812from aif_gen .task import AlignmentTask
913
1014from .alignment_sample import AlignmentDatasetSample
1115
1216
17+ @dataclass (slots = True )
1318class AlignmentDataset :
1419 r"""Container object for an Alignment Dataset.
1520
@@ -22,38 +27,14 @@ class AlignmentDataset:
2227 ValueError: If train_frac is not in the interval [0, 1.0]
2328 """
2429
25- def __init__ (
26- self ,
27- task : AlignmentTask ,
28- samples : List [AlignmentDatasetSample ],
29- train_frac : float = 1.0 ,
30- ) -> None :
31- self ._task = task
32- self ._samples = samples
33-
34- if not (0 <= train_frac <= 1 ):
35- raise ValueError (f'Train fraction must be in [0, 1] but got: { train_frac } ' )
36- self ._train_frac = train_frac
37-
38- @property
39- def task (self ) -> AlignmentTask :
40- r"""AlignmentTask: The task associated with the AlignmentDataset."""
41- return self ._task
42-
43- @property
44- def train_frac (self ) -> float :
45- r"""Fraction of samples that belong to the training split."""
46- return self ._train_frac
30+ task : AlignmentTask = Field (frozen = True )
31+ samples : List [AlignmentDatasetSample ] = Field (frozen = True )
32+ train_frac : float = Field (default = 1.0 , ge = 0 , le = 1 )
4733
4834 @property
4935 def test_frac (self ) -> float :
5036 r"""Fraction of samples that belong to the testing split."""
51- return 1.0 - self ._train_frac
52-
53- @property
54- def samples (self ) -> List [AlignmentDatasetSample ]:
55- r"""List[AlignmentDatasetSample]: The list of samples associated with the AlignmentDataset."""
56- return self ._samples
37+ return 1.0 - self .train_frac
5738
5839 @property
5940 def train (self ) -> List [AlignmentDatasetSample ]:
@@ -84,12 +65,12 @@ def __len__(self) -> int:
8465 return self .num_samples
8566
8667 def __getitem__ (
87- self , key : Union [ slice , int ]
88- ) -> Union [ AlignmentDatasetSample , List [AlignmentDatasetSample ] ]:
68+ self , key : slice | int
69+ ) -> AlignmentDatasetSample | List [AlignmentDatasetSample ]:
8970 # Slicing directly on the samples
9071 return self .samples [key ]
9172
92- def to_json (self , file_path : Union [ str , pathlib .Path ] ) -> None :
73+ def to_json (self , file_path : str | pathlib .Path ) -> None :
9374 r"""Save the AlignmentDataset to a json file.
9475
9576 Note: Uses to_dict() under the hood to get a dictionary representation.
@@ -104,26 +85,17 @@ def to_json(self, file_path: Union[str, pathlib.Path]) -> None:
10485 def to_dict (self ) -> Dict [str , Any ]:
10586 r"""Convert the AlignmentDataset to dictionary represenetation.
10687
107- Note: This method is the functional inverse of AlignmentDataset.from_dict().
108-
10988 Returns:
11089 Dict[str, Any]: The dictionary representation of the AlignmentDataset.
11190 """
11291 dataset_dict : Dict [str , Any ] = {}
11392 dataset_dict ['task' ] = self .task .to_dict ()
114- dataset_dict ['train' ] = []
115- dataset_dict ['test' ] = []
116-
117- for sample in self .train :
118- dataset_dict ['train' ].append (asdict (sample ))
119-
120- for sample in self .test :
121- dataset_dict ['test' ].append (asdict (sample ))
122-
93+ dataset_dict ['train' ] = [asdict (sample ) for sample in self .train ]
94+ dataset_dict ['test' ] = [asdict (sample ) for sample in self .test ]
12395 return dataset_dict
12496
12597 @classmethod
126- def from_json (cls , file_path : Union [ str , pathlib .Path ] ) -> ' AlignmentDataset' :
98+ def from_json (cls , file_path : str | pathlib .Path ) -> AlignmentDataset :
12799 r"""Load the AlignmentDataset from a json file.
128100
129101 Note: Uses AlignmentDataset.from_dict() under the hood to parse the representation.
@@ -136,11 +108,10 @@ def from_json(cls, file_path: Union[str, pathlib.Path]) -> 'AlignmentDataset':
136108 """
137109 with open (file_path , 'r' ) as f :
138110 dataset_dict = json .load (f )
139-
140111 return cls .from_dict (dataset_dict )
141112
142113 @classmethod
143- def from_dict (cls , dataset_dict : Dict [str , Any ]) -> ' AlignmentDataset' :
114+ def from_dict (cls , dataset_dict : Dict [str , Any ]) -> AlignmentDataset :
144115 r"""Construct an AlignmentDataset from dictionary representation.
145116
146117 Note:
@@ -161,14 +132,11 @@ def from_dict(cls, dataset_dict: Dict[str, Any]) -> 'AlignmentDataset':
161132 task = AlignmentTask .from_dict (dataset_dict ['task' ])
162133 samples = []
163134 for sample in dataset_dict ['train' ]:
164- sample = AlignmentDatasetSample (** sample )
165- samples .append (sample )
166-
135+ samples .append (AlignmentDatasetSample (** sample ))
167136 num_train_samples = len (samples )
168137
169138 for sample in dataset_dict ['test' ]:
170- sample = AlignmentDatasetSample (** sample )
171- samples .append (sample )
139+ samples .append (AlignmentDatasetSample (** sample ))
172140
173141 train_frac = num_train_samples / len (samples )
174142 return cls (task , samples , train_frac )
@@ -177,26 +145,22 @@ def to_hf_compatible(self) -> Dict[str, Dataset]:
177145 r"""Convert the AlignmentDataset to a dictionary compatible with HuggingFace datasets.
178146
179147 Returns:
180- dict [str, Dataset]: The dictionary compatible with HuggingFace datasets.
148+ Dict [str, Dataset]: The dictionary compatible with HuggingFace datasets.
181149 """
182- dataset_dict : Dict [str , Any ] = self .to_dict ()
183-
184150 hf_dict : Dict [str , Dataset ] = {
185151 'train' : Dataset .from_dict (
186152 {
187- 'prompt' : [sample ['prompt' ] for sample in dataset_dict ['train' ]],
188- 'chosen' : [sample ['chosen' ] for sample in dataset_dict ['train' ]],
189- 'rejected' : [
190- sample ['rejected' ] for sample in dataset_dict ['train' ]
191- ],
153+ 'prompt' : [sample .prompt for sample in self .train ],
154+ 'chosen' : [sample .chosen for sample in self .train ],
155+ 'rejected' : [sample .rejected for sample in self .train ],
192156 },
193157 split = 'train' ,
194158 ),
195159 'test' : Dataset .from_dict (
196160 {
197- 'prompt' : [sample [ ' prompt' ] for sample in dataset_dict [ ' test' ] ],
198- 'chosen' : [sample [ ' chosen' ] for sample in dataset_dict [ ' test' ] ],
199- 'rejected' : [sample [ ' rejected' ] for sample in dataset_dict [ ' test' ] ],
161+ 'prompt' : [sample . prompt for sample in self . test ],
162+ 'chosen' : [sample . chosen for sample in self . test ],
163+ 'rejected' : [sample . rejected for sample in self . test ],
200164 },
201165 split = 'test' ,
202166 ),
0 commit comments