Skip to content

Commit 7d92d48

Browse files
authored
Merge pull request #21 from ComplexData-MILA/dev/domain_spec
Add domain component
2 parents 8695c5e + cecdfbe commit 7d92d48

File tree

7 files changed

+585
-16
lines changed

7 files changed

+585
-16
lines changed

aif_gen/task/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from aif_gen.task.alignment_task import AlignmentTask
22
from aif_gen.task.domain import Domain
3+
from aif_gen.task.domain_component import DomainComponent

aif_gen/task/alignment_task.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ def __init__(self, domain: Domain, objective: str, preference: str) -> None:
1919

2020
@classmethod
2121
def from_dict(cls, task_dict: Dict[str, Any]) -> 'AlignmentTask':
22+
r"""Construct an AlignmentTask from dictionary represenetation.
23+
24+
Note:
25+
Expects 'domain', 'objective', and 'preference' keys to be present in the dictionary.
26+
Moreover, expects that the 'domain' value is parasable by Domain.from_dict().
27+
28+
Args:
29+
task_dict (Dict[str, Any]): The dictionary that encodes the AlignmentTask.
30+
31+
Returns:
32+
AlignmentTask: The newly constructed alignmentTask
33+
34+
Raises:
35+
ValueError: If the input dictionary is missing any required keys.
36+
"""
2237
expected_keys = {'domain', 'objective', 'preference'}
2338
missing_keys = expected_keys - set(task_dict)
2439
if len(missing_keys):
@@ -29,17 +44,34 @@ def from_dict(cls, task_dict: Dict[str, Any]) -> 'AlignmentTask':
2944
preference = task_dict['preference']
3045
return cls(domain, objective, preference)
3146

47+
def to_dict(self) -> Dict[str, Any]:
48+
r"""Convert the AlignmentTask to dictionary represenetation.
49+
50+
Note: This method is the functional inverse of AlignmentTask.from_dict().
51+
52+
Returns:
53+
Dict[str, Any]: The dictionary representation of the alignmentTask.
54+
"""
55+
return {
56+
'domain': self.domain.to_dict(),
57+
'objective': self.objective,
58+
'preference': self.preference,
59+
}
60+
3261
def __str__(self) -> str:
3362
return f'AlignmentTask({self.domain}, Objective: {self.objective}, Preference: {self.preference})'
3463

3564
@property
3665
def domain(self) -> Domain:
66+
"""Domain: The domain in the current AlignmentTask."""
3767
return self._domain
3868

3969
@property
4070
def objective(self) -> str:
71+
"""str: The objective in the current AlignmentTask."""
4172
return self._objective
4273

4374
@property
4475
def preference(self) -> str:
76+
"""str: The preference in the current AlignmentTask."""
4577
return self._preference

aif_gen/task/domain.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,112 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, List, Optional
2+
3+
from .domain_component import DomainComponent
24

35

46
class Domain:
5-
def __init__(self, domain: str) -> None:
6-
self._domain = domain
7+
r"""A domain is a combination of DomainComponents along with a weight for each component.
8+
9+
Note: We do not enforce that the weights be normalized.
10+
11+
Args:
12+
components (List[DomainComponent]): List of DomainComponents that constitute the domain.
13+
weights (Optional[List[float]]): Weights given to each constituent component (uniform if not specified).
14+
15+
Raises:
16+
ValueError: If the list of components is empty.
17+
ValueError: If the number of weights matches the number of components.
18+
ValueError: If any of the provided weights are negative.
19+
"""
20+
21+
def __init__(
22+
self, components: List[DomainComponent], weights: Optional[List[float]] = None
23+
) -> None:
24+
if not len(components):
25+
raise ValueError(
26+
'Cannot initialize a Domain with an empty list of DomainComponents'
27+
)
28+
29+
if weights is None:
30+
weights = [1 / len(components)] * len(components)
31+
32+
if len(weights) != len(components):
33+
raise ValueError(
34+
f'Number of components and weights must match, but got {len(components)} components and {len(weights)} weights'
35+
)
36+
for i, weight in enumerate(weights):
37+
if weight < 0:
38+
raise ValueError(
39+
f'Got a negative weight for component: {components[i]}'
40+
)
41+
42+
self._components = components
43+
self._weights = weights
744

