[RFC] CSE Optimization

  • The user API looks like the following:
    /*!
     * \brief Eliminate common subexpressions among \p in_args and between them and \p output .
     *
     * \param output The output tensor.
     * \param input_grads The gradients of input tensors.
     */
    std::pair<Tensor, std::vector<Tensor> >
    CSE(const Tensor& output, const std::vector<Tensor>& input_grads);
    
  • The tensor expression tree does not provide extra information. It is just an useful data structure for comparison purposes.