I’m sorry I don’t get the meaning of what does CorrectLayout do, the code is as following, but I can’t understand:
nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_correct_layout =
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// (new) NodePtr -> output_layouts
LayoutAttrDict new_layouts;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
// Variable node. No operator. Only one output entry.
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
if (src.HasAttr("layout_inputs")) {
new_layouts[new_node.get()] =
{src.GetAttr<std::vector<Layout> >("layout_inputs")[input_id]};
} else {
new_layouts[new_node.get()] = {Layout::Undef()};
}
mirror_vec[nid] = new_node;
continue;
}
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
// set up output and input layouts
std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef());
for (size_t i = 0; i < num_inputs; ++i) {
const IndexedGraph::NodeEntry& input_entry = inode.inputs[i];
const NodePtr& new_input_node = mirror_vec[input_entry.node_id];
CHECK(new_input_node != nullptr);
// fill inputs by previous node (DFS order) inferred layouts.
const auto& layouts_iter = new_layouts.find(new_input_node.get());
CHECK(layouts_iter != new_layouts.end());
request_ilayouts[i] = layouts_iter->second[input_entry.index];
}
// layouts produced by previous node.
std::vector<Layout> produce_ilayouts(request_ilayouts);
// input layouts from last pass of LayoutTransform (if apply)
std::vector<Layout> last_request_ilayouts(num_inputs, Layout::Undef());
// fill outputs by last pass of LayoutTransform (if apply)
std::vector<Layout> produce_olayouts(num_outputs, Layout::Undef());
if (src.HasAttr("layout")) {
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t i = 0; i < num_outputs; ++i) {
produce_olayouts[i] = layouts[idx.entry_id(nid, i)];
}
for (uint32_t i = 0; i < num_inputs; ++i) {
last_request_ilayouts[i] = layouts[idx.entry_id(inode.inputs[i])];
}
}
if (op_correct_layout.count(new_node->op())) {
const auto &flayout = op_correct_layout[new_node->op()];
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
<< "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
}
// update new layouts
new_layouts[new_node.get()] = std::move(produce_olayouts);
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version};
// insert layout_transform if necessary
const Layout& produce = produce_ilayouts[i];
const Layout& request = request_ilayouts[i];
if (produce != request && produce.defined()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output{tnode, 0, 0};
new_node->inputs[i] = tnode_output;
// layout produced by LayoutTransformNode
new_layouts[tnode.get()] = {request};
} else if (!produce.defined()) {
// do reverse infer
new_layouts[in.get()][e.index] = request;
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
outputs.emplace_back(nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
nnvm::Graph ret;
ret.outputs = outputs;
// restore the layouts to return graph
const auto& ret_idx = ret.indexed_graph();
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
const auto& layout_iter = new_layouts.find(inode.source);
if (layout_iter != new_layouts.end()) {
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = std::move(layout_iter->second[i]);
}
}
}
// cannot call indexed_graph() before return the origin Graph,
// thus create a new one
nnvm::Graph new_ret;
new_ret.outputs = std::move(outputs);
new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
return new_ret;
}
Hope to get your help, thank you very much~