Building asynchronous `future` callback chain from compile-time dependency graph (DAG) Building asynchronous `future` callback chain from compile-time dependency graph (DAG) multithreading multithreading

Building asynchronous `future` callback chain from compile-time dependency graph (DAG)


The easiest way is to start from the entry node of the graph, as if you were writing the code by hand. In order to solve the join problem, you can not use a recursive solution, you need to have a topological ordering of your graph, and then build the graph according to the ordering.

This gives the guarantee that when you build a node all of its predecessors have already been created.

To achieve this goal we can use a DFS, with reverse postordering.

Once you have a topological sorting, you can forget the original node IDs, and refer to nodes with their number in the list. In order to do that you need create a compile time map that allows to retrieve the node predecessors using the node index in the topological sorting instead of the node original node index.


EDIT: Following up on how to implement topological sorting at compile time, I refactored this answer.

To be on the same page I will assume that your graph looks like this:

struct mygraph{     template<int Id>     static constexpr auto successors(node_id<Id>) ->        list< node_id<> ... >; //List of successors for the input node     template<int Id>     static constexpr auto predecessors(node_id<Id>) ->        list< node_id<> ... >; //List of predecessors for the input node     //Get the task associated with the given node.     template<int Id>     static constexpr auto task(node_id<Id>);     using entry_node = node_id<0>;};

Step 1: topological sort

The basic ingredient that you need is a compile time set of node-ids. In TMP a set is also a list, simply because in set<Ids...> the order of the Ids matters. This means that you can use the same data structure to encode the information on whether a node as been visited AND the resulting ordering at the same time.

/** Topological sort using DFS with reverse-postordering **/template<class Graph>struct topological_sort{private:    struct visit;    // If we reach a node that we already visited, do nothing.    template<int Id, int ... Is>    static constexpr auto visit_impl( node_id<Id>,                                      set<Is...> visited,                                      std::true_type )    {        return visited;    }    // This overload kicks in when node has not been visited yet.    template<int Id, int ... Is>    static constexpr auto visit_impl( node_id<Id> node,                                      set<Is...> visited,                                      std::false_type )    {        // Get the list of successors for the current node        constexpr auto succ = Graph::successors(node);        // Reverse postordering: we call insert *after* visiting the successors        // This will call "visit" on each successor, updating the        // visited set after each step.        // Then we insert the current node in the set.        // Notice that if the graph is cyclic we end up in an infinite        // recursion here.        return fold( succ,                     visited,                     visit() ).insert(node);        // Conventional DFS would be:        // return fold( succ, visited.insert(node), visit() );    }    struct visit    {        // Dispatch to visit_impl depending on the result of visited.contains(node)        // Note that "contains" returns a type convertible to        // integral_constant<bool,x>        template<int Id, int ... Is>        constexpr auto operator()( set<Is...> visited, node_id<Id> node ) const        {            return visit_impl(node, visited, visited.contains(node) );        }    };public:    template<int StartNodeId>    static constexpr auto compute( node_id<StartNodeId> node )    {        // Start visiting from the entry node        // The set of visited nodes is initially empty.        // "as_list" converts set<Is ... > to list< node_id<Is> ... >.        return reverse( visit()( set<>{}, node ).as_list() );    }};

This algorithm with the graph from your last example (assuming A = node_id<0>, B = node_id<1>, etc.), produces list<A,B,C,D,E,F>.

Step 2: graph map

This is simply an adapter that modifies the Id of each node in your graph according to a given ordering. So assuming that previous steps returned list<C,D,A,B>, this graph_map would map the index 0 to C, index 1 to D, etc.

template<class Graph, class List>class graph_map{       // Convert a node_id from underlying graph.    // Use a function-object so that it can be passed to algorithms.    struct from_underlying    {         template<int I>        constexpr auto operator()(node_id<I> id)         { return node_id< find(id, List{}) >{}; }    };    struct to_underlying    {         template<int I>        constexpr auto operator()(node_id<I> id)         { return get<I>(List{}); }    };public:            template<int Id>    static constexpr auto successors( node_id<Id> id )    {        constexpr auto orig_id = to_underlying()(id);        constexpr auto orig_succ = Graph::successors( orig_id );        return transform( orig_succ, from_underlying() );    }    template<int Id>    static constexpr auto predecessors( node_id<Id> id )    {        constexpr auto orig_id = to_underlying()(id);        constexpr auto orig_succ = Graph::predecessors( orig_id );        return transform( orig_succ, from_underlying() );    }    template<int Id>    static constexpr auto task( node_id<Id> id )    {        return Graph::task( to_underlying()(id) );    }    using entry_node = decltype( from_underlying()( typename Graph::entry_node{} ) );};

Step 3: assemble the result

We can now iterate over each node id in order. Thanks to the way we built the graph map, we know that all the predecessors of I have a node id which is less than I, for every possible node I.

// Returns a tuple<> of futurestemplate<class GraphMap, class ... Ts>auto make_cont( std::tuple< future<Ts> ... > && pred ){     // The next node to work with is N:     constexpr auto current_node = node_id< sizeof ... (Ts) >();     // Get a list of all the predecessors for the current node.     auto indices = GraphMap::predecessors( current_node );     // "select" is some magic function that takes a tuple of Ts     // and an index_sequence, and returns a tuple of references to the elements      // from the input tuple that are in the indices list.      auto futures = select( pred, indices );     // Assuming you have an overload of when_all that takes a tuple,     // otherwise use C++17 apply.     auto join = when_all( futures );     // Note: when_all with an empty parameter list returns a future< tuple<> >,     // which is always ready.     // In general this has to be a shared_future, but you can avoid that     // by checking if this node has only one successor.     auto next = join.then( GraphMap::task( current_node ) ).share();     // Return a new tuple of futures, pushing the new future at the back.     return std::tuple_cat( std::move(pred),                            std::make_tuple(std::move(next)) );         }// Returns a tuple of futures, you can take the last element if you// know that your DAG has only one leaf, or do some additional // processing to extract only the leaf nodes.template<class Graph>auto make_callback_chain(){    constexpr auto entry_node = typename Graph::entry_node{};    constexpr auto sorted_list =          topological_sort<Graph>::compute( entry_node );    using map = graph_map< Graph, decltype(sorted_list) >;    // Note: we are not really using the "index" in the functor here,     // we only want to call make_cont once for each node in the graph    return fold( sorted_list,                  std::make_tuple(), //Start with an empty tuple                 []( auto && tuple, auto index )                 {                     return make_cont<map>(std::move(tuple));                 } );}

Full live demo


If redundant dependencies may occur, remove them first (see e.g. https://mathematica.stackexchange.com/questions/33638/remove-redundant-dependencies-from-a-directed-acyclic-graph).

Then perform the following graph transformations (building sub-expressions in merged nodes) until you are down to a single node (in a way similar to how you'd calculate a network of resistors):

Graph transformations

*: Additional incoming or outgoing dependencies, depending on placement

(...): Expression in a single node

Java code including setup for your more complex example:

public class DirectedGraph {  /** Set of all nodes in the graph */  static Set<Node> allNodes = new LinkedHashSet<>();  static class Node {    /** Set of all preceeding nodes */    Set<Node> prev = new LinkedHashSet<>();    /** Set of all following nodes */    Set<Node> next = new LinkedHashSet<>();    String value;    Node(String value) {      this.value = value;      allNodes.add(this);    }    void addPrev(Node other) {      prev.add(other);      other.next.add(this);    }    /** Returns one of the next nodes */    Node anyNext() {      return next.iterator().next();    }    /** Merges this node with other, then removes other */    void merge(Node other) {      prev.addAll(other.prev);      next.addAll(other.next);      for (Node on: other.next) {        on.prev.remove(other);        on.prev.add(this);      }      for (Node op: other.prev) {        op.next.remove(other);        op.next.add(this);      }      prev.remove(this);      next.remove(this);      allNodes.remove(other);    }    public String toString() {      return value;    }  }  /**    * Merges sequential or parallel nodes following the given node.   * Returns true if any node was merged.   */  public static boolean processNode(Node node) {    // Check if we are the start of a sequence. Merge if so.    if (node.next.size() == 1 && node.anyNext().prev.size() == 1) {      Node then = node.anyNext();      node.value += " then " + then.value;      node.merge(then);      return true;    }    // See if any of the next nodes has a parallel node with    // the same one level indirect target.     for (Node next : node.next) {      // Nodes must have only one in and out connection to be merged.      if (next.prev.size() == 1 && next.next.size() == 1) {        // Collect all parallel nodes with only one in and out connection         // and the same target; the same source is implied by iterating over         // node.next again.        Node target = next.anyNext().next();        Set<Node> parallel = new LinkedHashSet<Node>();        for (Node other: node.next) {          if (other != next && other.prev.size() == 1             && other.next.size() == 1 && other.anyNext() == target) {            parallel.add(other);          }        }        // If we have found any "parallel" nodes, merge them        if (parallel.size() > 0) {          StringBuilder sb = new StringBuilder("allNodes(");          sb.append(next.value);          for (Node other: parallel) {            sb.append(", ").append(other.value);            next.merge(other);          }          sb.append(")");          next.value = sb.toString();          return true;        }      }    }    return false;  }  public static void main(String[] args) {    Node a = new Node("A");    Node b = new Node("B");    Node c = new Node("C");    Node d = new Node("D");    Node e = new Node("E");    Node f = new Node("F");    f.addPrev(d);    f.addPrev(e);    e.addPrev(a);    d.addPrev(b);    d.addPrev(c);    b.addPrev(a);    c.addPrev(a);    boolean anyChange;    do {      anyChange = false;      for (Node node: allNodes) {        if (processNode(node)) {          anyChange = true;          // We need to leave the inner loop here because changes          // invalidate the for iteration.           break;        }      }      // We are done if we can't find any node to merge.    } while (anyChange);    System.out.println(allNodes.toString());  }}

Output: A then all(E, all(B, C) then D) then F


This seems reasonably easy if you stop thinking about it in form of explicit dependencies and organizing a DAG. Every task can be organized in something like the following (C# because it's so much simpler to explain the idea):

class MyTask{    // a list of all tasks that depend on this to be finished    private readonly ICollection<MyTask> _dependenants;    // number of not finished dependencies of this task    private int _nrDependencies;    public int NrDependencies    {        get { return _nrDependencies; }        private set { _nrDependencies = value; }    }}

If you have a organized your DAG in such a form, the problem is actually really simple: Every Task where _nrDependencies == 0 can be executed. So we need a run method that looks something like the following:

public async Task RunTask(){    // Execute actual code of the task.    var tasks = new List<Task>();    foreach (var dependent in _dependenants)    {        if (Interlocked.Decrement(ref dependent._nrDependencies) == 0)        {            tasks.Add(Task.Run(() => dependent.RunTask()));        }    }    await Task.WhenAll(tasks);}

Basically as soon as our task finished, we go through all our dependents and execute all of those that have no more unfinished dependencies.

To start the whole thing off the only thing you have to do is to call RunTask() for all tasks that have zero dependents to start with (at least one of those must exist since we have a DAG). As soon as all of these tasks have finished, we know that the whole DAG has been executed.