Hey guys. Does anyone have experience working with t5 style models (specifically mt3). I am working with a fork of the https://github.com/rlax59us/MT3-pytorch repository and i am not getting results anywhere close to what was described in the original article the repo is based on (https://arxiv.org/abs/2107.09142v1). I have already found 2 bugs that could be the root cause (shared attention linear layers between q, w, k, duplicate time tokens) but other than that, i appear to be a bit stuck. What is a good way to debug such a model to see if it is behaving correctly? The source code in JAX that is adapted here is available at https://github.com/magenta/mt3/tree/main?tab=readme-ov-file
if you got any advice on how i should tackle this let me know!
Thanks!!