How would one go about implementing Mixture-of-Depths

It relies on a dynamic k for topk based on the sequence length, but in Relax op.topk gets very upset when k is not a fixed int. Here’s an example of the gating in MoD.

    def forward(self,
                x: torch.Tensor,
                attention_mask: torch.Tensor,
                position_ids: torch.Tensor,
                past_key_value: Optional[DynamicCache],
                output_attentions: bool,
                use_cache: bool,
                cache_position: Optional[torch.Tensor] = None,
                **kwargs: Any
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        b, s, d = x.shape

        weights = self.router(x)

        if self.router.training:
            self.training_step += 1 if self.training_step < 1000 else 999
            self.capacity = 0.125 + ((1 - 0.125) * (1. / self.training_step))

        if torch.isnan(x).any():
            warnings.warn(
                "NaN detected in input tokens, this is not intended to happen, please check your model. Before retraining, you could try the model with flash-attn-2 enabled.")

        k = max(1, int(self.capacity * s))
        top_k_values, _ = torch.topk(weights, k, dim=1, sorted=True)
        threshold = top_k_values[:, -1]
        selected_mask = weights > threshold.unsqueeze(-1) if k > 1 else weights >= threshold.unsqueeze(-1)

I tried to do something like this a few weeks ago, and got stuck on topk as well

router_logits = op.squeeze(op.matmul(expert_tokens, router_for_expert), axis=-1)     
            def topk_using_seqlen(r : Tensor, x : Tensor ) -> Tensor:
                print(f"X: {x}")
                print(f"R: {x}")
                def compute(s: tir.Var, d: tir.Var):
                    k = tir.expr.Select(
                        tir.multiply(d, self.capacity) > tir.const(1),
                        true_value=tir.Cast("int32", tir.multiply(d, self.capacity)),
                        false_value=tir.const(1, "int32")
                    )
                    return tvm.topi.topk(r, k=k, ret_type="values", is_ascend=False)    
                return te.compute(x.shape, compute, name="compute_top_k_w_seqlen")
            topk_tensor = op.tensor_expr_op(topk_using_seqlen, "top_k_w_seqlen", [ router_logits, x  ])
        print(f"Top K Values: {topk_tensor}")
`#get error: shape.size() == indices.size() (2 vs. 0) : Tensor dimension mismatch in read ndim = 2, indices.size=0`

I wish TVM had better docs on how exprs / binding dynamic stuff worked at a higher level :frowning: I would love to improve my understanding as it is such a cool project