It relies on a dynamic k for topk based on the sequence length, but in Relax op.topk gets very upset when k is not a fixed int. Here’s an example of the gating in MoD.
def forward(self,
x: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_value: Optional[DynamicCache],
output_attentions: bool,
use_cache: bool,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Any
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
b, s, d = x.shape
weights = self.router(x)
if self.router.training:
self.training_step += 1 if self.training_step < 1000 else 999
self.capacity = 0.125 + ((1 - 0.125) * (1. / self.training_step))
if torch.isnan(x).any():
warnings.warn(
"NaN detected in input tokens, this is not intended to happen, please check your model. Before retraining, you could try the model with flash-attn-2 enabled.")
k = max(1, int(self.capacity * s))
top_k_values, _ = torch.topk(weights, k, dim=1, sorted=True)
threshold = top_k_values[:, -1]
selected_mask = weights > threshold.unsqueeze(-1) if k > 1 else weights >= threshold.unsqueeze(-1)