845
@classmethod
9-
def from_dict(cls, domain_dict: Dict[str, Any]) -> 'Domain':
10-
return cls('Mock Domain')
46+
def from_dict(cls, domain_dict: Dict[str, Dict[str, Any]]) -> 'Domain':
47+
r"""Construct an Domain from dictionary represenetation.
48+
49+
Note:
50+
Expects each key to denote the 'name' of a DomainComponent, and each value
51+
to be a dictionary that is parsable by DomainComponent.from_dict().
52+
53+
Moreover, each value should include a ('weight': weight_value: float) key-value
54+
pair that encodes the weight for that given DomainComponent.
55+
56+
If these 'weight' keys are no present, the uniform weight initialization is used.
57+
58+
Args:
59+
domain_dict(Dict[str, Dict[str, Any]]): The dictionary that encodes the Domain.
60+
61+
Returns:
62+
Domain: The newly constructed Domain.
63+
64+
Raises:
65+
ValueError: If the input dictionary is missing any required keys.
66+
"""
67+
components, component_weights = [], []
68+
for component_name, component_args in domain_dict.items():
69+
component_args['name'] = component_name
70+
components.append(DomainComponent.from_dict(component_args))
71+
72+
if 'weight' in component_args:
73+
component_weights.append(component_args['weight'])
74+
75+
weights = None if not len(component_weights) else component_weights
76+
return cls(components, weights)
77+
78+
def to_dict(self) -> Dict[str, Any]:
79+
r"""Convert the Domain to dictionary represenetation.
80+
81+
Note: This method is the functional inverse of Domain.from_dict().
82+
83+
Returns:
84+
Dict[str, Any]: The dictionary representation of the Domain.
85+
"""
86+
domain_dict = {}
87+
for i in range(self.num_components):
88+
domain_dict[self.components[i].name] = self.components[i].to_dict()
89+
domain_dict[self.components[i].name]['weight'] = self.weights[i]
90+
return domain_dict
1191

1292
def __str__(self) -> str:
13-
return f'Domain: {self._domain}'
93+
s = f'Domain: ['
94+
for i in range(self.num_components):
95+
s += f'({self.components[i]}, weight={self.weights[i]:.2f}), '
96+
s += ']'
97+
return s
98+
99+
@property
100+
def components(self) -> List[DomainComponent]:
101+
"""List[DomainComponent]: The list of components associated with this Domain."""
102+
return self._components
103+
104+
@property
105+
def num_components(self) -> int:
106+
"""int: The number of components associated with this Domain."""
107+
return len(self.components)
14108

15109
@property
16-
def domain(self) -> str:
17-
return self._domain
110+
def weights(self) -> List[float]:
111+
"""List[float]: The weights associated with this Domain."""
112+
return self._weights

aif_gen/task/domain_component.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
4+
class DomainComponent:
5+
r"""A domain component is an alias for a set of 'seed words' that describe a
6+
specific sphere of activity or knowledge.
7+
8+
Args:
9+
name (str): The name that describes the content of the domain component.
10+
seed_words (List[str]): The list of seed words that describe the domain component.
11+
description (Optional[str]): An optional description of the domain component.
12+
13+
Raises:
14+
ValueError: If no seed words are provided.
15+
"""
16+
17+
def __init__(
18+
self,
19+
name: str,
20+
seed_words: List[str],
21+
description: Optional[str] = None,
22+
) -> None:
23+
if not len(seed_words):
24+
raise ValueError(
25+
'Cannot initialize a DomainComponent with an empty list of seed words'
26+
)
27+
28+
self._name = name
29+
self._seed_words = seed_words
30+
self._description = description
31+
32+
@classmethod
33+
def from_dict(cls, component_dict: Dict[str, Any]) -> 'DomainComponent':
34+
r"""Construct an AlignmentTask from dictionary represenetation.
35+
36+
Note:
37+
Expects 'name', and 'seed_words' keys to be present in the dictionary.
38+
39+
Args:
40+
component_dict(Dict[str, Any]): The dictionary that encodes the DomainComponent.
41+
42+
Returns:
43+
DomainComponent: The newly constructed DomainComponent.
44+
45+
Raises:
46+
ValueError: If the input dictionary is missing any required keys.
47+
"""
48+
expected_keys = {'name', 'seed_words'}
49+
missing_keys = expected_keys - set(component_dict)
50+
if len(missing_keys):
51+
raise ValueError(f'Missing required keys: {missing_keys}')
52+
53+
name = component_dict['name']
54+
seed_words = component_dict['seed_words']
55+
description = component_dict.get('description')
56+
return cls(name, seed_words, description)
57+
58+
def to_dict(self) -> Dict[str, Any]:
59+
r"""Convert the DomainComponent to dictionary represenetation.
60+
61+
Note: This method is the functional inverse of DomainComponent.from_dict().
62+
63+
Returns:
64+
Dict[str, Any]: The dictionary representation of the DomainComponent.
65+
"""
66+
component_dict = {'name': self.name, 'seed_words': self.seed_words}
67+
if self.description is not None:
68+
component_dict['description'] = self.description
69+
return component_dict
70+
71+
def __str__(self) -> str:
72+
s = f'{self._name} '
73+
if self.description is not None:
74+
s += f'({self.description}) '
75+
76+
# Truncate number of seed words to first 3 to avoid spamming output stream
77+
if len(self.seed_words) > 3:
78+
s += str(self._seed_words[:3])[:-1] + ', ...]'
79+
else:
80+
s += str(self._seed_words)
81+
82+
return s
83+
84+
@property
85+
def name(self) -> str:
86+
"""str: The name of this DomainComponent."""
87+
return self._name
88+
89+
@property
90+
def seed_words(self) -> List[str]:
91+
"""List[str]: The list of seed words aliased by this DomainComponent."""
92+
return self._seed_words
93+
94+
@property
95+
def num_seed_words(self) -> int:
96+
"""int: The number of seed words aliased by this DomainComponent."""
97+
return len(self.seed_words)
98+
99+
@property
100+
def description(self) -> Optional[str]:
101+
"""Optional[str]: The description in the current DomainComponent, if it exists."""
102+
return self._description

