[Questions] How to use the auto-scheduler to tune the conjugate gradient?

Hi everyone! I am new to the community and currently trying to search for a faster schedule of the conjugate gradient algorithm with the auto-scheduler. However, I found the memory cost increased evidently and the program got killed after it ran out of all memory resources. I am wondering if I understand the tenser expression language correctly. So if anyone can help me figure this out, I would appreciate it very much!

Here’s my code: OS: Ubuntu 20.04 TVM: branch of main @458ca81c

import tvm
import tvm.testing
from tvm import te, auto_scheduler
from tvm.contrib import tedd
import numpy as np
import time


tgt = tvm.target.Target(target="llvm", host="llvm")
dev = tvm.device(tgt.kind.name, 0)

N_ITR = 160
N = 80

def gen_vec_dot(a, b, n, name=""):
    k2 = te.reduce_axis((0, n), name+'_k2')
    return te.compute((1, ), lambda i: te.sum(a[k2] * b[k2], axis=k2), name=name)

def gen_mat_vec_dot(a, b, n, name=""):
    k3 = te.reduce_axis((0, n), name+'_k3')
    return te.compute((n, ), lambda i: te.sum(a[i, k3] * b[k3], axis=k3), name=name)

def pcg_step(_A, _x, _p, _r, n, itr):
    _rTr = gen_vec_dot(_r, _r, n, "r0Tr0")
    _Ap = gen_mat_vec_dot(_A, _p, n, name=f"Ap_{itr}")
    _pTAp = gen_vec_dot(_p, _Ap, n, name=f"pTAp_{itr}")
    _alpha = te.compute((1, ), lambda _: _rTr[0]/_pTAp[0], f"alpha_{itr}")
    _r_kp1 = te.compute((n, ), lambda i: _r[i] + _alpha[0] * _Ap[i], f"r_kp1_{itr}")
    _x_kp1 = te.compute((n, ), lambda i: _x[i] + _alpha[0] * _p[i], f"x_kp1_{itr}")
    _r_kp1Tr_kp1 = gen_vec_dot(_r_kp1, _r_kp1, n, "r_kp1Tr_kp1")

    _beta = te.compute((1, ), lambda _: _r_kp1Tr_kp1[0] / _rTr[0], f"beta_{itr}")
    _p_kp1 = te.compute((n, ), lambda i: -_r_kp1[i] + _beta[0] * _p[i], f'p_{itr}')

    return _x_kp1, _p_kp1, _r_kp1

def pcg_step_np(_A_np, _x_np, _p_np, _r_np, n, itr):
    _rTr = _r_np.dot(_r_np)
    _Ap = _A_np.dot(_p_np)
    _pTAp = _p_np.dot(_Ap)
    _alpha = _rTr / _pTAp
    _r_kp1 = _r_np + _alpha * _Ap
    _x_kp1 = _x_np + _alpha * _p_np

    _r_kp1Tr_kp1 = _r_kp1.dot(_r_kp1)
    beta = _r_kp1Tr_kp1 / _rTr
    _p_kp1 = -_r_kp1 + beta * _p_np
    return _x_kp1, _p_kp1, _r_kp1

def pcg():
    @auto_scheduler.register_workload 
    def pcg_func(N):
        n = te.var('n')
        n = N
        A = te.placeholder(shape=(n, n), name='A')
        x0 = te.placeholder(shape=(n,), name='x0')
        b = te.placeholder(shape=(n,), name='b')

        k = te.reduce_axis((0, n), 'k')

        Ax0 = te.compute((n, ), lambda i: te.sum(A[i, k] * x0[k], axis=k), name='Ax0')
        r0 = te.compute((n, ), lambda i: Ax0[i] - b[i], name='r0')
        p0 = te.compute((n,), lambda i: -r0[i], "p0")

        # tensor_xs = [x0]
        # tensor_ps = [p0]
        # tensor_rs = [r0]
        x_kp1, p_kp1, r_pk1 = x0, p0, r0
        for i in range(N_ITR):
            x_kp1, p_kp1, r_kp1 = pcg_step(A, x_kp1, p_kp1, r_pk1, n, 0)
            # tensor_xs.append(x_kp1)
            # tensor_ps.append(p_kp1)
            # tensor_rs.append(r_kp1)
        return [A, x0, b, x_kp1]

    def search_sch():
        task = tvm.auto_scheduler.SearchTask(func=pcg_func, args=[N], target=tgt)
        print("Computational DAG:")
        print(task.compute_dag)
        log_file = "pcg.json"
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=2000,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
            verbose=2,
        )
        print('begin tuning')
        task.tune(tune_option)
        sch, args = task.apply_best(log_file)
        tvm_cg = tvm.build(sch, args, tgt)
        return tvm_cg

    def default_sch():
        A, x0, b, res_x = pcg_func(N)
        s = te.create_schedule(res_x.op)
        tedd.viz_dataflow_graph(s, dot_file_path="./tmp/dfg.dot")
        print('begin build')
        tvm_cg = tvm.build(s, [A, x0, b, res_x], target=tgt)
        return tvm_cg

    tvm_cg = search_sch()
    tvm_cg = default_sch()
    test(tvm_cg, N, N_ITR, dev, dtype=np.float32)

def test(tvm_cg, n, N_ITR, dev, dtype):
    print('begin testing')
    np.random.seed(1)
    _A = np.random.uniform(size=(n, n)).astype(dtype)
    _A = _A.T @ _A
    _A = tvm.nd.array(_A, dev)
    _A_np = _A.numpy()

    _x0_np = np.random.uniform(size=(n,)).astype(dtype)
    _x0 = tvm.nd.array(_x0_np, dev)
    
    _b_np = np.random.uniform(size=(n,)).astype(dtype)
    _b = tvm.nd.array(_b_np, dev)
    
    _r0_np = _A_np.dot(_x0_np) - _b_np
    _p0_np = -_r0_np
    
    _r_kp1 = tvm.nd.array(np.zeros((n, ), ).astype(dtype), dev)
    _x_kp1 = tvm.nd.array(np.zeros((n, ), ).astype(dtype), dev)
    _p_kp1 = tvm.nd.array(np.zeros((n, ), ).astype(dtype), dev)

    def run():
        tvm_cg(_A, _x0, _b, _x_kp1)
        _x_np, _p_np, _r_np = _x0_np, _p0_np, _r0_np
        for i in range(N_ITR):
            _x_np, _p_np, _r_np = pcg_step_np(_A_np, _x_np, _p_np, _r_np, n, i)

    # warm up
    for _ in range(10):
        run()
    st = time.time()
    for _ in range(20):
        run()
    end = time.time()
    print(f'tvm default: {(end - st) / 20 * 1000} ms')
    
    # print(_x_np)
    # print(_x_kp1)


pcg()

Thanks in advance!