Autoscheduling Multiple Shapes Sequentially Results In Unexpected Measure Timeouts

Hi everyone, Im hoping for some guidance on what I might be doing wrong. What im trying to do is tune multiple operators/shapes sequentially in a script using auto_scheduler.

My code looks roughly as follows:

for shape in shapes_list:
  tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=num_trials,
            num_measures_per_round=num_measures_per_round,
            verbose=2,
            measure_callbacks=[
                auto_scheduler.RecordToFile(
                    "outputs/" + operator.__name__ + str(compute_args) + ".json"
                )
            ],
        )
  cost_model = XGBModel()
  s_policy = SketchPolicy(task, cost_model, params=sketch_policy_params)
  sch, args = auto_scheduler.auto_schedule(
      task, search_policy=s_policy, tuning_options=tune_option
  )

When I run multiple shapes in this loop the first shape succeeds to find schedules without a problem but all subsequent shapes have this behavior that they encounter numerous measure timeout errors. See this log that runs two shapes of gmm back to back:

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Number of sketches generated: 3
Init Population, sketch select strategy: random, total sketches: 3
bounds: Sample Initial Population	#s: 1949	fail_ct: 27	Time elapsed: 0.39
GA Iter: 0	Max score: 1.0000	Min score: 0.9387	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 5	Max score: 1.0000	Min score: 0.9903	#Pop: 128	#M+: 1438	#M-: 83
GA Iter: 10	Max score: 1.0000	Min score: 0.9944	#Pop: 128	#M+: 1567	#M-: 90
EvolutionarySearch		#s: 128	Time elapsed: 3.26
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure.
Gathering gemm shapes
op: batch_matmul_nkkm compute arg: (1, 128, 128, 128)
..........**********==================================================

.
.
.

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Number of sketches generated: 3
Init Population, sketch select strategy: random, total sketches: 3
bounds: Sample Initial Population	#s: 1966	fail_ct: 22	Time elapsed: 0.46
GA Iter: 0	Max score: 0.9984	Min score: 0.9365	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 5	Max score: 1.0000	Min score: 0.9889	#Pop: 128	#M+: 1436	#M-: 79
GA Iter: 10	Max score: 1.0000	Min score: 0.9938	#Pop: 128	#M+: 1571	#M-: 86
EvolutionarySearch		#s: 128	Time elapsed: 3.18
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure.

Execution time of this operator: 0.021 ms
op: batch_matmul_nkkm compute arg: (1, 512, 32, 512)
.........*T*T*T*T*T*T*T*T*T==================================================

You can see that all the schedules measured for this (1, 512, 32, 512) shape run into timeout issues.

But if I run just this shape (or i run this shape as the first shape searched) then all of the schedules will succeed as follows (note this is a separate run):

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Number of sketches generated: 3
Init Population, sketch select strategy: random, total sketches: 3
bounds: Sample Initial Population	#s: 1951	fail_ct: 21	Time elapsed: 0.39
GA Iter: 0	Max score: 0.9997	Min score: 0.9417	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 5	Max score: 1.0000	Min score: 0.9898	#Pop: 128	#M+: 1438	#M-: 74
GA Iter: 10	Max score: 1.0000	Min score: 0.9948	#Pop: 128	#M+: 1565	#M-: 81
EvolutionarySearch		#s: 128	Time elapsed: 3.31
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure.
Gathering gemm shapes
op: batch_matmul_nkkm compute arg: (1, 512, 32, 512)
..........**********==================================================

You can note that im using the default builder and runner.

Is there some state clean up i need to be doing between subsequent calls to auto_scheduler.auto_schedule that im not doing or can anyone indicate a direction to fix this behaviour?

It might be due to the RPC tracker conflict for measurement. You can try to reuse the tune_option, or provide a single runner to all tune_option. Something like

    measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10)

    for shape in shapes:
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=200,  # change this to 20000 to achieve the best performance
            runner=measure_ctx.runner,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )

    del measure_ctx

So I tried what you suggested but also redefine the measure_ctx in the loop and it seems to have fixed the issue.

Can you expand on what you mean by “RPC tracker conflict for measurement”? Im a big confused why using the default runner experiences this issue but if we define the measure_ctx and specify the runner that way we avoid the issue.

Is this a known issue or is this something i should raise in the issues on github?

Either way thank you very very much @comaniac!

This is expected. We use RPC runner to measure the latency of each schedule. Specifically, when constructing an RPC measure context, we fork a thread to be RPC server that receives the measurement requests. In this way, we isolate the measurement environment to avoid some issues (e.g., CUDA context memory mapping, etc) and get more accurate latencies.

Accordingly, in your original code, you were trying to construct the second RPC server with the same IP and port, which results in the conflict and it will be adjusted to another port IIUC. However, you will still send the measurement requests to the first RPC server, which might be garbage collected by Python already. It could be the reason the RPC client failed to connect to the server and results in timeout.