Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 143 additions & 84 deletions example/transform/google_paper_comparison_model.ipynb

Large diffs are not rendered by default.

109 changes: 73 additions & 36 deletions example/transform/openai_paper_comparison_model.ipynb

Large diffs are not rendered by default.

22 changes: 7 additions & 15 deletions tests/op/basic/test_group_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,23 @@

class TestGroupOp(unittest.TestCase):
def setUp(self):
self.preprocess_fn = lambda nodes_1, nodes_2: [
(
node_label.value_dict["response"][0],
node_summary.value_dict["response"][0],
)
for node_label, node_summary in zip(nodes_1, nodes_2)
]
self.group_fn = lambda labels, summaries: {
label: [s for l, s in zip(labels, summaries) if l == label]
for label in set(labels)
}
self.group_op = GroupOp("test_group", self.preprocess_fn, self.group_fn)
self.group_op = GroupOp("test_group", self.group_fn)

def test_init(self):
self.assertEqual(self.group_op._preprocess_fn, self.preprocess_fn)
self.assertEqual(self.group_op._fn, self.group_fn)

def test_call(self):
node_a0 = Node("node_a0", {"response": ["Introduction"]})
node_a1 = Node("node_a1", {"response": ["Introduction"]})
node_a2 = Node("node_a2", {"response": ["Abstract"]})
node_a0 = Node("node_a0", [Context(context=["Introduction"])])
node_a1 = Node("node_a1", [Context(context=["Introduction"])])
node_a2 = Node("node_a2", [Context(context=["Abstract"])])

node_b0 = Node("node_b0", {"response": ["A paper about life itself"]})
node_b1 = Node("node_b1", {"response": ["Life is complicated"]})
node_b2 = Node("node_b2", {"response": ["Happy wife, happy life"]})
node_b0 = Node("node_b0", [Context(context=["A paper about life itself"])])
node_b1 = Node("node_b1", [Context(context=["Life is complicated"])])
node_b2 = Node("node_b2", [Context(context=["Happy wife, happy life"])])

nodes_1 = [node_a0, node_a1, node_a2]
nodes_2 = [node_b0, node_b1, node_b2]
Expand Down
10 changes: 0 additions & 10 deletions uniflow/flow/transform/transform_comparison_google_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
prompt_template (PromptTemplate): Guided prompt template.
model_config (Dict[str, Any]): Model config.
"""
# TODO: Refactoring needed to make model_op output Context format. Need to keep it in Context format and only convert back to dictionary format before exiting Flow
super().__init__()

# Expand list of nodes to two or more nodes
Expand All @@ -47,7 +46,6 @@ def __init__(
],
)

# TODO: Refactoring needed to make model_op output Context format
# Add label
label_prompt_template = PromptTemplate(
instruction="""
Expand All @@ -64,7 +62,6 @@ def __init__(
),
)

