#Cuda-kernels for low-bit optimizers

9 messages · Page 1 of 1 (latest)

fiery spire
#

Implement (and extend) the low-bit optimizers from torchao with cuda kernels.
The current, python-based implementation nicely separates the optimization algorithm from the underlying data format that is to be used for optimizer states. This is something that should be carried over in the CUDA implementation, i.e., instead of writing an 8-bit Adam kernel, write one Adam kernel template that can be instantiated with different param and optim state dtypes.
In addition to the features currently available, the low-bit optimizer should also help support training when not only the optimizer states are low-bit, but also the parameters. This implies that it should support [https://arxiv.org/abs/2010.06192](stochastic rounding and/or error compensation) for the weight updates.

Stretch goals include efficiently handling multiple tensors in a single kernel call, and using CUDA's just-in-time compilation to potentially avoid having to compile the combinatorical explosion of possible low-bit adam kernels.

ancient tapir
#

bubbling this to the top for myself

ancient tapir
ancient tapir
#

@fiery spire So I should use unsigned char for these? and figure out a way to do float arithmetic with them?

fiery spire
#

I think you can just use the cuda built in float-8 type?

fiery spire
#

in any case, you're probably not doing any actual arithmetic in fp8, basically, just load from memory as fp8, convert to fp32, do the update, convert to fp8, send to memory.
Or more generally, load -> dequant -> compute -> quant -> store type of operation

ancient tapir
#

@fiery spire Here's a draft! The numbers look reasonable, but I still need the valid pytorch spec to make sure it matches. I guess I'm doing this in the wrong order lol

8BitAdam.cu
test.py (fixing now)

fiery spire
#

about to board my flight, and probably will need some sleep after I get home, so it'll be some time before I can give this a detailed look.
From quickly scrolling over it, though:

global void quant_kernel(__nv_fp8_e4m3* params_memory, __nv_fp8_e4m3* grads_memory, __nv_fp8_e4m3* m_memory, __nv_fp8_e4m3* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction,
float eps, float weight_decay
I think in practice, even with 8-bit optimizers, you'd still keep the gradients (grads_memory) in 16 bit

ancient tapir