diff --git a/flowrep/workflow.py b/flowrep/workflow.py index 745f7e8..dce2fbe 100644 --- a/flowrep/workflow.py +++ b/flowrep/workflow.py @@ -831,10 +831,14 @@ def _get_missing_edges(edge_list: list[tuple[str, str]]) -> list[tuple[str, str] for tag in edge: if len(tag.split(".")) < 3: continue - if tag.split(".")[1] == "inputs": - new_edge = (tag, tag.split(".")[0]) - elif tag.split(".")[1] == "outputs": - new_edge = (tag.split(".")[0], tag) + assert tag.split(".")[-2] in ["inputs", "outputs"], ( + f"Edge tag {tag} not recognized. " + "Expected format: .(inputs|outputs)." + ) + if tag.split(".")[-2] == "inputs": + new_edge = (tag, tag.rsplit(".", 2)[0]) + elif tag.split(".")[-2] == "outputs": + new_edge = (tag.rsplit(".", 2)[0], tag) if new_edge not in extra_edges: extra_edges.append(new_edge) return edge_list + extra_edges @@ -900,6 +904,9 @@ def get_workflow_graph(workflow_dict: dict[str, Any]) -> nx.DiGraph: nx.DiGraph: A directed graph representing the workflow. """ + def _get_items(d: dict[str, dict]) -> dict[str, dict]: + return {k: v for k, v in d.items() if k not in ["inputs", "outputs"]} + def _to_gnode(node: str) -> str: node_split = node.rsplit(".", 2) if len(node_split) < 2: @@ -923,9 +930,16 @@ def _to_gnode(node: str) -> str: G.add_node(f"outputs@{out}", step="output", **data) nodes_to_delete = [] + if "test" in workflow_dict: + G.add_node("test", step="node", **_get_items(workflow_dict["test"])) + if "iter" in workflow_dict: + G.add_node("iter", step="node", **_get_items(workflow_dict["iter"])) for key, node in workflow_dict["nodes"].items(): - assert node["type"] in ["atomic", "workflow"] - if node["type"] == "workflow": + assert node["type"] in ["atomic", "workflow", "while", "for"], ( + f"Node {key} has unrecognized type {node['type']}. " + "Expected types are 'atomic', 'workflow', 'while', or 'for'." + ) + if node["type"] in ["workflow", "for", "while"]: child_G = get_workflow_graph(node) for child_key in list(child_G.graph.keys()): new_key = f"{key}/{child_key}" if child_key != "" else key @@ -939,11 +953,7 @@ def _to_gnode(node: str) -> str: G = nx.union(nx.relabel_nodes(child_G, mapping), G) nodes_to_delete.append(key) else: - G.add_node( - key, - step="node", - **{k: v for k, v in node.items() if k not in ["inputs", "outputs"]}, - ) + G.add_node(key, step="node", **_get_items(node)) for ii, (inp, data) in enumerate(node.get("inputs", {}).items()): G.add_node(f"{key}:inputs@{inp}", step="input", **({"position": ii} | data)) for ii, (out, data) in enumerate(node.get("outputs", {}).items()): @@ -1083,6 +1093,13 @@ def simple_run(G: nx.DiGraph) -> nx.DiGraph: return G +def _get_edges_in_order(G: nx.DiGraph) -> list[tuple[str, str]]: + node_order = list(nx.topological_sort(G)) + pos = {n: i for i, n in enumerate(node_order)} + + return sorted(G.edges(), key=lambda e: (pos[e[0]], pos[e[1]])) + + class GNode: def __init__(self, key: str): self.key = key @@ -1113,12 +1130,60 @@ def io(self) -> str | None: def arg(self) -> str | None: return self.key.split("@")[-1] if "@" in self.key else None - @property def is_io(self) -> bool: return self.io is not None + def is_global_io(self) -> bool: + return self.is_io() and self.node is None + + +def _graph_to_flat_wf_dict(G: nx.DiGraph) -> dict: + G = flatten_graph(G) + wf_dict = tools.dict_to_recursive_dd( + { + "inputs": tools.dict_to_recursive_dd({}), + "outputs": tools.dict_to_recursive_dd({}), + "nodes": tools.dict_to_recursive_dd({}), + "edges": [], + } + ) + for edge in _get_edges_in_order(G): + orig = GNode(edge[0]) + dest = GNode(edge[1]) + if not orig.is_io() or not dest.is_io(): + continue + wf_dict["edges"].append( + tuple( + ( + ".".join(["/".join(n.node_list), n.io, n.arg]) + if n.node_list + else ".".join([n.io, n.arg]) + ) + for n in [orig, dest] + ) + ) + for node, metadata in list(G.nodes.data()): + gn = GNode(node) + if gn.is_global_io(): + for key, value in metadata.items(): + if key == "step": + continue + wf_dict[gn.io][gn.arg][key] = value + elif gn.is_io(): + for key, value in metadata.items(): + if key == "step": + continue + wf_dict["nodes"]["/".join(gn.node_list)][gn.io][gn.arg][key] = value + else: + for key, value in metadata.items(): + if key == "step": + continue + wf_dict["nodes"]["/".join(gn.node_list)][key] = value + + return tools.recursive_dd_to_dict(wf_dict) + -def graph_to_wf_dict(G: nx.DiGraph) -> dict: +def graph_to_wf_dict(G: nx.DiGraph, flatten: bool = False) -> dict: """ Convert a directed graph representation of a workflow into a workflow dictionary. @@ -1130,13 +1195,15 @@ def graph_to_wf_dict(G: nx.DiGraph) -> dict: dict: The dictionary representation of the workflow. """ wf_dict = tools.dict_to_recursive_dd({}) + if flatten: + return _graph_to_flat_wf_dict(G) for node, metadata in list(G.nodes.data()): gn = GNode(node) d = wf_dict for n in gn.node_list: d = d["nodes"][n] - if gn.is_io: + if gn.is_io(): d[gn.io][gn.arg] = { key: value for key, value in metadata.items() if key != "step" } @@ -1146,7 +1213,7 @@ def graph_to_wf_dict(G: nx.DiGraph) -> dict: for edge in G.edges: orig = GNode(edge[0]) dest = GNode(edge[1]) - if not orig.is_io or not dest.is_io: + if not orig.is_io() or not dest.is_io(): continue if len(orig.node_list) == len(dest.node_list): nodes = orig.node_list[:-1] @@ -1186,3 +1253,28 @@ def graph_to_wf_dict(G: nx.DiGraph) -> dict: for k, v in value.items(): d[k] = v return tools.recursive_dd_to_dict(wf_dict) + + +def flatten_graph(G: nx.DiGraph) -> nx.DiGraph: + H = G.copy() + nodes = [node for node, data in G.nodes.data() if data["step"] == "node"] + ios = [ + io + for n in nodes + for neighbors in [G.predecessors(n), G.successors(n)] + for io in neighbors + ] + for node in G.nodes: + gn = GNode(node) + if not gn.is_io() or node in ios or gn.is_global_io(): + continue + if gn.io == "input": + main_node = list(G.successors(node))[0] + else: + main_node = list(G.predecessors(node))[0] + for k, val in G.nodes[node].items(): + if k not in G.nodes[main_node]: + H.nodes[main_node][k] = val + H = nx.contracted_nodes(H, main_node, node, self_loops=False) + del H.nodes[main_node]["contraction"] + return H diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 17d93a7..f883bef 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -719,6 +719,18 @@ def test_wf_dict_to_graph(self): sorted(wf_dict["nodes"]["example_macro_0"]["edges"]), ) + def test_flattening(self): + wf_dict = example_workflow.serialize_workflow() + wf_dict["inputs"] = {"a": {"value": 1}, "b": {"default": 2}} + G = fwf.get_workflow_graph(wf_dict) + result = fwf.simple_run(G) + wf_dict_flat = fwf.graph_to_wf_dict(G, flatten=True) + G_flat = fwf.get_workflow_graph(wf_dict_flat) + result_flat = fwf.simple_run(G_flat) + self.assertEqual( + result_flat.nodes["outputs@z"]["value"], result.nodes["outputs@z"]["value"] + ) + if __name__ == "__main__": unittest.main()