#jax
1 messages ยท Page 1 of 1 (latest)
@cinder pumice
JAX!!!!!!
Did I miss a public announcement on this? Or is this in anticipation for something being released soon?
Google JAX is a machine learning framework for transforming numerical functions, to be used in Python. It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure a...
this channel is actually for discussion of Jax Taylor
Wait so you're telling me this channel is dedicated to me?
I think it was released like a month ago?
but missed it entirely too
hello
What is this? Didn't hear about this
Hi my name is Rick Roll and I'm a developer. ๐
Hello! My name is Muhammad Muso.
So, can anyone link good resources for learning jax? Maybe some blogs or examples?
If you intend to use jax for NLP, this would be the best way to learn:
- Read JAX documentation
- Apply for free access to TPU VMs from TRC
- Read tpu-starter
- Implement MLP model (+ training on MNIST), then implement data parallelism
- Implement a recent Transformer model (+ implement API for fine-tuning, generation), then implement model parallelism
- Ask for free access to TPU Pods from TRC
- Test the implementation on TPU Pods
- You will become an expert and already know what to do next.
Hey guys I am having problem identifying which is the quota that allows me to have tpuv5 any help?
"TPU v5 Lite PodSlice chips"?
Just added a new transform in optax https://github.com/google-deepmind/optax/pull/958
hi, I have a JAX (or rather autodiff specific) question: in you only linearize once (https://arxiv.org/abs/2204.10923) they define the VJP using inner products... why did they do this?
Automatic differentiation (AD) is conventionally understood as a family of distinct algorithms, rooted in two "modes" -- forward and reverse -- which are typically presented (and implemented) separately. Can there be only one? Following up on the AD systems developed in the JAX and Dex projects, we formalize a decomposition of reverse-mode AD in...
A couple weeks ago me and Aritra released a JAX Implementation of Black Forest Labs' Flux.1 family of models. If you have some free time and want to contribute to a JAX codebase we have some open issues.
Hi, I'm just learning jax and had a question about when to use vmap/jit or do batching for forward propagation in neural nets. I'm specifically learning the equinox library. Coming from pytorch I inherently want to include a batch dimension in my input data shape (batch_size, channels, height, width) with mnist, for example. But I know that I can exclude this, (channels, height, width), and apply vmap or jit to the forward function to parallelize. I'm not sure what approach to take, and was wondering if there's some overhead from using vmap/jit that slows down processing of smaller batch sizes.
How can i start learning JAX?
I would recommend the following notebooks by fellow GDE AakashKumarNain
Link: https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part1/notebook
These tutorials are quite good as well:
https://docs.jaxstack.ai/en/latest/tutorials.html
the question of where to call vmap may be particular Equinox. In Flax/NNX, the inputs in the model still contain the batch dimension. e.g.:
def __call__(self, x: jax.Array) -> jax.Array:
print('SHAPE', x.shape) # => (batch_size, ...input dims)
Kicking off the discussion of the JAX Layers project, here is the original proposal:
Project title: JAX Layers
Description: Create an open-source library of reusable layers based on porting GenAI models from PyTorch to JAX. Following good coding and documentation practices. Apply JAX to optimize the implementations, while maintaining numerical correctness.
Additional information: An example model at https://docs.jaxstack.ai/en/latest/JAX_Vision_transformer.html and more models are upcoming
Similar equivalent:
https://pytorch.org/docs/stable/nn.html
https://www.tensorflow.org/api_docs/python/tf/keras/layers
I'll create a thread
Hi everyone!
We're hosting a JAX and OpenXLA "DevLab" on June 23rd and 24th in Sunnyvale. "DevLabs" are small, informal user group gatherings, with a keynote, tutorials, lunch, round table discussions, and a happy hour.
You can find more details about the event here: https://openxla.org/events/summer_devlab_2025
If you'd like to attend, please fill out this form to express interest! https://forms.gle/8dyAAfivmKsmHhFt5
Space is limited, and we'll let you know by 5/30 if we're able to confirm your registration. Thank so much!
Hi Josh, post it on #events too ๐
GDG AI for Science Australia have two virtual JAX talks coming up. Come along if you are interested:
Virtual Event - Unlock the secrets of JAX! Learn why it's revolutionizing AI development through its modular ecosystem and how leaders use it to build and deploy state-of-the-art models and empower the next generation of intelligent applications.
Why does this channel keep showing up, but never has any postings? Maybe the channel gets only spam, which is then deleted, but it leaves the "new items" annotation in the TOC. Ugh.
๐
yeah I think it only gets spam and gets auto deleted or by mods
Unfortunately, messages that appear and are then deleted don't remove the white ping / new message marking on Discord :(
What do you mean?
if a spammer puts a message here, and a bot (or person) later deletes it, the channel in the left column will still show "new messages" even though the only new message has been deleted.
I'm working on porting the VibeVoice TTS model to JAX. So far I have something that produces sensible output https://github.com/pevers/VibeVoice-jax . However, the quality is really bad compared to the HF model. After debugging every layer I believe it is caused by a stacking error because of some small numerical differences in certain operations like nnx.Conv . For testing I run it on an NVidia 5090 with jax-cuda. Anyone knows how to tackle this? Or should I "just" fine-tune the model to fix this?
Hey everyone! New to JAX and want to start learning.
What's the best beginner path?
โข Any good tutorials for someone coming from NumPy?
โข Simple "hello world" example?
this is good
making a random number generator is probally the best move for new coders