I want to do matmul, e.g. (1024, 768) * (3072, 768), in my accelarator, which has 4 clusters with each cluster consisting of 4 cores. SRAM is on-chip RAM shared by all cores within one cluster and NRAM is on-chip RAM in each core. I need to split the input W(3072, 768) into 16 parts, each part of (192, 768) nums is loaded in each core through dataflow W → WS → WN → C, in which W/WS/WN locates in global memory/SRAM/NRAM respectively and the corresponding data nums is (3072, 768)/(768, 768)/(192, 768). The schedule I use is as follows:
### Python code
A, weights = C.op.input_tensors
AA = s.cache_read(A, "nram", [C])
WS = s.cache_read(weights, "sram", [C])
CC = s.cache_write(C, "nram")
WN = s.cache_read(WS, "nram", [CC])
b, o = C.op.axis
oo, oi = s[C].split(o, factor=768)
s[C].reorder(oo, b, oi)
s[WS].compute_at(s[C], oo)
s[C].bind(oo, te.thread_axis("clusterId"))
s[CC].compute_at(s[C], oo)
b, i = s[CC].op.axis
io, ii = s[CC].split(i, factor=192)
s[CC].reorder(io, b, ii)
s[WN].compute_at(s[CC], io)
#s[CC].bind(io, te.thread_axis("coreId"))
Before binding to the ThreadIndex “coreId”, the resulted IR is as follows. Now the load data nums of placeholder.sram is (768, 768) and that of placeholder.sram.nram is (192, 768), which is what I wanted:
IR before binding to coreId:
//attr [IterVar(clusterId: int32, (nullptr), "ThreadIndex", "clusterId")] "thread_extent" = 4;
allocate(placeholder.sram: Pointer(sram int16), int16, [589824]), storage_scope = sram;
allocate(T_matmul_NT.nram: Pointer(nram float32), float32, [786432]), storage_scope = nram;
allocate(placeholder.sram.nram: Pointer(nram int16), int16, [147456]), storage_scope = nram {
for (ax0_1: int32, 0, 768) {
for (ax1_1: int32, 0, 768) {
placeholder.sram[((ax0_1*768) + ax1_1)] = (int16*)placeholder_4[(((clusterId*589824) + (ax0_1*768)) + ax1_1)]
}
}
**for (j.c.outer: int32, 0, 4**) {
for (ax0_2: int32, 0, 192) {
for (ax1_2: int32, 0, 768) {
placeholder.sram.nram[((ax0_2*768) + ax1_2)] = (int16*)placeholder.sram[(((j.c.outer*147456) + (ax0_2*768)) + ax1_2)]
}
}
for (i.c: int32, 0, 1024) {
for (j.c.inner: int32, 0, 192) {
T_matmul_NT.nram[(((i.c*768) + (j.c.outer*192)) + j.c.inner)] = 0f32
for (k: int32, 0, 768) {
T_matmul_NT.nram[(((i.c*768) + (j.c.outer*192)) + j.c.inner)] = ((float32*)T_matmul_NT.nram[(((i.c*768) + (j.c.outer*192)) + j.c.inner)] + ((float32*)placeholder.nram[((i.c*768) + k)]*cast(float32, (int16*)placeholder.sram.nram[((j.c.inner*768) + k)])))
}
}
}
}
for (i: int32, 0, 1024) {
for (j.inner: int32, 0, 768) {
T_matmul_NT_2[(((i*3072) + (clusterId*768)) + j.inner)] = (float32*)T_matmul_NT.nram[((i*768) + j.inner)]
}
}
}
But if I bind j.c.outer in above IR to the ThreadIndex “coreId”, the load data nums of placeholder.sram change from (768, 768) to (192, 768) too, which is not I wanted:
IR after binding to coreId:
//attr [IterVar(clusterId: int32, (nullptr), "ThreadIndex", "clusterId")] "thread_extent" = 4;
allocate(placeholder.sram: Pointer(sram int16), int16, [147456]), storage_scope = sram;
allocate(placeholder.sram.nram: Pointer(nram int16), int16, [147456]), storage_scope = nram;
allocate(T_matmul_NT.nram: Pointer(nram float32), float32, [786432]), storage_scope = nram {
for (ax0_1: int32, 0, 192) {
for (ax1_1: int32, 0, 768) {
placeholder.sram[((ax0_1*768) + ax1_1)] = (int16*)placeholder_4[((((clusterId*589824) + (coreId: int32*147456)) + (ax0_1*768)) + ax1_1)]
}
}
//**attr [IterVar(coreId, (nullptr), "ThreadIndex", "coreId")] "thread_extent" = 4** {
for (ax0_2: int32, 0, 192) {
for (ax1_2: int32, 0, 768) {
placeholder.sram.nram[((ax0_2*768) + ax1_2)] = (int16*)placeholder.sram[((ax0_2*768) + ax1_2)]
}
}
for (i.c: int32, 0, 1024) {
for (j.c.inner: int32, 0, 192) {
T_matmul_NT.nram[((i.c*768) + j.c.inner)] = 0f32
for (k: int32, 0, 768) {
T_matmul_NT.nram[((i.c*768) + j.c.inner)] = ((float32*)T_matmul_NT.nram[((i.c*768) + j.c.inner)] + ((float32*)placeholder.nram[((i.c*768) + k)]*cast(float32, (int16*)placeholder.sram.nram[(((j.c.inner*768) + k) - (coreId*147456))])))
}
}
}
}
for (i: int32, 0, 1024) {
for (j.inner: int32, 0, 768) {
T_matmul_NT_2[(((i*3072) + (clusterId*768)) + j.inner)] = (float32*)T_matmul_NT.nram[((i*768) + j.inner)]
}
}
}
Is there any primitive I can use in the schedule shown above to make sure that the load data size in shared memory SRAM do not change to be the same with NRAM in each core?