# TODO: Refactoring needed to make model_op output Context format
# Summarize
summary_prompt_template = PromptTemplate(
instruction="""
Expand All @@ -84,13 +81,6 @@ def __init__(
# Group summaries by label
self._group = GroupOp(
name="summaries_groupby_labels",
preprocss_fn=lambda nodes_1, nodes_2: [
(
node_label.value_dict["response"][0],
node_summary.value_dict["response"][0],
)
for node_label, node_summary in zip(nodes_1, nodes_2)
],
fn=lambda labels, summaries: {
label: [s for l, s in zip(labels, summaries) if l == label]
for label in set(labels)
Expand Down
12 changes: 1 addition & 11 deletions uniflow/flow/transform/transform_comparison_openai_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
prompt_template (PromptTemplate): Guided prompt template.
model_config (Dict[str, Any]): Model config.
"""
# TODO: Refactoring needed to make model_op output Context format. Need to keep it in Context format and only convert back to dictionary format before exiting Flow
super().__init__()
if model_config["response_format"]["type"] == "json_object":
model = JsonLmModel(
Expand All @@ -52,12 +51,11 @@ def __init__(
name="split_to_chunks",
fn=lambda markdown_content: [
[Context(context=item.strip())]
for item in re.split(r"\n\s*\n", markdown_content[0].Context)
for item in re.split(r"\n\s*\n", markdown_content[0].context)
if item.strip()
],
)

# TODO: Refactoring needed to make model_op output Context format
# Add label
label_prompt_template = PromptTemplate(
instruction="""
Expand Down Expand Up @@ -104,7 +102,6 @@ def __init__(
),
)

# TODO: Refactoring needed to make model_op output Context format
# Summarize
summary_prompt_template = PromptTemplate(
instruction="""
Expand All @@ -124,13 +121,6 @@ def __init__(
# Group summaries by label
self._group = GroupOp(
name="summaries_groupby_labels",
preprocss_fn=lambda nodes_1, nodes_2: [
(
node_label.value_dict["response"][0],
node_summary.value_dict["response"][0],
)
for node_label, node_summary in zip(nodes_1, nodes_2)
],
fn=lambda labels, summaries: {
label: [s for l, s in zip(labels, summaries) if l == label]
for label in set(labels)
Expand Down
23 changes: 13 additions & 10 deletions uniflow/op/basic/group_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,25 @@ class GroupOp(Op):
def __init__(
self,
name: str,
preprocss_fn: Callable[
[Mapping[str, Any], Mapping[str, Any]], Mapping[str, Any]
],
fn: Callable[[Mapping[str, Any], Mapping[str, Any]], Mapping[str, Any]],
given_fixed_labels: Optional[list] = None,
) -> None:
"""Initializes group operation.

Args:
name (str): Name of the group operation.
preprocss_fn (callable): Function to extract.
fn (callable): Function to group.
given_fixed_labels (Optional[list]) : A list of fixed, provided labels to help handle exceptions if there are no content for certain labels
"""
super().__init__(name)
self._fn = fn
self._preprocess_fn = preprocss_fn
self._given_fixed_labels = given_fixed_labels if given_fixed_labels else []

def __call__(
self, nodes_1: Sequence[Node], nodes_2: Sequence[Node]
) -> Sequence[Node]:
"""Calls group operation.
The (preprocss_fn) preprocess function will first extract information such as label and summary out of node's dictionary.
Then (fn) function would groub by summaries based on their labels.
Then (fn) function would groub by summaries from nodes_2 based on their labels from nodes_1.
The result would be a list of nodes where each node's dictionary is a sum of summaries of nodes with same label.
If given_fixed_labels is provided, labels with no summaries will still be included in the result.

Expand All @@ -49,7 +43,16 @@ def __call__(
"""
output_nodes = []

labels, summaries = zip(*self._preprocess_fn(nodes_1, nodes_2))
labels, summaries = zip(
*[
(
node_label.value_dict[0].context[0],
node_summary.value_dict[0].context[0],
)
for node_label, node_summary in zip(nodes_1, nodes_2)
]
)

aggregated_summaries = self._fn(labels, summaries)
sorted_labels = sorted(aggregated_summaries.keys())

Expand All @@ -63,7 +66,7 @@ def __call__(
label_nodes = {label: [] for label in sorted_labels}

for node in nodes_1:
label = node.value_dict["response"][0]
label = node.value_dict[0].context[0]
if label in label_nodes:
label_nodes[label].append(node)

Expand All @@ -76,7 +79,7 @@ def __call__(
prev_nodes = label_nodes[label]

for node in nodes_2:
if node.value_dict["response"][0] in summary_list:
if node.value_dict[0].context[0] in summary_list:
prev_nodes.append(node)

output_nodes.append(
Expand Down
6 changes: 2 additions & 4 deletions uniflow/op/model/lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _serialize(self, data: List[Context]) -> List[str]:
List[str]: Serialized data.
"""
output = []

for d in data:
if not isinstance(d, Context):
raise ValueError("Input data must be a Context object.")
Expand Down Expand Up @@ -58,10 +59,7 @@ def _deserialize(self, data: List[str]) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: Deserialized data.
"""
return {
RESPONSE: data,
ERROR: "No errors.",
}
return [Context(context=data)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: why are you using a list here instead of directly returning Context(context=data)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this because LmModel._serialize takes in List[Context]



class JsonLmModel(AbsModel):
Expand Down