diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 95616ebaf29..45af055ff8f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -4443,6 +4443,7 @@ tf_cc_tests( "graph/optimizer_cse_test.cc", "graph/optimizer_fusion_engine_test.cc", "graph/star_server_graph_partition_test.cc", + "graph/stream_subgraph_test.cc", "graph/subgraph_test.cc", "graph/tensor_id_test.cc", "graph/validate_test.cc", diff --git a/tensorflow/core/graph/stream_subgraph.cc b/tensorflow/core/graph/stream_subgraph.cc index 56fe090470c..07559e2132e 100644 --- a/tensorflow/core/graph/stream_subgraph.cc +++ b/tensorflow/core/graph/stream_subgraph.cc @@ -22,6 +22,8 @@ limitations under the License. namespace tensorflow { namespace stream_subgraph { +using DAG = std::vector>; +using Bigraph = std::vector>; namespace { @@ -43,52 +45,212 @@ std::string GetDeviceNamePrefix(const std::string& device_name) { return device_name_prefix; } -} // namesapce +DAG GraphToDAG(const Graph* g) { + DAG dag; + dag.resize(g->num_node_ids()); + for (auto node : g->nodes()) { + for (auto edge : node->out_edges()) { + int dst_id = edge->dst()->id(); + dag[node->id()].push_back(dst_id); + } + } -void MarkStreamSubGraph(Graph* g, const MultiStreamOptions& opt) { - int num_streams = opt.multi_stream_num(); - MultiStreamPartitionPolicy policy = opt.partition_policy(); + return dag; +} - if (policy == MultiStreamPartitionPolicy::EMBEDDING_GRAPH_PARTITION) { - MarkEmbeddingGraph(g, num_streams); +void DFS(int curr, const DAG& graph, + std::vector& visited) { + visited[curr] = true; + const std::vector& adjacent_nodes = graph[curr]; + for (auto n : adjacent_nodes) { + if (!visited[n]) { + DFS(n, graph, visited); + } } } -void MarkEmbeddingGraph(Graph* g, int num_streams) { - bool train_graph = false; +// TODO: Optimize the algorithm +std::vector> GetReachableNodes(const DAG& dag) { + std::vector> reachable_nodes; + int num_nodes = dag.size(); + for (int i = 0; i < num_nodes; i++) { + std::vector reachable(num_nodes, false); + DFS(i, dag, reachable); + reachable[i] = false; + reachable_nodes.push_back(std::move(reachable)); + } + + return reachable_nodes; +} + +// Get minimum equivalent graph +DAG GetMEG(const DAG& dag) { + const auto& reachable_nodes = GetReachableNodes(dag); + int num_nodes = dag.size(); + DAG meg = dag; + for (int i = 0; i < num_nodes; i++) { + auto& meg_child_nodes = meg[i]; + auto& child_nodes = dag[i]; + for (auto child : child_nodes) { + if (std::find(meg_child_nodes.begin(), + meg_child_nodes.end(), child) == + meg_child_nodes.end()) { + continue; + } + for (auto another : child_nodes) { + if (reachable_nodes[child][another]) { + auto it = std::find(meg_child_nodes.begin(), + meg_child_nodes.end(), another); + if (it != meg_child_nodes.end()) { + meg_child_nodes.erase(it); + } + } + } + } + } + + return meg; +} + +Bigraph MEGToBigraph(const DAG& meg) { + Bigraph bigraph; + int num_nodes = meg.size(); + for (int i = 0; i < num_nodes; i++) { + std::vector adjacency(num_nodes, false); + for (auto child : meg[i]) { + adjacency[child] = true; + } + bigraph.push_back(std::move(adjacency)); + } + + return bigraph; +} + +Bigraph DAGToBigraph(const DAG& dag) { + Bigraph bigraph; + int num_nodes = dag.size(); + for (int i = 0; i < num_nodes; i++) { + std::vector reachable(num_nodes, false); + DFS(i, dag, reachable); + reachable[i] = false; + bigraph.push_back(std::move(reachable)); + } + + return bigraph; +} + +DAG BuildStreamDAG( + const DAG& dag, + const std::vector>& stream_chains) { + const auto& reachable_nodes = GetReachableNodes(dag); + DAG stream_dag; + for (int i = 0; i < stream_chains.size(); i++) { + std::vector ensuing_streams; + auto chain_end = stream_chains[i][1]; + for (int j = 0; j < stream_chains.size(); j++) { + auto chain_begin = stream_chains[j][0]; + if (reachable_nodes[chain_end][chain_begin]) { + ensuing_streams.push_back(j); + } + } + stream_dag.push_back(ensuing_streams); + } + + return stream_dag; +} + +bool FindMatching(int start, const Bigraph& graph, + std::vector& visited, + std::vector& match_status) { + int num = graph[0].size(); + for (int i = 0; i < num; i++) { + if (graph[start][i] && !visited[i]) { + visited[i] = true; + int curr_match = match_status[i]; + if (match_status[i] == -1 || + FindMatching(curr_match, graph, visited, match_status)) { + match_status[i] = start; + return true; + } + } + } + + return false; +} + +std::vector MaximumMatching(const Bigraph& graph) { + int num = graph[0].size(); + std::vector match_result(num, -1); + int num_bigraph = graph.size(); + for (int i = 0; i < num_bigraph; i++) { + std::vector visited(num, false); + FindMatching(i, graph, visited, match_result); + } + + return match_result; +} + +std::tuple, std::vector>, int> +GetMapping(const std::vector& matching) { + int num_nodes = matching.size(); + std::vector> chains; + for(int i = 0; i < num_nodes; i++) { + auto it = std::find(matching.begin(), matching.end(), i); + if (it == matching.end()) { + chains.push_back({i, i}); + } + } + + int group_num = 0; + std::vector mapping(num_nodes, -1); + for (auto& chain : chains) { + int group_id = group_num++; + int curr = chain[1]; + while (true) { + mapping[curr] = group_id; + if (matching[curr] == -1) { + chain[0] = curr; + break; + } else { + curr = matching[curr]; + } + } + } + + return std::make_tuple(mapping, chains, group_num); +} + +} // namesapce + +void MarkStreamSubGraph(Graph* g, const MultiStreamOptions& opt) { // trained graph if (!g->IsTrainingGraph()) { return; } - //for (Node* n : g->nodes()) { // if (n->type_string() == "IsVariableInitialized" && // n->name() != "global_step/IsVariableInitialized") { - // return; + // return; // } //} + int num_streams = opt.multi_stream_num(); + MultiStreamPartitionPolicy policy = opt.partition_policy(); + if (policy == MultiStreamPartitionPolicy::EMBEDDING_GRAPH_PARTITION) { + MarkEmbeddingGraph(g, num_streams); + } else if (policy == MultiStreamPartitionPolicy::FULL_GRAPH_PARTITION) { + MarkFullGraph(g, num_streams); + } else { + // Unrecognized policy + return; + } + std::unordered_map name_to_node; - // User marked subgraph for (Node* n : g->nodes()) { name_to_node[n->name()] = n; - - if (n->assigned_device_name().find("device:GPU:") == std::string::npos || - n->def().attr().find("_stream_id") == n->def().attr().end()) { - continue; - } - - int stream_id = n->def().attr().at("_stream_id").i(); - std::string required_device = - GetDeviceNamePrefix(n->assigned_device_name()) + - std::to_string(stream_id); - if (n->assigned_device_name() != required_device) { - n->set_assigned_device_name(required_device); - } } - // Colocate nodes std::unordered_map> node_colocate_childs; std::unordered_set colocate_nodes; @@ -130,7 +292,6 @@ void MarkEmbeddingGraph(Graph* g, int num_streams) { continue; } - //std::vector edges_to_delete; std::vector in_edges(n->in_edges().begin(), n->in_edges().end()); for (const Edge* e : in_edges) { @@ -161,5 +322,76 @@ void MarkEmbeddingGraph(Graph* g, int num_streams) { } } +// Return stream id vector which indexed by node id +std::vector GenerateNodeStreamId(const Graph* graph) { + // Assign stream id nodes. + const auto& dag = GraphToDAG(graph); + const auto& meg = GetMEG(dag); + const auto& bigraph = MEGToBigraph(meg); + const auto& matching = MaximumMatching(bigraph); + const auto& result = GetMapping(matching); + std::vector node_to_chain = std::get<0>(result); + + // Rematching stream, some streams can have the same id. + const auto& stream_chains = std::get<1>(result); + const auto& stream_dag = BuildStreamDAG(meg, stream_chains); + const auto& stream_bigraph = DAGToBigraph(stream_dag); + const auto& rematching = MaximumMatching(stream_bigraph); + const auto& remapping = GetMapping(rematching); + std::vector chain_to_stream = std::get<0>(remapping); + + std::vector stream_ids(node_to_chain.size(), -1); + for (int node_id = 0; node_id < node_to_chain.size(); ++node_id) { + stream_ids[node_id] = chain_to_stream[node_to_chain[node_id]]; + } + + return stream_ids; +} + +void MarkFullGraph(Graph* g, int num_streams) { + std::vector node_stream_ids = GenerateNodeStreamId(g); + + std::unordered_map name_to_node; + for (Node* n : g->nodes()) { + name_to_node[n->name()] = n; + + if (n->assigned_device_name().find("device:GPU:") == + std::string::npos) { + continue; + } + + int stream_id = node_stream_ids[n->id()] % num_streams; + n->AddAttr("_stream_id", stream_id); + + std::string required_device = + GetDeviceNamePrefix(n->assigned_device_name()) + + std::to_string(stream_id); + if (n->assigned_device_name() != required_device) { + n->set_assigned_device_name(required_device); + } + } +} + +void MarkEmbeddingGraph(Graph* g, int num_streams) { + std::unordered_map name_to_node; + // User marked subgraph + for (Node* n : g->nodes()) { + name_to_node[n->name()] = n; + + if (n->assigned_device_name().find("device:GPU:") == std::string::npos || + n->def().attr().find("_stream_id") == n->def().attr().end()) { + continue; + } + + int stream_id = n->def().attr().at("_stream_id").i(); + std::string required_device = + GetDeviceNamePrefix(n->assigned_device_name()) + + std::to_string(stream_id); + if (n->assigned_device_name() != required_device) { + n->set_assigned_device_name(required_device); + } + } +} + } // namespace stream_subgraph } // namespace tensorflow diff --git a/tensorflow/core/graph/stream_subgraph.h b/tensorflow/core/graph/stream_subgraph.h index 28a3415084b..ac568a121a8 100644 --- a/tensorflow/core/graph/stream_subgraph.h +++ b/tensorflow/core/graph/stream_subgraph.h @@ -34,6 +34,12 @@ void MarkStreamSubGraph(Graph* g, const MultiStreamOptions& opt); // Assign embedding graphs stream. void MarkEmbeddingGraph(Graph* g, int num_streams); +// Auto split full graph to subgraphs, +// and assign stream to each subgraph. +void MarkFullGraph(Graph* g, int num_streams); +// Return stream id vector which indexed by node id +std::vector GenerateNodeStreamId(const Graph* graph); + } // namespace stream_subgraph } // namespace tensorflow diff --git a/tensorflow/core/graph/stream_subgraph_test.cc b/tensorflow/core/graph/stream_subgraph_test.cc new file mode 100644 index 00000000000..8159fb1e02b --- /dev/null +++ b/tensorflow/core/graph/stream_subgraph_test.cc @@ -0,0 +1,54 @@ +#include "tensorflow/core/graph/stream_subgraph.h" + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(GenerateNodeStreamId, TestGraph) { + Graph graph(OpRegistry::Global()); + std::vector nodes; + nodes.push_back(graph.source_node()); + nodes.push_back(graph.sink_node()); + for (auto edge : graph.source_node()->out_edges()) { + graph.RemoveEdge(edge); + } + for (int i = 0; i < 5; ++i) { + Node* node; + TF_CHECK_OK(NodeBuilder(strings::StrCat("v", i+1), "NoOp").Finalize(&graph, &node)); + nodes.push_back(node); + } + + graph.AddEdge(nodes[0], 0, nodes[1], 0); + graph.AddEdge(nodes[0], 1, nodes[2], 0); + graph.AddEdge(nodes[0], 2, nodes[3], 0); + graph.AddEdge(nodes[0], 3, nodes[5], 0); + + graph.AddEdge(nodes[1], 0, nodes[4], 0); + + graph.AddEdge(nodes[2], 0, nodes[4], 1); + graph.AddEdge(nodes[2], 1, nodes[6], 0); + graph.AddEdge(nodes[2], 2, nodes[5], 1); + + graph.AddEdge(nodes[3], 0, nodes[5], 2); + + graph.AddEdge(nodes[4], 0, nodes[6], 1); + + auto mapping = stream_subgraph::GenerateNodeStreamId(&graph); + + EXPECT_EQ(mapping.size(), 7); + EXPECT_EQ(mapping[0], mapping[1]); + EXPECT_EQ(mapping[1], mapping[4]); + EXPECT_EQ(mapping[4], mapping[6]); + EXPECT_EQ(mapping[2], mapping[5]); + for (int i = 0; i < mapping.size(); ++i) { + VLOG(2) << i+1 << ": " << mapping[i]; + } +} + +} // namespace +} // namespace tensorflow