Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 107 additions & 15 deletions flowrep/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,10 +831,14 @@ def _get_missing_edges(edge_list: list[tuple[str, str]]) -> list[tuple[str, str]
for tag in edge:
if len(tag.split(".")) < 3:
continue
if tag.split(".")[1] == "inputs":
new_edge = (tag, tag.split(".")[0])
elif tag.split(".")[1] == "outputs":
new_edge = (tag.split(".")[0], tag)
assert tag.split(".")[-2] in ["inputs", "outputs"], (
f"Edge tag {tag} not recognized. "
"Expected format: <function>.(inputs|outputs).<variable>"
)
if tag.split(".")[-2] == "inputs":
new_edge = (tag, tag.rsplit(".", 2)[0])
elif tag.split(".")[-2] == "outputs":
new_edge = (tag.rsplit(".", 2)[0], tag)
if new_edge not in extra_edges:
extra_edges.append(new_edge)
return edge_list + extra_edges
Expand Down Expand Up @@ -900,6 +904,9 @@ def get_workflow_graph(workflow_dict: dict[str, Any]) -> nx.DiGraph:
nx.DiGraph: A directed graph representing the workflow.
"""

def _get_items(d: dict[str, dict]) -> dict[str, dict]:
return {k: v for k, v in d.items() if k not in ["inputs", "outputs"]}

def _to_gnode(node: str) -> str:
node_split = node.rsplit(".", 2)
if len(node_split) < 2:
Expand All @@ -923,9 +930,16 @@ def _to_gnode(node: str) -> str:
G.add_node(f"outputs@{out}", step="output", **data)

nodes_to_delete = []
if "test" in workflow_dict:
G.add_node("test", step="node", **_get_items(workflow_dict["test"]))
if "iter" in workflow_dict:
G.add_node("iter", step="node", **_get_items(workflow_dict["iter"]))
for key, node in workflow_dict["nodes"].items():
assert node["type"] in ["atomic", "workflow"]
if node["type"] == "workflow":
assert node["type"] in ["atomic", "workflow", "while", "for"], (
f"Node {key} has unrecognized type {node['type']}. "
"Expected types are 'atomic', 'workflow', 'while', or 'for'."
)
if node["type"] in ["workflow", "for", "while"]:
child_G = get_workflow_graph(node)
for child_key in list(child_G.graph.keys()):
new_key = f"{key}/{child_key}" if child_key != "" else key
Expand All @@ -939,11 +953,7 @@ def _to_gnode(node: str) -> str:
G = nx.union(nx.relabel_nodes(child_G, mapping), G)
nodes_to_delete.append(key)
else:
G.add_node(
key,
step="node",
**{k: v for k, v in node.items() if k not in ["inputs", "outputs"]},
)
G.add_node(key, step="node", **_get_items(node))
for ii, (inp, data) in enumerate(node.get("inputs", {}).items()):
G.add_node(f"{key}:inputs@{inp}", step="input", **({"position": ii} | data))
for ii, (out, data) in enumerate(node.get("outputs", {}).items()):
Expand Down Expand Up @@ -1083,6 +1093,13 @@ def simple_run(G: nx.DiGraph) -> nx.DiGraph:
return G


def _get_edges_in_order(G: nx.DiGraph) -> list[tuple[str, str]]:
node_order = list(nx.topological_sort(G))
pos = {n: i for i, n in enumerate(node_order)}

return sorted(G.edges(), key=lambda e: (pos[e[0]], pos[e[1]]))


class GNode:
def __init__(self, key: str):
self.key = key
Expand Down Expand Up @@ -1113,12 +1130,60 @@ def io(self) -> str | None:
def arg(self) -> str | None:
return self.key.split("@")[-1] if "@" in self.key else None

@property
def is_io(self) -> bool:
return self.io is not None

def is_global_io(self) -> bool:
return self.is_io() and self.node is None


def _graph_to_flat_wf_dict(G: nx.DiGraph) -> dict:
G = flatten_graph(G)
wf_dict = tools.dict_to_recursive_dd(
{
"inputs": tools.dict_to_recursive_dd({}),
"outputs": tools.dict_to_recursive_dd({}),
"nodes": tools.dict_to_recursive_dd({}),
"edges": [],
}
)
for edge in _get_edges_in_order(G):
orig = GNode(edge[0])
dest = GNode(edge[1])
if not orig.is_io() or not dest.is_io():
continue
wf_dict["edges"].append(
tuple(
(
".".join(["/".join(n.node_list), n.io, n.arg])
if n.node_list
else ".".join([n.io, n.arg])
)
for n in [orig, dest]
)
)
for node, metadata in list(G.nodes.data()):
gn = GNode(node)
if gn.is_global_io():
for key, value in metadata.items():
if key == "step":
continue
wf_dict[gn.io][gn.arg][key] = value
elif gn.is_io():
for key, value in metadata.items():
if key == "step":
continue
wf_dict["nodes"]["/".join(gn.node_list)][gn.io][gn.arg][key] = value
else:
for key, value in metadata.items():
if key == "step":
continue
wf_dict["nodes"]["/".join(gn.node_list)][key] = value

return tools.recursive_dd_to_dict(wf_dict)


def graph_to_wf_dict(G: nx.DiGraph) -> dict:
def graph_to_wf_dict(G: nx.DiGraph, flatten: bool = False) -> dict:
"""
Convert a directed graph representation of a workflow into a workflow
dictionary.
Expand All @@ -1130,13 +1195,15 @@ def graph_to_wf_dict(G: nx.DiGraph) -> dict:
dict: The dictionary representation of the workflow.
"""
wf_dict = tools.dict_to_recursive_dd({})
if flatten:
return _graph_to_flat_wf_dict(G)

for node, metadata in list(G.nodes.data()):
gn = GNode(node)
d = wf_dict
for n in gn.node_list:
d = d["nodes"][n]
if gn.is_io:
if gn.is_io():
d[gn.io][gn.arg] = {
key: value for key, value in metadata.items() if key != "step"
}
Expand All @@ -1146,7 +1213,7 @@ def graph_to_wf_dict(G: nx.DiGraph) -> dict:
for edge in G.edges:
orig = GNode(edge[0])
dest = GNode(edge[1])
if not orig.is_io or not dest.is_io:
if not orig.is_io() or not dest.is_io():
continue
if len(orig.node_list) == len(dest.node_list):
nodes = orig.node_list[:-1]
Expand Down Expand Up @@ -1186,3 +1253,28 @@ def graph_to_wf_dict(G: nx.DiGraph) -> dict:
for k, v in value.items():
d[k] = v
return tools.recursive_dd_to_dict(wf_dict)


def flatten_graph(G: nx.DiGraph) -> nx.DiGraph:
H = G.copy()
nodes = [node for node, data in G.nodes.data() if data["step"] == "node"]
ios = [
io
for n in nodes
for neighbors in [G.predecessors(n), G.successors(n)]
for io in neighbors
]
for node in G.nodes:
gn = GNode(node)
if not gn.is_io() or node in ios or gn.is_global_io():
continue
if gn.io == "input":
main_node = list(G.successors(node))[0]
else:
main_node = list(G.predecessors(node))[0]
for k, val in G.nodes[node].items():
if k not in G.nodes[main_node]:
H.nodes[main_node][k] = val
H = nx.contracted_nodes(H, main_node, node, self_loops=False)
del H.nodes[main_node]["contraction"]
return H
12 changes: 12 additions & 0 deletions tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,18 @@ def test_wf_dict_to_graph(self):
sorted(wf_dict["nodes"]["example_macro_0"]["edges"]),
)

def test_flattening(self):
wf_dict = example_workflow.serialize_workflow()
wf_dict["inputs"] = {"a": {"value": 1}, "b": {"default": 2}}
G = fwf.get_workflow_graph(wf_dict)
result = fwf.simple_run(G)
wf_dict_flat = fwf.graph_to_wf_dict(G, flatten=True)
G_flat = fwf.get_workflow_graph(wf_dict_flat)
result_flat = fwf.simple_run(G_flat)
self.assertEqual(
result_flat.nodes["outputs@z"]["value"], result.nodes["outputs@z"]["value"]
)


if __name__ == "__main__":
unittest.main()
Loading