Hi there, I tried to deploy resnet50 on web browsers referring to web-stable-diffusion (really great job) and tvm-webgpu-example (which is outdated), but I can’t get correct results.
First I built TVM wasm runtime following this, and I got the resnet50_v1.wasm as follows:
from tvm import relax from tvm.relax.testing import relay_translator import tvm import onnx target = tvm.target.Target("webgpu", host="llvm --mtriple=wasm32-unknown-unknown-wasm") onnx_mod = onnx.load("./resnet50_v1.onnx") mod = relay.frontend.from_onnx(onnx_mod, shape={"input_tensor:0":[1, 3, 224, 224]}) mod = relay_translator.from_relay(mod["main"], target=target) with target, tvm.transform.PassContext(opt_level=3): ex = relax.build(mod, target=target, system_lib=True) ex.export_library(os.path.join("./dist", f"resnet50_v1.wasm"))
And I deployed resnet50_v1.wasm on web. First initialize:
const wasmSource = await ( await fetch("./resnet50_v1.wasm") ).arrayBuffer(); this.tvm = await tvmjs.instantiate( new Uint8Array(wasmSource), new EmccWASI(), logger ); const gpuDevice = await tvmjs.detectGPUDevice(); this.tvm.initWebGPU(gpuDevice.device); this.device = this.tvm.webgpu(); this.vm = this.tvm.withNewScope(() => { return this.tvm.detachFromCurrentScope(this.tvm.createVirtualMachine(this.device)); });
Then classify:
this.tvm.beginScope(); const inputData = this.tvm.empty([1, 3, 224, 224], "float32", this.device); const outputData = this.tvm.empty([1, 1001], "float32", this.tvm.cpu()); const executor = this.vm.getFunction("main"); inputData.copyFrom(processedImage); const output = executor(inputData); outputData.copyFrom(output.get(1)); await this.device.sync(); const outputArray = outputData.toArray(); console.log(outputArray); this.tvm.endScope();
But however I changed the input image, I got an output array of only zeros: Float32Array(1001) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…]. Did I do something wrong?
BTW, I have tried the built model in native TVM runtime, and everything goes fine.
vm = relax.VirtualMachine(ex, tvm.cpu()) res = vm["main"](data)