Causing slow data generating in Pytorch

Hi, I implemented a cuda kernel using TVM, and converted this function to a Pytorch version. The time for computation is short, but it causes needing for longer time for data generating. I wonder if this would happen in dataloading when I deploy this function for model training.

To reproduce the problem:

torch.manual_seed(0)
tot = 0
gen_t = 0

for _ in tqdm(range(5)):
    start_g = time.time()
    q = torch.randn(4,2,100,100,768).float().cuda(non_blocking=True)
    k = torch.randn(4,2,100,100,768).float().cuda(non_blocking=True)

    attn = torch.empty(4,2,100,100,2*3*3).float().cuda(non_blocking=True)
    end_g = time.time()
    gen_t += (end_g-start_g)
    start = time.time()
    qk_mul_torch(q,k,attn,1)
    # op = q+k
    end = time.time()
    tot+=(end-start)
    del q,k,attn
print(gen_t)
print(tot)

Here, qk_mul is a function converted from TVM module to pytorch function. gen_t recorded the time for torch.randn(...).float().cuda() and tot record the time for qk_mul computation. It turns out that in this case get_t = 40.19069480895996 and tot = 0.0007958412170410156. When I disable qk_mul and use op=q+k instead, the result is 10.955339193344116 and 0.0002796649932861328. It seems that qk_mul is blocking sth, which i can not figure out.

For a cpu version of qk_mul,

from tvm.contrib.dlpack import to_pytorch_func
from tqdm import tqdm
qk_mul_torch = to_pytorch_func(qk_mul)
torch.manual_seed(0)
tot=0
gen_t=0
for _ in tqdm(range(5)):
    start_gen = time.time()
    q = torch.randn(4,2,100,100,768)
    k = torch.randn(4,2,100,100,768)
    attn = torch.empty(4,2,100,100,2*3*3)
    end_gen = time.time()
    gen_t += (end-start)
    start = time.time()
    qk_mul_torch(q,k,attn,1)
    end = time.time()
    tot += (end-start)
print(gen_t,tot)

get_t and tot is 4.727025747299194 and 5.899913549423218 respectively. Time for preparing data is less than using a gpu version by a large margin.

As a green hand to TVM, I wonder if there is any mistake when I spliting long axis into smaller chunks and accessing each one to a separate GPU thread/block?

Looking forward to your reply. Thanks