#JAX Layers Project

1 messages · Page 1 of 1 (latest)

long crystal
#

To get things started, let's discuss here. We may need to set up a GVC meeting if this doesn't work.

#

In the meantime, since this project was originally proposed there has been some discussion about whether or not it's actually too ambitious. There is some concern that many of the layer types that are used in state-of-the-art models are so specific to the individual model architecture that they cannot be reused in other models. So it's been suggested that instead of a layer library it would make more sense to create a model garden.

primal vortex
# long crystal In the meantime, since this project was originally proposed there has been some ...

Hey Robert, I think the model garden is a great approach.

Here’s the roadmap I am brainstorming for the project.

  1. Create a core reusable layer library that implements extremely reusable layers across popular model architectures. Examples to name a few include:
  • attention (FlashAttention, sliding window, rotary embedding)
  • Normalization (LayerNorm, RMSNorm, GroupNorm)
  • Utilities (Dropout, stochastic depth, activation funcs)
  1. Model implementation that uses the reusable components from the core library. These models serve as an example of best practices. It will also be an example so documentation must be a reflection of that. Also, not all layers will be reusable and we should seperate/label non reusable parts.

really like this idea since torch.nn equivalent in jax would be a great addition