#jax

1 messages ยท Page 1 of 1 (latest)

raw jewel
#

Any videos?

onyx flare
#

@cinder pumice

cinder pumice
#

JAX!!!!!!

hardy dune
#

Did I miss a public announcement on this? Or is this in anticipation for something being released soon?

#

Google JAX is a machine learning framework for transforming numerical functions, to be used in Python. It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure a...

dusk grail
#

this channel is actually for discussion of Jax Taylor

#
onyx flare
onyx flare
#

but missed it entirely too

severe prism
#

hello

noble nova
undone spear
#

What is this? Didn't hear about this

acoustic fractal
#

awwwww

#

jax is awesome!

raw jewel
#

Hi my name is Rick Roll and I'm a developer. ๐Ÿ˜‚

paper perch
static bay
#

So, can anyone link good resources for learning jax? Maybe some blogs or examples?

acoustic fractal
# static bay So, can anyone link good resources for learning jax? Maybe some blogs or exampl...

If you intend to use jax for NLP, this would be the best way to learn:

  • Read JAX documentation
  • Apply for free access to TPU VMs from TRC
  • Read tpu-starter
  • Implement MLP model (+ training on MNIST), then implement data parallelism
  • Implement a recent Transformer model (+ implement API for fine-tuning, generation), then implement model parallelism
  • Ask for free access to TPU Pods from TRC
  • Test the implementation on TPU Pods
  • You will become an expert and already know what to do next.
lone marsh
#

Hey guys I am having problem identifying which is the quota that allows me to have tpuv5 any help?

acoustic fractal
hushed dawn
tender mantle
#

hi, I have a JAX (or rather autodiff specific) question: in you only linearize once (https://arxiv.org/abs/2204.10923) they define the VJP using inner products... why did they do this?

hushed dawn
dense jacinth
#

Hi, I'm just learning jax and had a question about when to use vmap/jit or do batching for forward propagation in neural nets. I'm specifically learning the equinox library. Coming from pytorch I inherently want to include a batch dimension in my input data shape (batch_size, channels, height, width) with mnist, for example. But I know that I can exclude this, (channels, height, width), and apply vmap or jit to the forward function to parallelize. I'm not sure what approach to take, and was wondering if there's some overhead from using vmap/jit that slows down processing of smaller batch sizes.

quartz yacht
#

How can i start learning JAX?

hushed dawn
terse zinc
terse zinc
agile creek
#

Kicking off the discussion of the JAX Layers project, here is the original proposal:

Project title: JAX Layers
Description: Create an open-source library of reusable layers based on porting GenAI models from PyTorch to JAX. Following good coding and documentation practices. Apply JAX to optimize the implementations, while maintaining numerical correctness.
Additional information: An example model at https://docs.jaxstack.ai/en/latest/JAX_Vision_transformer.html and more models are upcoming
Similar equivalent:
https://pytorch.org/docs/stable/nn.html
https://www.tensorflow.org/api_docs/python/tf/keras/layers

I'll create a thread

pale heron
#

Hi everyone!

We're hosting a JAX and OpenXLA "DevLab" on June 23rd and 24th in Sunnyvale. "DevLabs" are small, informal user group gatherings, with a keynote, tutorials, lunch, round table discussions, and a happy hour.

You can find more details about the event here: https://openxla.org/events/summer_devlab_2025

If you'd like to attend, please fill out this form to express interest! https://forms.gle/8dyAAfivmKsmHhFt5

Space is limited, and we'll let you know by 5/30 if we're able to confirm your registration. Thank so much!

vestal python
analog flint
#

GDG AI for Science Australia have two virtual JAX talks coming up. Come along if you are interested:

Google Developer Groups

Virtual Event - Unlock the secrets of JAX! Learn why it's revolutionizing AI development through its modular ecosystem and how leaders use it to build and deploy state-of-the-art models and empower the next generation of intelligent applications.

Google Developer Groups

Virtual Event - Dive into parallel continuous local search for SAT and pseudo-Boolean problems. Learn how JAX enables high-performance, scalable search algorithms to tackle complex problems efficiently. Discover the power of GPU-accelerated optimisation.

hardy dune
#

Why does this channel keep showing up, but never has any postings? Maybe the channel gets only spam, which is then deleted, but it leaves the "new items" annotation in the TOC. Ugh.

misty condor
onyx flare
#

Unfortunately, messages that appear and are then deleted don't remove the white ping / new message marking on Discord :(

hardy dune
#

if a spammer puts a message here, and a bot (or person) later deletes it, the channel in the left column will still show "new messages" even though the only new message has been deleted.

true tree
#

I'm working on porting the VibeVoice TTS model to JAX. So far I have something that produces sensible output https://github.com/pevers/VibeVoice-jax . However, the quality is really bad compared to the HF model. After debugging every layer I believe it is caused by a stacking error because of some small numerical differences in certain operations like nnx.Conv . For testing I run it on an NVidia 5090 with jax-cuda. Anyone knows how to tackle this? Or should I "just" fine-tune the model to fix this?

GitHub

VibeVoice port to JAX. Contribute to pevers/VibeVoice-jax development by creating an account on GitHub.

stuck cobalt
#

Hey everyone! New to JAX and want to start learning.

What's the best beginner path?
โ€ข Any good tutorials for someone coming from NumPy?
โ€ข Simple "hello world" example?

amber igloo