Build a python-only alternative to torch.amp.autocast so that people can specify what ops they want to do autocasting with. Most people want to run softmax and layernorm with half precision inputs and outputs these days in LLMs, but unfortunately AMP in pytorch today casts them to float32 precision https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float32 and outputs them as float32 tensors. Then subsequent matmuls will have to cast these float32 outputs back to half precision. This cuases a "ping pong" of casts that can kill performance.
This just barely scratches the surface, though. It stands to reason that people will want to have flexible control over what operations they want to do in lower precision (fp8, etc.).
I believe that using a TorchDispatchMode will allow for implementing this easily.