test/test_alignment_task.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,101 @@
44

55

66
def test_init():
7-
domain = Domain('Mock Domain')
7+
component_dict = {
8+
'Component A': {
9+
'seed_words': ['a_foo', 'a_bar', 'a_baz'],
10+
'description': 'A Mock Domain Component',
11+
},
12+
'Component B': {
13+
'seed_words': ['b_foo', 'b_bar', 'b_baz'],
14+
'description': 'B Mock Domain Component',
15+
},
16+
}
17+
domain = Domain.from_dict(component_dict)
818
objective = 'Mock Objective'
919
preference = 'Mock Preference'
1020

1121
task = AlignmentTask(domain, objective, preference)
12-
exp_str = 'AlignmentTask(Domain: Mock Domain, Objective: Mock Objective, Preference: Mock Preference)'
22+
exp_str = f'AlignmentTask({str(domain)}, Objective: Mock Objective, Preference: Mock Preference)'
1323
assert str(task) == exp_str
1424

1525

1626
def test_init_from_dict():
1727
task_dict = {
18-
'domain': 'Mock Domain',
28+
'domain': {
29+
'Component A': {
30+
'seed_words': ['a_foo', 'a_bar', 'a_baz'],
31+
'description': 'A Mock Domain Component',
32+
},
33+
'Component B': {
34+
'seed_words': ['b_foo', 'b_bar', 'b_baz'],
35+
'description': 'B Mock Domain Component',
36+
},
37+
},
1938
'objective': 'Mock Objective',
2039
'preference': 'Mock Preference',
2140
}
2241

2342
task = AlignmentTask.from_dict(task_dict)
24-
exp_str = 'AlignmentTask(Domain: Mock Domain, Objective: Mock Objective, Preference: Mock Preference)'
43+
domain = Domain.from_dict(task_dict['domain'])
44+
exp_str = f'AlignmentTask({str(domain)}, Objective: Mock Objective, Preference: Mock Preference)'
2545
assert str(task) == exp_str
2646

2747

2848
def test_init_from_dict_missing_keys():
29-
task_dict = { # Missing 'preference' key
30-
'domain': 'Mock Domain',
49+
task_dict = { # Missing 'domain' key
3150
'objective': 'Mock Objective',
51+
'preference': 'Mock Preference',
3252
}
3353

3454
with pytest.raises(ValueError):
3555
_ = AlignmentTask.from_dict(task_dict)
56+
57+
58+
def test_to_dict_no_weights():
59+
task_dict = {
60+
'domain': {
61+
'Component A': {
62+
'seed_words': ['a_foo', 'a_bar', 'a_baz'],
63+
'description': 'A Mock Domain Component',
64+
},
65+
'Component B': {
66+
'seed_words': ['b_foo', 'b_bar', 'b_baz'],
67+
'description': 'B Mock Domain Component',
68+
},
69+
},
70+
'objective': 'Mock Objective',
71+
'preference': 'Mock Preference',
72+
}
73+
74+
# Note: We automatically add uniform weights to the domain if they were not specified
75+
expected_dict = task_dict
76+
expected_dict['domain']['Component A']['weight'] = 0.5
77+
expected_dict['domain']['Component B']['weight'] = 0.5
78+
79+
task = AlignmentTask.from_dict(task_dict)
80+
assert expected_dict == task.to_dict()
81+
82+
83+
def test_to_dict_with_weights():
84+
task_dict = {
85+
'domain': {
86+
'Component A': {
87+
'weight': 0.3,
88+
'seed_words': ['a_foo', 'a_bar', 'a_baz'],
89+
'description': 'A Mock Domain Component',
90+
},
91+
'Component B': {
92+
'weight': 0.7,
93+
'seed_words': ['b_foo', 'b_bar', 'b_baz'],
94+
'description': 'B Mock Domain Component',
95+
},
96+
},
97+
'objective': 'Mock Objective',
98+
'preference': 'Mock Preference',
99+
}
100+
101+
expected_dict = task_dict
102+
103+
task = AlignmentTask.from_dict(task_dict)
104+
assert expected_dict == task.to_dict()

0 commit comments

Comments
 (0)