What is the order strategy of input args of for OpenCL device code

Hi, I write TE and generate opencl code as below:

import tvm
from tvm import te

def test0():
    # inputs
    n = te.var("n")
    in0 = te.placeholder((n), name="in0")
    in1 = te.placeholder((n), name="in1")
    in2 = te.placeholder((n), name="in2")
    # scale
    scale0 = te.var("scale0")
    scale1 = te.var("scale1")
    scale2 = te.var("scale2")
    # output
    out0 = te.compute(in0.shape, lambda *i: in0(*i) * scale0 + in1(*i) * scale1 - in2(*i) * scale2, name="out0")

    # schedule
    s = te.create_schedule(out0.op)
    bx, tx = s[out0].split(out0.op.axis[0], factor=64)

    s[out0].bind(bx, te.thread_axis("blockIdx.x"))
    s[out0].bind(tx, te.thread_axis("threadIdx.x"))

    fapi = tvm.lower(s, [scale0, scale1, scale2, in0, out0, in1, in2], simple_mode=True)
    print(fapi)

    rt_mod = tvm.build(fapi, target=tvm.target.Target("opencl", host="c"), name="addgpu")
    print(rt_mod.imported_modules[0].get_source())

Source OpenCL code:

// Function: main_kernel0
__kernel void main_kernel0(__global float* restrict out0, __global float* restrict in0, __global float* restrict in1, __global float* restrict in2, int n, int stride, int stride_1, int scale0, int stride_2, int scale1, int stride_3, int scale2) {
  if (((int)get_group_id(0)) < (n >> 6)) {
    out0[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride)] = (((in0[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_1)] * ((float)scale0)) + (in1[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_2)] * ((float)scale1))) - (in2[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_3)] * ((float)scale2)));
  } else {
    if (((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) < n) {
      out0[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride)] = (((in0[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_1)] * ((float)scale0)) + (in1[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_2)] * ((float)scale1))) - (in2[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_3)] * ((float)scale2)));
    }
  }
}

But I want to know the order strategy of input args of for OpenCL device code. It seems that first is output address, then input address, other attributes params, but what is the order of attribute params for device function? Have any describes or source code link?

Thanks in advance :smile:

I found the source code, maybe i need to understand the code below:

// https://github.com/apache/tvm/blob/eda84e7804be63a74f0089be221da36c6555b9f9/src/target/source/codegen_c.cc#L96-L115
void CodeGenC::AddFunction(const PrimFunc& f) {
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  this->PrintExtraAttrs(f);
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
    if (i != 0) stream << ", ";
    if (v.dtype().is_handle()) {
      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      PrintType(GetType(v), stream);
      // Register handle data type
      // TODO(tvm-team): consider simply keep type info in the
      // type annotation(via a normalizing rewriting).
      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->PrintFinalReturn();
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

It seems the order strategy determined by f->params[i]; push sequence if no sort progress etc.

I think the sort progress is located before codegen. I’m reading source code :crying_cat_face: