[RFC] CSE Optimization

Motivation

Common Subexpression Elimination is a compiler optimization that effectively avoids repeated computation of the same expression. It is needed in the AutoDiff submodule to improve its performance. As an example, consider a compute definition that is as simple as the tanh activation function:

Y = tanh(X)

whose gradient definitions are

dX = (1 - tanh(X) * tanh(X)) * dY

We can clearly see that the value of dX is evaluated for 3 times in this example: once in the forward pass, and twice in the backward pass. Although such re-evaluation might not be a big deal here (because tanh is a relatively lightweight operator), one can probably imagine the trouble that this will bring as the operators become more complicated, as our examples in the later text will show.

Key Ideas

To illustrate our idea better, we consider the compute definition for the BatchNorm operator and its gradient on the placeholder γ:

,

To avoid re-evaluating the normalized x (denoted as x_hat), we must complete the following steps:

  1. Analysis

    We need to find out that x_hat is the LARGEST common subexpression between the forward and the backward pass.

  2. Transform (Forward)

    We need to set x_hat as one of the outputs, so that it can be stashed in memory by the forward pass.

  3. Transform (Backward)

    We need to replace the x_hat in the backward compute definition with a placeholder, whose value will be fed by the one that we stashed in the previous step.

As one might notice, those steps are in essence inferring the feature maps (a.k.a. backward dependency). In legacy machine learning frameworks (e.g., MXNet, TensorFlow) and deep learning libraries (e.g., cuDNN), this is done in a manual, hard-coded way. But here we are doing this automatically, that is, the feature maps are determined in such a way that the amount of computation needed to evaluate the backward gradients is minimized.

Implementation Details

To implement CSE, we need to go through the following steps:

  1. Tensor AutoInliner

    The tensor inliner is needed to simplify the tensor expressions generated by the AD pass and has already been implemented as part of autodiff/ad_util. As an example, consider the raw compute definition from the AD pass on the γ gradient of the BatchNorm operator:

    extracted_tensor[b, h, w, -c] = 
            (X[b, -c, h, w] - E_X[-c]) / sqrt(VAR_X[-c] + ε)
    dγ[c] += dY[b, c, h, w] × extracted_tensor[b, h, w, -c]
    

    As we can see, the introduction of the intermediate tensor extracted_tensor causes trouble in the optimization: not only does it have a different indexing order from X (and dY), but a reverse index as well (i.e., c). Futhermore, the access to the extracted_tensor adds an extra level of indirection in the tensor expression comparison (described later). Therefore, the first step of the CSE is to automatically inline all the injective computes in their respective consumers.

  2. CSE Optimizer

    // Note that for simplicity, some low-level details are omitted here.
    class CSEOptimizer {
     private:
      Tensor* src_;
      TensorExpTree src_tensor_expr_tree_, tgt_tensor_expr_tree_;  //< described later
     private:
      /// @brief Optimize the expression (to a placeholder) if
      ///
      ///            src_tensor_expr_tree_.Find(tgt_tensor_expr_tree_.at(expr))
      ///
      ///        Dispatch to the following operation nodes:
      ///          - Call
      ///          - +, -, *, /
      ///          - Reduce
      ///          - Int/FloatImm
      Expr Optimize(const Expr& expr);
     public:
      void Optimize(Tensor* const src, Tensor* const tgt) {
        src_tensor_tree_.Visit(*src);
        tgt_tensor_tree_.Visit(*tgt);
        new_body_stmt = Optimize(tgt->op->body);
        *tgt = ComputeOpNode::make(new_body_stmt, ...);
      }
    };
    

    The CSE optimizer is constructed from a src tensor expression to optimize tgt. It stores internally a tensor expression tree for each of them (described later). The optimizer operates on the body statement of target. As it optimizes for each expression, it looks the expression up in the target expression tree for the expression subtree and will replace the expression with a placeholder if the source expression tree is able to locate the same subtree.

  3. Tensor Expression Tree

    class TensorExprTree {
     private:
      /*! \brief Construct a tensor expression, whose operator type is \p node
       *         and shape is \p axis.
       *
       *         Dispatch to the following operation nodes:
       *           - Call
       *           - ProducerLoad
       *           - +, -, *, /
       *           - Reduce
       *           - Int/FloatImm
       */
      TensorExprPtr Construct(const ObjectRef& node, const Array<IterVar>& axis)
     public:
      /*! \brief Visit `tensor`'s body statement to construct the expression tree. 
       */
      void Visit(const Tensor& tensor);
      /*! \brief Find a tensor expression in the tree.
       */
      bool Find(const TensorExpr& expr) const;
    };
    

    The tensor expression tree, constructed from a tensor, is a tree-like structure that is used to represent the expressions that a tensor has evaluated. It is in fact very similar to the legacy NNVM graph. The tree is able to search its subtrees to determine if they match certain tensor expression, which, as we have shown earlier, is used by the CSE optimization.

@yzhliu

3 Likes

Thanks for the RFC. Nice finding and it looks good to me overall. Could you also provide an example of how the user api looks like? Besides, what information does Tensor Expression Tree provide in addition to Tensor?

  • The user API looks like the following:
    /*!
     * \brief Eliminate common subexpressions among \p in_args and between them and \p output .
     *
     * \param output The output tensor.
     * \param input_grads The gradients of input tensors.
     */
    std::pair<Tensor, std::vector<Tensor> >
    CSE(const Tensor& output, const std::vector<Tensor>& input_grads);
    
  • The tensor expression tree does not provide extra information. It is just an useful data structure for comparison purposes.

@tqchen @junrushao1994 how much of the work do you think we can reuse once we move to https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm?

I think there are two potential ways to think about it. We can either try to do CSE in the Expr level, or we can do CSE in the TE level. I think both will bring some of the benefit, so it would be helpful to support both variants

I agree with @tqchen. CSE is not hard to implement, and CSE on both sides provides different benefits. So we can potentially support both variants.

I suppose CSE would solve my question earlier where 2 identical adds with the same tensor shapes where not simplified as 1 add with both tensors added?