Tools: Any HuggingFace Model on TPUs: A Beginner's Guide to TorchAX Run

Tools: Any HuggingFace Model on TPUs: A Beginner's Guide to TorchAX Run

What if you could run any HuggingFace model on TPUs — without rewriting a single line of model code?

Why This Matters: The HuggingFace + JAX Problem

What is JAX?

Enter TorchAX

TorchAX vs. the Alternatives

Prerequisites & Setup

Key Concepts for Beginners

Pytrees: JAX's Data Containers

JIT Compilation: Translate Once, Run Fast Forever

Static vs. Dynamic Values

Step 1: Your First Forward Pass

Step 2: Speed It Up with JIT Compilation

The extract_jax Approach

Register HuggingFace Output Types as Pytrees

Handle Static Arguments with a Closure

Benchmark: Eager vs. JIT

Step 3: The Simpler API — torchax.compile

Step 4: Text Classification

Step 5: Text Generation (Autoregressive Decoding)

How Autoregressive Decoding Works

The KV Cache Solution

Implementation with StaticCache

Step 6: Distributed Inference (Tensor Parallelism)

How Tensor Parallelism Works

Sharding the Weights

Step 7: Build a Mini Chatbot

Swapping to a Larger Model

Troubleshooting

Conclusion Here is what the end result looks like: Five lines. Your PyTorch model is now executing on JAX — with access to TPUs, JIT compilation, and automatic parallelism across devices. In this tutorial, we will go from zero to building a working chatbot powered by a HuggingFace model running on JAX. Along the way, you will learn key JAX concepts, see real benchmarks, and understand why this approach exists. In 2024, HuggingFace removed native JAX and TensorFlow support from its transformers library to focus development on PyTorch. This left thousands of JAX users — especially those running on Google Cloud TPUs — without a straightforward way to use HuggingFace's massive model collection. If you are new to JAX, think of it as Google's high-performance numerical computing library. It looks like NumPy on the surface, but under the hood it offers three powerful capabilities: JIT Compilation — JAX can compile your Python code into optimized machine code using the XLA compiler. The first run is slower (compilation), but every subsequent call is dramatically faster. TPU Support — JAX is the native programming model for Google's Tensor Processing Units. If you want to use TPUs, JAX is the most natural path. Automatic Parallelism — JAX can automatically distribute computation across multiple devices (TPUs or GPUs) using a single-program model called gSPMD. You describe what should be sharded; the compiler figures out how. torchax is a library from Google that bridges PyTorch and JAX. It works by creating a special torch.Tensor subclass that secretly holds a jax.Array inside. When PyTorch operations are called on this tensor, torchax intercepts them and executes the JAX equivalent instead. Think of it like a Trojan horse: PyTorch thinks it is working with regular tensors, but the computation is actually happening on JAX. This means you can take any PyTorch model — including HuggingFace models — and run it on JAX without modifying the model code at all. Credits: This tutorial builds on the excellent 3-part blog series by Han Qi (@qihqi), the author of torchax, and on the torchax documentation. We expand on those tutorials with beginner-friendly explanations, a different model (Gemma instead of Llama), benchmarks, and a complete Colab-ready notebook. Before diving into code, it helps to understand where torchax fits in the broader ecosystem: When should you use torchax? When you have a PyTorch model (especially from HuggingFace) and want to leverage JAX's JIT compilation, TPU support, or interop with JAX libraries — without rewriting the model. Zero-setup option: Click the Colab badge above. The notebook handles all installation automatically. Before we write code, let's demystify three JAX concepts you will encounter throughout this tutorial. A pytree is any nested structure of Python containers (dicts, lists, tuples) with arrays as leaves. JAX uses pytrees everywhere — model weights are pytrees, function inputs/outputs are pytrees. Think of a pytree like a shipping box with labeled compartments. JAX knows how to open standard boxes (dicts, lists, tuples), pull out all the arrays, do math on them, and put them back. The catch: JAX does not know how to open custom boxes. HuggingFace defines custom output types like CausalLMOutputWithPast — we need to teach JAX how to unpack and repack these. This is called pytree registration, and we will see it in action shortly. JIT (Just-In-Time) compilation is like translating a recipe from English to machine code. The first time you call a JIT-compiled function, JAX traces through it, records all the operations, and compiles an optimized version. Subsequent calls skip the tracing and run the compiled version directly. The speedup can be 10-100x or more. The trade-off is that the compiled function is specialized for the input shapes it was traced with — if shapes change, JAX recompiles. When JAX traces a function for JIT, it treats inputs as abstract shapes, not concrete values. If your code has a branch like if use_cache:, JAX cannot evaluate it during tracing because use_cache is abstract. This causes a ConcretizationTypeError. The fix: mark such values as static (compile-time constants) so JAX knows their actual value during tracing. We will see two ways to do this: closures and static_argnums. Let's load a model and run it on JAX. We will use Gemma 3 1B IT — a small, instruction-tuned model from Google that runs comfortably on free Colab hardware. The output logits tensor has shape (1, sequence_length, vocab_size). Each position contains a score for every token in the vocabulary — the highest score is the model's prediction for the next token. The eager forward pass works, but it is slow — every operation goes through Python one at a time. Let's compile the model for dramatically faster inference. The torchax.extract_jax() function converts a PyTorch model into a pure JAX function: This returns two things: Before we can JIT this function, we need to teach JAX about HuggingFace's custom types: The use_cache flag is a boolean that JAX cannot trace. We wrap it in a closure to make it a compile-time constant: Expected output (times will vary by hardware): The JIT-compiled version runs orders of magnitude faster than eager mode. This is the power of XLA compilation — operations are fused, memory is optimized, and the accelerator runs a single optimized program. The extract_jax + manual JIT approach gives you full control, but for most cases there is a simpler way. The catch is that torchax.compile() uses jax.jit under the hood, so we need to avoid passing dynamic boolean flags like use_cache. We wrap the model in a thin module that bakes in these constants: Under the hood, torchax.compile() wraps your model in a JittableModule and applies jax.jit. The first call triggers compilation; subsequent calls are fast. The NoCacheModel wrapper ensures that boolean flags are constants (not traced) and that the output is a plain tensor (not a custom HuggingFace type that needs pytree registration). Let's use our JIT-compiled model for a practical task — sentiment classification. Since Gemma is an instruction-tuned model, we can use prompt engineering: Classification is useful, but the real power of LLMs is generating text. Let's understand how this works. An LLM predicts one token at a time. Given an input of length n, it produces scores for the next token. We pick one (e.g., the highest-scoring token via greedy decoding), append it to the input, and repeat: The problem: input shapes change every iteration. JIT compilation specializes for fixed shapes, so changing shapes means recompilation every step — worse than eager mode. The KV (Key-Value) cache stores intermediate computations from previous tokens so the model only needs to process the new token each iteration: With a DynamicCache, the cache grows each step — shapes still change. With a StaticCache, the cache has a fixed maximum length — shapes stay constant, making it JIT-friendly. If you have access to multiple devices (e.g., a TPU v2-8 with 8 chips, or multi-GPU), you can shard the model weights across devices for faster inference. In tensor parallelism, we split weight matrices across devices: JAX's gSPMD handles the communication automatically — you just specify how each weight should be sharded. With sharded weights, the same jax.jit-compiled function now runs in parallel across all devices. The XLA compiler automatically inserts the necessary all-reduce operations. Note: Tensor parallelism requires a multi-device environment. On free Colab TPU (single device), this section is for illustration. Use a TPU v2-8 or multi-GPU setup to run it. Let's wrap everything into a simple chat function using Gemma's instruction template: Everything above uses google/gemma-3-1b-it (1B parameters). To use a larger model, change the model name: The rest of the code remains identical. Larger models produce higher quality outputs but require more memory and compute. The 7B model benefits significantly from tensor parallelism on multi-device setups. Other models that work well with torchax include any standard HuggingFace AutoModelForCausalLM architecture — GPT-2, Llama, Mistral, Phi, and more. TypeError: ... is not a valid JAX type

You need to register the type as a pytree. See the registration examples above for CausalLMOutputWithPast, DynamicCache, and StaticCache. ConcretizationTypeError: Abstract tracer value encounteredA value that changes between calls (like a boolean flag) needs to be either: (1) made static via static_argnums in jax.jit, or (2) baked into a closure as a constant. UserWarning: A large amount of constants were capturedModel weights are being inlined as constants in the compiled graph. Pass them as explicit function arguments instead of closing over them. RuntimeError: No available devices

Ensure JAX can see your accelerator: print(jax.devices()). In Colab, check that your runtime type is set to TPU or GPU. In this tutorial, we went from zero to a working chatbot running a HuggingFace model on JAX: The key insight: torchax lets you use the entire HuggingFace ecosystem — models, tokenizers, configs — while running on JAX's high-performance backend. No model rewrites needed. This tutorial would not be possible without the work of: What model will you try running on TPUs first? Let me know in the comments! Templates let you quickly answer FAQs or store snippets for re-use. Hide child comments as well For further actions, you may consider blocking this person and/or reporting abuse

Command

Copy

$ from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16") import torchax torchax.enable_globally() # Enable AFTER loading the model model.to("jax") # That's it. Now running on JAX. from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16") import torchax torchax.enable_globally() # Enable AFTER loading the model model.to("jax") # That's it. Now running on JAX. from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16") import torchax torchax.enable_globally() # Enable AFTER loading the model model.to("jax") # That's it. Now running on JAX. PyTorch Model | v torchax.Tensor (looks like torch.Tensor) | v jax.Array (actual computation on TPU/GPU) PyTorch Model | v torchax.Tensor (looks like torch.Tensor) | v jax.Array (actual computation on TPU/GPU) PyTorch Model | v torchax.Tensor (looks like torch.Tensor) | v jax.Array (actual computation on TPU/GPU) # 1. Install PyTorch (CPU version — torchax handles the accelerator) -weight: 500;">pip -weight: 500;">install torch --index-url https://download.pytorch.org/whl/cpu # Linux # -weight: 500;">pip -weight: 500;">install torch # macOS # 2. Install JAX for your accelerator -weight: 500;">pip -weight: 500;">install -U jax[tpu] # Google Cloud TPU # -weight: 500;">pip -weight: 500;">install -U jax[cuda12] # NVIDIA GPU # -weight: 500;">pip -weight: 500;">install -U jax # CPU only # 3. Install torchax, transformers, and flax (for JAX compatibility) -weight: 500;">pip -weight: 500;">install -U torchax transformers flax # 1. Install PyTorch (CPU version — torchax handles the accelerator) -weight: 500;">pip -weight: 500;">install torch --index-url https://download.pytorch.org/whl/cpu # Linux # -weight: 500;">pip -weight: 500;">install torch # macOS # 2. Install JAX for your accelerator -weight: 500;">pip -weight: 500;">install -U jax[tpu] # Google Cloud TPU # -weight: 500;">pip -weight: 500;">install -U jax[cuda12] # NVIDIA GPU # -weight: 500;">pip -weight: 500;">install -U jax # CPU only # 3. Install torchax, transformers, and flax (for JAX compatibility) -weight: 500;">pip -weight: 500;">install -U torchax transformers flax # 1. Install PyTorch (CPU version — torchax handles the accelerator) -weight: 500;">pip -weight: 500;">install torch --index-url https://download.pytorch.org/whl/cpu # Linux # -weight: 500;">pip -weight: 500;">install torch # macOS # 2. Install JAX for your accelerator -weight: 500;">pip -weight: 500;">install -U jax[tpu] # Google Cloud TPU # -weight: 500;">pip -weight: 500;">install -U jax[cuda12] # NVIDIA GPU # -weight: 500;">pip -weight: 500;">install -U jax # CPU only # 3. Install torchax, transformers, and flax (for JAX compatibility) -weight: 500;">pip -weight: 500;">install -U torchax transformers flax First call: Python code → trace → compile → execute (slow) Second call: compiled code → execute (fast!) First call: Python code → trace → compile → execute (slow) Second call: compiled code → execute (fast!) First call: Python code → trace → compile → execute (slow) Second call: compiled code → execute (fast!) import torch import torchax import jax import time from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer model_name = "google/gemma-3-1b-it" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cpu" ) # Enable torchax globally AFTER model loading # This prevents intercepting unsupported initialization ops torchax.enable_globally() # Move model weights to the JAX device model.to("jax") # Tokenize an input prompt prompt = "The secret to baking a good cake is" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") # Run a forward pass (eager mode) -weight: 500;">start = time.perf_counter() with torch.no_grad(): outputs = model(input_ids, use_cache=False) elapsed = time.perf_counter() - -weight: 500;">start print(f"Output logits shape: {outputs.logits.shape}") print(f"Eager forward pass: {elapsed:.3f}s") import torch import torchax import jax import time from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer model_name = "google/gemma-3-1b-it" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cpu" ) # Enable torchax globally AFTER model loading # This prevents intercepting unsupported initialization ops torchax.enable_globally() # Move model weights to the JAX device model.to("jax") # Tokenize an input prompt prompt = "The secret to baking a good cake is" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") # Run a forward pass (eager mode) -weight: 500;">start = time.perf_counter() with torch.no_grad(): outputs = model(input_ids, use_cache=False) elapsed = time.perf_counter() - -weight: 500;">start print(f"Output logits shape: {outputs.logits.shape}") print(f"Eager forward pass: {elapsed:.3f}s") import torch import torchax import jax import time from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer model_name = "google/gemma-3-1b-it" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cpu" ) # Enable torchax globally AFTER model loading # This prevents intercepting unsupported initialization ops torchax.enable_globally() # Move model weights to the JAX device model.to("jax") # Tokenize an input prompt prompt = "The secret to baking a good cake is" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") # Run a forward pass (eager mode) -weight: 500;">start = time.perf_counter() with torch.no_grad(): outputs = model(input_ids, use_cache=False) elapsed = time.perf_counter() - -weight: 500;">start print(f"Output logits shape: {outputs.logits.shape}") print(f"Eager forward pass: {elapsed:.3f}s") # Extract a JAX-callable function and the model weights as a pytree weights, jax_func = torchax.extract_jax(model) # Extract a JAX-callable function and the model weights as a pytree weights, jax_func = torchax.extract_jax(model) # Extract a JAX-callable function and the model weights as a pytree weights, jax_func = torchax.extract_jax(model) from jax.tree_util import register_pytree_node from transformers import modeling_outputs, cache_utils # Register CausalLMOutputWithPast def output_flatten(v): return v.to_tuple(), None def output_unflatten(aux, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node( modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten, ) # Register DynamicCache def _flatten_dynamic_cache(cache): return (cache.key_cache, cache.value_cache), None def _unflatten_dynamic_cache(aux, children): c = cache_utils.DynamicCache() c.key_cache, c.value_cache = children return c register_pytree_node( cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache, ) from jax.tree_util import register_pytree_node from transformers import modeling_outputs, cache_utils # Register CausalLMOutputWithPast def output_flatten(v): return v.to_tuple(), None def output_unflatten(aux, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node( modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten, ) # Register DynamicCache def _flatten_dynamic_cache(cache): return (cache.key_cache, cache.value_cache), None def _unflatten_dynamic_cache(aux, children): c = cache_utils.DynamicCache() c.key_cache, c.value_cache = children return c register_pytree_node( cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache, ) from jax.tree_util import register_pytree_node from transformers import modeling_outputs, cache_utils # Register CausalLMOutputWithPast def output_flatten(v): return v.to_tuple(), None def output_unflatten(aux, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node( modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten, ) # Register DynamicCache def _flatten_dynamic_cache(cache): return (cache.key_cache, cache.value_cache), None def _unflatten_dynamic_cache(aux, children): c = cache_utils.DynamicCache() c.key_cache, c.value_cache = children return c register_pytree_node( cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache, ) def forward_no_cache(weights, input_ids): return jax_func(weights, (input_ids,), {"use_cache": False}) jitted_forward = jax.jit(forward_no_cache) def forward_no_cache(weights, input_ids): return jax_func(weights, (input_ids,), {"use_cache": False}) jitted_forward = jax.jit(forward_no_cache) def forward_no_cache(weights, input_ids): return jax_func(weights, (input_ids,), {"use_cache": False}) jitted_forward = jax.jit(forward_no_cache) # Convert input to a native JAX array for jax.jit jax_input_ids = jax.device_put(inputs["input_ids"].numpy()) # Warm up (first call triggers compilation) res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) # Benchmark 3 runs for i in range(3): -weight: 500;">start = time.perf_counter() res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) elapsed = time.perf_counter() - -weight: 500;">start print(f"Run {i}: {elapsed:.4f}s") # Convert input to a native JAX array for jax.jit jax_input_ids = jax.device_put(inputs["input_ids"].numpy()) # Warm up (first call triggers compilation) res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) # Benchmark 3 runs for i in range(3): -weight: 500;">start = time.perf_counter() res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) elapsed = time.perf_counter() - -weight: 500;">start print(f"Run {i}: {elapsed:.4f}s") # Convert input to a native JAX array for jax.jit jax_input_ids = jax.device_put(inputs["input_ids"].numpy()) # Warm up (first call triggers compilation) res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) # Benchmark 3 runs for i in range(3): -weight: 500;">start = time.perf_counter() res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) elapsed = time.perf_counter() - -weight: 500;">start print(f"Run {i}: {elapsed:.4f}s") Run 0: 0.0142s # Already compiled from warm-up Run 1: 0.0038s Run 2: 0.0035s Run 0: 0.0142s # Already compiled from warm-up Run 1: 0.0038s Run 2: 0.0035s Run 0: 0.0142s # Already compiled from warm-up Run 1: 0.0038s Run 2: 0.0035s import torch.nn as nn class NoCacheModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, input_ids): # Return only logits to avoid HuggingFace output class pytree issues return self.base_model(input_ids, use_cache=False, return_dict=False)[0] # One-liner: compile the wrapped model compiled_model = torchax.compile(NoCacheModel(model)) # Use it like a normal PyTorch model with torch.no_grad(): logits = compiled_model(input_ids) import torch.nn as nn class NoCacheModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, input_ids): # Return only logits to avoid HuggingFace output class pytree issues return self.base_model(input_ids, use_cache=False, return_dict=False)[0] # One-liner: compile the wrapped model compiled_model = torchax.compile(NoCacheModel(model)) # Use it like a normal PyTorch model with torch.no_grad(): logits = compiled_model(input_ids) import torch.nn as nn class NoCacheModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, input_ids): # Return only logits to avoid HuggingFace output class pytree issues return self.base_model(input_ids, use_cache=False, return_dict=False)[0] # One-liner: compile the wrapped model compiled_model = torchax.compile(NoCacheModel(model)) # Use it like a normal PyTorch model with torch.no_grad(): logits = compiled_model(input_ids) def classify_sentiment(text, model, tokenizer): prompt = f"""Classify the following text as POSITIVE or NEGATIVE. Text: "{text}" Classification:""" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") with torch.no_grad(): outputs = model(input_ids, use_cache=False) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() prediction = tokenizer.decode([next_token_id]).strip() return prediction # Test it texts = [ "This movie was absolutely fantastic, I loved every minute!", "The -weight: 500;">service was terrible and the food was cold.", "A perfectly average experience, nothing special.", ] for text in texts: result = classify_sentiment(text, model, tokenizer) print(f"Text: {text[:50]}... => {result}") def classify_sentiment(text, model, tokenizer): prompt = f"""Classify the following text as POSITIVE or NEGATIVE. Text: "{text}" Classification:""" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") with torch.no_grad(): outputs = model(input_ids, use_cache=False) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() prediction = tokenizer.decode([next_token_id]).strip() return prediction # Test it texts = [ "This movie was absolutely fantastic, I loved every minute!", "The -weight: 500;">service was terrible and the food was cold.", "A perfectly average experience, nothing special.", ] for text in texts: result = classify_sentiment(text, model, tokenizer) print(f"Text: {text[:50]}... => {result}") def classify_sentiment(text, model, tokenizer): prompt = f"""Classify the following text as POSITIVE or NEGATIVE. Text: "{text}" Classification:""" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") with torch.no_grad(): outputs = model(input_ids, use_cache=False) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() prediction = tokenizer.decode([next_token_id]).strip() return prediction # Test it texts = [ "This movie was absolutely fantastic, I loved every minute!", "The -weight: 500;">service was terrible and the food was cold.", "A perfectly average experience, nothing special.", ] for text in texts: result = classify_sentiment(text, model, tokenizer) print(f"Text: {text[:50]}... => {result}") Iteration 1: input (1, n) → output (1, n) → pick token Iteration 2: input (1, n+1) → output (1, n+1) → pick token Iteration 3: input (1, n+2) → output (1, n+2) → pick token ... Iteration 1: input (1, n) → output (1, n) → pick token Iteration 2: input (1, n+1) → output (1, n+1) → pick token Iteration 3: input (1, n+2) → output (1, n+2) → pick token ... Iteration 1: input (1, n) → output (1, n) → pick token Iteration 2: input (1, n+1) → output (1, n+1) → pick token Iteration 3: input (1, n+2) → output (1, n+2) → pick token ... Iteration 1: input (1, n) → output + kv_cache(n) Iteration 2: input (1, 1) + cache(n) → output + kv_cache(n+1) Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2) Iteration 1: input (1, n) → output + kv_cache(n) Iteration 2: input (1, 1) + cache(n) → output + kv_cache(n+1) Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2) Iteration 1: input (1, n) → output + kv_cache(n) Iteration 2: input (1, 1) + cache(n) → output + kv_cache(n+1) Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2) from transformers.cache_utils import StaticCache # Register StaticCache as a pytree def _flatten_static_cache(cache): return ( cache.key_cache, cache.value_cache ), (cache.config, cache.max_batch_size, cache.max_cache_len, getattr(cache, "device", None), getattr(cache, "dtype", None)) def _unflatten_static_cache(aux, children): config, max_batch_size, max_cache_len, device, dtype = aux kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs) cache.key_cache, cache.value_cache = children return cache register_pytree_node( StaticCache, _flatten_static_cache, _unflatten_static_cache, ) def generate_text(model, tokenizer, prompt, max_new_tokens=50): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") batch_size, seq_length = input_ids.shape # Create a static cache with fixed maximum length past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device="jax", dtype=model.dtype, ) cache_position = torch.arange(seq_length, device="jax") # Prefill: process the full prompt with torch.no_grad(): logits, past_key_values = model( input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids = [next_token[:, 0].item()] cache_position = torch.tensor([seq_length], device="jax") # Decode: generate one token at a time for _ in range(max_new_tokens - 1): with torch.no_grad(): logits, past_key_values = model( next_token, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] token_id = next_token[:, 0].item() if token_id == tokenizer.eos_token_id: break generated_ids.append(token_id) cache_position += 1 return tokenizer.decode(generated_ids, skip_special_tokens=True) # Generate! result = generate_text(model, tokenizer, "The secret to baking a good cake is") print(result) from transformers.cache_utils import StaticCache # Register StaticCache as a pytree def _flatten_static_cache(cache): return ( cache.key_cache, cache.value_cache ), (cache.config, cache.max_batch_size, cache.max_cache_len, getattr(cache, "device", None), getattr(cache, "dtype", None)) def _unflatten_static_cache(aux, children): config, max_batch_size, max_cache_len, device, dtype = aux kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs) cache.key_cache, cache.value_cache = children return cache register_pytree_node( StaticCache, _flatten_static_cache, _unflatten_static_cache, ) def generate_text(model, tokenizer, prompt, max_new_tokens=50): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") batch_size, seq_length = input_ids.shape # Create a static cache with fixed maximum length past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device="jax", dtype=model.dtype, ) cache_position = torch.arange(seq_length, device="jax") # Prefill: process the full prompt with torch.no_grad(): logits, past_key_values = model( input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids = [next_token[:, 0].item()] cache_position = torch.tensor([seq_length], device="jax") # Decode: generate one token at a time for _ in range(max_new_tokens - 1): with torch.no_grad(): logits, past_key_values = model( next_token, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] token_id = next_token[:, 0].item() if token_id == tokenizer.eos_token_id: break generated_ids.append(token_id) cache_position += 1 return tokenizer.decode(generated_ids, skip_special_tokens=True) # Generate! result = generate_text(model, tokenizer, "The secret to baking a good cake is") print(result) from transformers.cache_utils import StaticCache # Register StaticCache as a pytree def _flatten_static_cache(cache): return ( cache.key_cache, cache.value_cache ), (cache.config, cache.max_batch_size, cache.max_cache_len, getattr(cache, "device", None), getattr(cache, "dtype", None)) def _unflatten_static_cache(aux, children): config, max_batch_size, max_cache_len, device, dtype = aux kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs) cache.key_cache, cache.value_cache = children return cache register_pytree_node( StaticCache, _flatten_static_cache, _unflatten_static_cache, ) def generate_text(model, tokenizer, prompt, max_new_tokens=50): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") batch_size, seq_length = input_ids.shape # Create a static cache with fixed maximum length past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device="jax", dtype=model.dtype, ) cache_position = torch.arange(seq_length, device="jax") # Prefill: process the full prompt with torch.no_grad(): logits, past_key_values = model( input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids = [next_token[:, 0].item()] cache_position = torch.tensor([seq_length], device="jax") # Decode: generate one token at a time for _ in range(max_new_tokens - 1): with torch.no_grad(): logits, past_key_values = model( next_token, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] token_id = next_token[:, 0].item() if token_id == tokenizer.eos_token_id: break generated_ids.append(token_id) cache_position += 1 return tokenizer.decode(generated_ids, skip_special_tokens=True) # Generate! result = generate_text(model, tokenizer, "The secret to baking a good cake is") print(result) from jax.sharding import PartitionSpec as P, NamedSharding # Create a device mesh mesh = jax.make_mesh((jax.device_count(),), ("axis",)) def shard_weights(mesh, weights): sharded = {} for name, tensor in weights.items(): if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): spec = P("axis", None) # Column-parallel elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]): spec = P(None, "axis") # Row-parallel else: spec = P() # Replicate (e.g., layer norms) sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec)) return sharded # Apply sharding weights, jax_func = torchax.extract_jax(model) weights = shard_weights(mesh, weights) # Replicate the input across all devices input_ids_sharded = jax.device_put( inputs["input_ids"], NamedSharding(mesh, P()) ) from jax.sharding import PartitionSpec as P, NamedSharding # Create a device mesh mesh = jax.make_mesh((jax.device_count(),), ("axis",)) def shard_weights(mesh, weights): sharded = {} for name, tensor in weights.items(): if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): spec = P("axis", None) # Column-parallel elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]): spec = P(None, "axis") # Row-parallel else: spec = P() # Replicate (e.g., layer norms) sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec)) return sharded # Apply sharding weights, jax_func = torchax.extract_jax(model) weights = shard_weights(mesh, weights) # Replicate the input across all devices input_ids_sharded = jax.device_put( inputs["input_ids"], NamedSharding(mesh, P()) ) from jax.sharding import PartitionSpec as P, NamedSharding # Create a device mesh mesh = jax.make_mesh((jax.device_count(),), ("axis",)) def shard_weights(mesh, weights): sharded = {} for name, tensor in weights.items(): if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): spec = P("axis", None) # Column-parallel elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]): spec = P(None, "axis") # Row-parallel else: spec = P() # Replicate (e.g., layer norms) sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec)) return sharded # Apply sharding weights, jax_func = torchax.extract_jax(model) weights = shard_weights(mesh, weights) # Replicate the input across all devices input_ids_sharded = jax.device_put( inputs["input_ids"], NamedSharding(mesh, P()) ) def chat(model, tokenizer, user_message, max_new_tokens=100): # Gemma instruction format prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n" response = generate_text(model, tokenizer, prompt, max_new_tokens) return response # Example conversation questions = [ "What is JAX and why would I use it?", "Explain tensor parallelism in simple terms.", "Write a haiku about machine learning.", ] for q in questions: print(f"User: {q}") print(f"Gemma: {chat(model, tokenizer, q)}") print() def chat(model, tokenizer, user_message, max_new_tokens=100): # Gemma instruction format prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n" response = generate_text(model, tokenizer, prompt, max_new_tokens) return response # Example conversation questions = [ "What is JAX and why would I use it?", "Explain tensor parallelism in simple terms.", "Write a haiku about machine learning.", ] for q in questions: print(f"User: {q}") print(f"Gemma: {chat(model, tokenizer, q)}") print() def chat(model, tokenizer, user_message, max_new_tokens=100): # Gemma instruction format prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n" response = generate_text(model, tokenizer, prompt, max_new_tokens) return response # Example conversation questions = [ "What is JAX and why would I use it?", "Explain tensor parallelism in simple terms.", "Write a haiku about machine learning.", ] for q in questions: print(f"User: {q}") print(f"Gemma: {chat(model, tokenizer, q)}") print() # 7B model — needs more memory (Colab Pro or multi-device) model_name = "google/gemma-3-7b-it" # 7B model — needs more memory (Colab Pro or multi-device) model_name = "google/gemma-3-7b-it" # 7B model — needs more memory (Colab Pro or multi-device) model_name = "google/gemma-3-7b-it" - JIT Compilation — JAX can compile your Python code into optimized machine code using the XLA compiler. The first run is slower (compilation), but every subsequent call is dramatically faster. - TPU Support — JAX is the native programming model for Google's Tensor Processing Units. If you want to use TPUs, JAX is the most natural path. - Automatic Parallelism — JAX can automatically distribute computation across multiple devices (TPUs or GPUs) using a single-program model called gSPMD. You describe what should be sharded; the compiler figures out how. - Python 3.10+ - Basic familiarity with PyTorch (loading models, running inference) - A Google Colab account (free tier works for the 1B model) - We load the model on CPU first, then call torchax.enable_globally(). This ordering is important — enabling torchax before model loading can intercept unsupported initialization ops and cause errors. - model.to("jax") moves every parameter from CPU to the JAX device — just like model.to("cuda") for GPUs. - The forward pass runs through PyTorch's code path, but every operation is executed by JAX under the hood. - weights — the model's state_dict as a pytree of jax.Arrays - jax_func — a function with signature jax_func(weights, args_tuple, kwargs_dict) - Column-parallel: Q, K, V, Gate, and Up projections are split along the output dimension - Row-parallel: O and Down projections are split along the input dimension - Between these two, only a single all-reduce operation is needed per layer - Forward pass — moved a PyTorch model to JAX with model.to("jax") - JIT compilation — compiled for 10-100x speedup with jax.jit - Text classification — used prompt engineering for sentiment analysis - Text generation — implemented autoregressive decoding with StaticCache - Distributed inference — sharded weights across devices with tensor parallelism - Chatbot — wrapped generation in an instruction-following chat function - torchax GitHub — library source and documentation - torchax Docs — official getting started guide - Original tutorial series by Han Qi — the 3-part blog series this tutorial builds on - JAX Documentation — JIT compilation, pytrees, distributed arrays - HuggingFace LLM Inference Optimization — StaticCache and torch.compile docs - Companion GitHub repo — all code, notebooks, and diagrams - Han Qi (@qihqi) — author of torchax and the original HuggingFace + JAX tutorial series - The torchax team at Google — for building and maintaining the library - The HuggingFace team — for the transformers ecosystem - The JAX team at Google — for JAX, XLA, and TPU support