$ from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16") import torchax
torchax.enable_globally() # Enable AFTER loading the model model.to("jax") # That's it. Now running on JAX.
from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16") import torchax
torchax.enable_globally() # Enable AFTER loading the model model.to("jax") # That's it. Now running on JAX.
from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16") import torchax
torchax.enable_globally() # Enable AFTER loading the model model.to("jax") # That's it. Now running on JAX.
PyTorch Model | v
torchax.Tensor (looks like torch.Tensor) | v
jax.Array (actual computation on TPU/GPU)
PyTorch Model | v
torchax.Tensor (looks like torch.Tensor) | v
jax.Array (actual computation on TPU/GPU)
PyTorch Model | v
torchax.Tensor (looks like torch.Tensor) | v
jax.Array (actual computation on TPU/GPU)
# 1. Install PyTorch (CPU version — torchax handles the accelerator)
-weight: 500;">pip -weight: 500;">install torch --index-url https://download.pytorch.org/whl/cpu # Linux
# -weight: 500;">pip -weight: 500;">install torch # macOS # 2. Install JAX for your accelerator
-weight: 500;">pip -weight: 500;">install -U jax[tpu] # Google Cloud TPU
# -weight: 500;">pip -weight: 500;">install -U jax[cuda12] # NVIDIA GPU
# -weight: 500;">pip -weight: 500;">install -U jax # CPU only # 3. Install torchax, transformers, and flax (for JAX compatibility)
-weight: 500;">pip -weight: 500;">install -U torchax transformers flax
# 1. Install PyTorch (CPU version — torchax handles the accelerator)
-weight: 500;">pip -weight: 500;">install torch --index-url https://download.pytorch.org/whl/cpu # Linux
# -weight: 500;">pip -weight: 500;">install torch # macOS # 2. Install JAX for your accelerator
-weight: 500;">pip -weight: 500;">install -U jax[tpu] # Google Cloud TPU
# -weight: 500;">pip -weight: 500;">install -U jax[cuda12] # NVIDIA GPU
# -weight: 500;">pip -weight: 500;">install -U jax # CPU only # 3. Install torchax, transformers, and flax (for JAX compatibility)
-weight: 500;">pip -weight: 500;">install -U torchax transformers flax
# 1. Install PyTorch (CPU version — torchax handles the accelerator)
-weight: 500;">pip -weight: 500;">install torch --index-url https://download.pytorch.org/whl/cpu # Linux
# -weight: 500;">pip -weight: 500;">install torch # macOS # 2. Install JAX for your accelerator
-weight: 500;">pip -weight: 500;">install -U jax[tpu] # Google Cloud TPU
# -weight: 500;">pip -weight: 500;">install -U jax[cuda12] # NVIDIA GPU
# -weight: 500;">pip -weight: 500;">install -U jax # CPU only # 3. Install torchax, transformers, and flax (for JAX compatibility)
-weight: 500;">pip -weight: 500;">install -U torchax transformers flax
First call: Python code → trace → compile → execute (slow)
Second call: compiled code → execute (fast!)
First call: Python code → trace → compile → execute (slow)
Second call: compiled code → execute (fast!)
First call: Python code → trace → compile → execute (slow)
Second call: compiled code → execute (fast!)
import torch
import torchax
import jax
import time from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer
model_name = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cpu"
) # Enable torchax globally AFTER model loading
# This prevents intercepting unsupported initialization ops
torchax.enable_globally() # Move model weights to the JAX device
model.to("jax") # Tokenize an input prompt
prompt = "The secret to baking a good cake is"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to("jax") # Run a forward pass (eager mode)
-weight: 500;">start = time.perf_counter()
with torch.no_grad(): outputs = model(input_ids, use_cache=False)
elapsed = time.perf_counter() - -weight: 500;">start print(f"Output logits shape: {outputs.logits.shape}")
print(f"Eager forward pass: {elapsed:.3f}s")
import torch
import torchax
import jax
import time from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer
model_name = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cpu"
) # Enable torchax globally AFTER model loading
# This prevents intercepting unsupported initialization ops
torchax.enable_globally() # Move model weights to the JAX device
model.to("jax") # Tokenize an input prompt
prompt = "The secret to baking a good cake is"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to("jax") # Run a forward pass (eager mode)
-weight: 500;">start = time.perf_counter()
with torch.no_grad(): outputs = model(input_ids, use_cache=False)
elapsed = time.perf_counter() - -weight: 500;">start print(f"Output logits shape: {outputs.logits.shape}")
print(f"Eager forward pass: {elapsed:.3f}s")
import torch
import torchax
import jax
import time from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer
model_name = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cpu"
) # Enable torchax globally AFTER model loading
# This prevents intercepting unsupported initialization ops
torchax.enable_globally() # Move model weights to the JAX device
model.to("jax") # Tokenize an input prompt
prompt = "The secret to baking a good cake is"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to("jax") # Run a forward pass (eager mode)
-weight: 500;">start = time.perf_counter()
with torch.no_grad(): outputs = model(input_ids, use_cache=False)
elapsed = time.perf_counter() - -weight: 500;">start print(f"Output logits shape: {outputs.logits.shape}")
print(f"Eager forward pass: {elapsed:.3f}s")
# Extract a JAX-callable function and the model weights as a pytree
weights, jax_func = torchax.extract_jax(model)
# Extract a JAX-callable function and the model weights as a pytree
weights, jax_func = torchax.extract_jax(model)
# Extract a JAX-callable function and the model weights as a pytree
weights, jax_func = torchax.extract_jax(model)
from jax.tree_util import register_pytree_node
from transformers import modeling_outputs, cache_utils # Register CausalLMOutputWithPast
def output_flatten(v): return v.to_tuple(), None def output_unflatten(aux, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node( modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten,
) # Register DynamicCache
def _flatten_dynamic_cache(cache): return (cache.key_cache, cache.value_cache), None def _unflatten_dynamic_cache(aux, children): c = cache_utils.DynamicCache() c.key_cache, c.value_cache = children return c register_pytree_node( cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache,
)
from jax.tree_util import register_pytree_node
from transformers import modeling_outputs, cache_utils # Register CausalLMOutputWithPast
def output_flatten(v): return v.to_tuple(), None def output_unflatten(aux, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node( modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten,
) # Register DynamicCache
def _flatten_dynamic_cache(cache): return (cache.key_cache, cache.value_cache), None def _unflatten_dynamic_cache(aux, children): c = cache_utils.DynamicCache() c.key_cache, c.value_cache = children return c register_pytree_node( cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache,
)
from jax.tree_util import register_pytree_node
from transformers import modeling_outputs, cache_utils # Register CausalLMOutputWithPast
def output_flatten(v): return v.to_tuple(), None def output_unflatten(aux, children): return modeling_outputs.CausalLMOutputWithPast(*children) register_pytree_node( modeling_outputs.CausalLMOutputWithPast, output_flatten, output_unflatten,
) # Register DynamicCache
def _flatten_dynamic_cache(cache): return (cache.key_cache, cache.value_cache), None def _unflatten_dynamic_cache(aux, children): c = cache_utils.DynamicCache() c.key_cache, c.value_cache = children return c register_pytree_node( cache_utils.DynamicCache, _flatten_dynamic_cache, _unflatten_dynamic_cache,
)
def forward_no_cache(weights, input_ids): return jax_func(weights, (input_ids,), {"use_cache": False}) jitted_forward = jax.jit(forward_no_cache)
def forward_no_cache(weights, input_ids): return jax_func(weights, (input_ids,), {"use_cache": False}) jitted_forward = jax.jit(forward_no_cache)
def forward_no_cache(weights, input_ids): return jax_func(weights, (input_ids,), {"use_cache": False}) jitted_forward = jax.jit(forward_no_cache)
# Convert input to a native JAX array for jax.jit
jax_input_ids = jax.device_put(inputs["input_ids"].numpy()) # Warm up (first call triggers compilation)
res = jitted_forward(weights, jax_input_ids)
jax.block_until_ready(res) # Benchmark 3 runs
for i in range(3): -weight: 500;">start = time.perf_counter() res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) elapsed = time.perf_counter() - -weight: 500;">start print(f"Run {i}: {elapsed:.4f}s")
# Convert input to a native JAX array for jax.jit
jax_input_ids = jax.device_put(inputs["input_ids"].numpy()) # Warm up (first call triggers compilation)
res = jitted_forward(weights, jax_input_ids)
jax.block_until_ready(res) # Benchmark 3 runs
for i in range(3): -weight: 500;">start = time.perf_counter() res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) elapsed = time.perf_counter() - -weight: 500;">start print(f"Run {i}: {elapsed:.4f}s")
# Convert input to a native JAX array for jax.jit
jax_input_ids = jax.device_put(inputs["input_ids"].numpy()) # Warm up (first call triggers compilation)
res = jitted_forward(weights, jax_input_ids)
jax.block_until_ready(res) # Benchmark 3 runs
for i in range(3): -weight: 500;">start = time.perf_counter() res = jitted_forward(weights, jax_input_ids) jax.block_until_ready(res) elapsed = time.perf_counter() - -weight: 500;">start print(f"Run {i}: {elapsed:.4f}s")
Run 0: 0.0142s # Already compiled from warm-up
Run 1: 0.0038s
Run 2: 0.0035s
Run 0: 0.0142s # Already compiled from warm-up
Run 1: 0.0038s
Run 2: 0.0035s
Run 0: 0.0142s # Already compiled from warm-up
Run 1: 0.0038s
Run 2: 0.0035s
import torch.nn as nn class NoCacheModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, input_ids): # Return only logits to avoid HuggingFace output class pytree issues return self.base_model(input_ids, use_cache=False, return_dict=False)[0] # One-liner: compile the wrapped model
compiled_model = torchax.compile(NoCacheModel(model)) # Use it like a normal PyTorch model
with torch.no_grad(): logits = compiled_model(input_ids)
import torch.nn as nn class NoCacheModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, input_ids): # Return only logits to avoid HuggingFace output class pytree issues return self.base_model(input_ids, use_cache=False, return_dict=False)[0] # One-liner: compile the wrapped model
compiled_model = torchax.compile(NoCacheModel(model)) # Use it like a normal PyTorch model
with torch.no_grad(): logits = compiled_model(input_ids)
import torch.nn as nn class NoCacheModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, input_ids): # Return only logits to avoid HuggingFace output class pytree issues return self.base_model(input_ids, use_cache=False, return_dict=False)[0] # One-liner: compile the wrapped model
compiled_model = torchax.compile(NoCacheModel(model)) # Use it like a normal PyTorch model
with torch.no_grad(): logits = compiled_model(input_ids)
def classify_sentiment(text, model, tokenizer): prompt = f"""Classify the following text as POSITIVE or NEGATIVE.
Text: "{text}"
Classification:""" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") with torch.no_grad(): outputs = model(input_ids, use_cache=False) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() prediction = tokenizer.decode([next_token_id]).strip() return prediction # Test it
texts = [ "This movie was absolutely fantastic, I loved every minute!", "The -weight: 500;">service was terrible and the food was cold.", "A perfectly average experience, nothing special.",
] for text in texts: result = classify_sentiment(text, model, tokenizer) print(f"Text: {text[:50]}... => {result}")
def classify_sentiment(text, model, tokenizer): prompt = f"""Classify the following text as POSITIVE or NEGATIVE.
Text: "{text}"
Classification:""" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") with torch.no_grad(): outputs = model(input_ids, use_cache=False) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() prediction = tokenizer.decode([next_token_id]).strip() return prediction # Test it
texts = [ "This movie was absolutely fantastic, I loved every minute!", "The -weight: 500;">service was terrible and the food was cold.", "A perfectly average experience, nothing special.",
] for text in texts: result = classify_sentiment(text, model, tokenizer) print(f"Text: {text[:50]}... => {result}")
def classify_sentiment(text, model, tokenizer): prompt = f"""Classify the following text as POSITIVE or NEGATIVE.
Text: "{text}"
Classification:""" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") with torch.no_grad(): outputs = model(input_ids, use_cache=False) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() prediction = tokenizer.decode([next_token_id]).strip() return prediction # Test it
texts = [ "This movie was absolutely fantastic, I loved every minute!", "The -weight: 500;">service was terrible and the food was cold.", "A perfectly average experience, nothing special.",
] for text in texts: result = classify_sentiment(text, model, tokenizer) print(f"Text: {text[:50]}... => {result}")
Iteration 1: input (1, n) → output (1, n) → pick token
Iteration 2: input (1, n+1) → output (1, n+1) → pick token
Iteration 3: input (1, n+2) → output (1, n+2) → pick token
...
Iteration 1: input (1, n) → output (1, n) → pick token
Iteration 2: input (1, n+1) → output (1, n+1) → pick token
Iteration 3: input (1, n+2) → output (1, n+2) → pick token
...
Iteration 1: input (1, n) → output (1, n) → pick token
Iteration 2: input (1, n+1) → output (1, n+1) → pick token
Iteration 3: input (1, n+2) → output (1, n+2) → pick token
...
Iteration 1: input (1, n) → output + kv_cache(n)
Iteration 2: input (1, 1) + cache(n) → output + kv_cache(n+1)
Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2)
Iteration 1: input (1, n) → output + kv_cache(n)
Iteration 2: input (1, 1) + cache(n) → output + kv_cache(n+1)
Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2)
Iteration 1: input (1, n) → output + kv_cache(n)
Iteration 2: input (1, 1) + cache(n) → output + kv_cache(n+1)
Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2)
from transformers.cache_utils import StaticCache # Register StaticCache as a pytree
def _flatten_static_cache(cache): return ( cache.key_cache, cache.value_cache ), (cache.config, cache.max_batch_size, cache.max_cache_len, getattr(cache, "device", None), getattr(cache, "dtype", None)) def _unflatten_static_cache(aux, children): config, max_batch_size, max_cache_len, device, dtype = aux kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs) cache.key_cache, cache.value_cache = children return cache register_pytree_node( StaticCache, _flatten_static_cache, _unflatten_static_cache,
) def generate_text(model, tokenizer, prompt, max_new_tokens=50): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") batch_size, seq_length = input_ids.shape # Create a static cache with fixed maximum length past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device="jax", dtype=model.dtype, ) cache_position = torch.arange(seq_length, device="jax") # Prefill: process the full prompt with torch.no_grad(): logits, past_key_values = model( input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids = [next_token[:, 0].item()] cache_position = torch.tensor([seq_length], device="jax") # Decode: generate one token at a time for _ in range(max_new_tokens - 1): with torch.no_grad(): logits, past_key_values = model( next_token, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] token_id = next_token[:, 0].item() if token_id == tokenizer.eos_token_id: break generated_ids.append(token_id) cache_position += 1 return tokenizer.decode(generated_ids, skip_special_tokens=True) # Generate!
result = generate_text(model, tokenizer, "The secret to baking a good cake is")
print(result)
from transformers.cache_utils import StaticCache # Register StaticCache as a pytree
def _flatten_static_cache(cache): return ( cache.key_cache, cache.value_cache ), (cache.config, cache.max_batch_size, cache.max_cache_len, getattr(cache, "device", None), getattr(cache, "dtype", None)) def _unflatten_static_cache(aux, children): config, max_batch_size, max_cache_len, device, dtype = aux kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs) cache.key_cache, cache.value_cache = children return cache register_pytree_node( StaticCache, _flatten_static_cache, _unflatten_static_cache,
) def generate_text(model, tokenizer, prompt, max_new_tokens=50): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") batch_size, seq_length = input_ids.shape # Create a static cache with fixed maximum length past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device="jax", dtype=model.dtype, ) cache_position = torch.arange(seq_length, device="jax") # Prefill: process the full prompt with torch.no_grad(): logits, past_key_values = model( input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids = [next_token[:, 0].item()] cache_position = torch.tensor([seq_length], device="jax") # Decode: generate one token at a time for _ in range(max_new_tokens - 1): with torch.no_grad(): logits, past_key_values = model( next_token, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] token_id = next_token[:, 0].item() if token_id == tokenizer.eos_token_id: break generated_ids.append(token_id) cache_position += 1 return tokenizer.decode(generated_ids, skip_special_tokens=True) # Generate!
result = generate_text(model, tokenizer, "The secret to baking a good cake is")
print(result)
from transformers.cache_utils import StaticCache # Register StaticCache as a pytree
def _flatten_static_cache(cache): return ( cache.key_cache, cache.value_cache ), (cache.config, cache.max_batch_size, cache.max_cache_len, getattr(cache, "device", None), getattr(cache, "dtype", None)) def _unflatten_static_cache(aux, children): config, max_batch_size, max_cache_len, device, dtype = aux kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs) cache.key_cache, cache.value_cache = children return cache register_pytree_node( StaticCache, _flatten_static_cache, _unflatten_static_cache,
) def generate_text(model, tokenizer, prompt, max_new_tokens=50): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to("jax") batch_size, seq_length = input_ids.shape # Create a static cache with fixed maximum length past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device="jax", dtype=model.dtype, ) cache_position = torch.arange(seq_length, device="jax") # Prefill: process the full prompt with torch.no_grad(): logits, past_key_values = model( input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] generated_ids = [next_token[:, 0].item()] cache_position = torch.tensor([seq_length], device="jax") # Decode: generate one token at a time for _ in range(max_new_tokens - 1): with torch.no_grad(): logits, past_key_values = model( next_token, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, ) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] token_id = next_token[:, 0].item() if token_id == tokenizer.eos_token_id: break generated_ids.append(token_id) cache_position += 1 return tokenizer.decode(generated_ids, skip_special_tokens=True) # Generate!
result = generate_text(model, tokenizer, "The secret to baking a good cake is")
print(result)
from jax.sharding import PartitionSpec as P, NamedSharding # Create a device mesh
mesh = jax.make_mesh((jax.device_count(),), ("axis",)) def shard_weights(mesh, weights): sharded = {} for name, tensor in weights.items(): if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): spec = P("axis", None) # Column-parallel elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]): spec = P(None, "axis") # Row-parallel else: spec = P() # Replicate (e.g., layer norms) sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec)) return sharded # Apply sharding
weights, jax_func = torchax.extract_jax(model)
weights = shard_weights(mesh, weights) # Replicate the input across all devices
input_ids_sharded = jax.device_put( inputs["input_ids"], NamedSharding(mesh, P())
)
from jax.sharding import PartitionSpec as P, NamedSharding # Create a device mesh
mesh = jax.make_mesh((jax.device_count(),), ("axis",)) def shard_weights(mesh, weights): sharded = {} for name, tensor in weights.items(): if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): spec = P("axis", None) # Column-parallel elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]): spec = P(None, "axis") # Row-parallel else: spec = P() # Replicate (e.g., layer norms) sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec)) return sharded # Apply sharding
weights, jax_func = torchax.extract_jax(model)
weights = shard_weights(mesh, weights) # Replicate the input across all devices
input_ids_sharded = jax.device_put( inputs["input_ids"], NamedSharding(mesh, P())
)
from jax.sharding import PartitionSpec as P, NamedSharding # Create a device mesh
mesh = jax.make_mesh((jax.device_count(),), ("axis",)) def shard_weights(mesh, weights): sharded = {} for name, tensor in weights.items(): if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): spec = P("axis", None) # Column-parallel elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]): spec = P(None, "axis") # Row-parallel else: spec = P() # Replicate (e.g., layer norms) sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec)) return sharded # Apply sharding
weights, jax_func = torchax.extract_jax(model)
weights = shard_weights(mesh, weights) # Replicate the input across all devices
input_ids_sharded = jax.device_put( inputs["input_ids"], NamedSharding(mesh, P())
)
def chat(model, tokenizer, user_message, max_new_tokens=100): # Gemma instruction format prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n" response = generate_text(model, tokenizer, prompt, max_new_tokens) return response # Example conversation
questions = [ "What is JAX and why would I use it?", "Explain tensor parallelism in simple terms.", "Write a haiku about machine learning.",
] for q in questions: print(f"User: {q}") print(f"Gemma: {chat(model, tokenizer, q)}") print()
def chat(model, tokenizer, user_message, max_new_tokens=100): # Gemma instruction format prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n" response = generate_text(model, tokenizer, prompt, max_new_tokens) return response # Example conversation
questions = [ "What is JAX and why would I use it?", "Explain tensor parallelism in simple terms.", "Write a haiku about machine learning.",
] for q in questions: print(f"User: {q}") print(f"Gemma: {chat(model, tokenizer, q)}") print()
def chat(model, tokenizer, user_message, max_new_tokens=100): # Gemma instruction format prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n" response = generate_text(model, tokenizer, prompt, max_new_tokens) return response # Example conversation
questions = [ "What is JAX and why would I use it?", "Explain tensor parallelism in simple terms.", "Write a haiku about machine learning.",
] for q in questions: print(f"User: {q}") print(f"Gemma: {chat(model, tokenizer, q)}") print()
# 7B model — needs more memory (Colab Pro or multi-device)
model_name = "google/gemma-3-7b-it"
# 7B model — needs more memory (Colab Pro or multi-device)
model_name = "google/gemma-3-7b-it"
# 7B model — needs more memory (Colab Pro or multi-device)
model_name = "google/gemma-3-7b-it" - JIT Compilation — JAX can compile your Python code into optimized machine code using the XLA compiler. The first run is slower (compilation), but every subsequent call is dramatically faster.
- TPU Support — JAX is the native programming model for Google's Tensor Processing Units. If you want to use TPUs, JAX is the most natural path.
- Automatic Parallelism — JAX can automatically distribute computation across multiple devices (TPUs or GPUs) using a single-program model called gSPMD. You describe what should be sharded; the compiler figures out how. - Python 3.10+
- Basic familiarity with PyTorch (loading models, running inference)
- A Google Colab account (free tier works for the 1B model) - We load the model on CPU first, then call torchax.enable_globally(). This ordering is important — enabling torchax before model loading can intercept unsupported initialization ops and cause errors.
- model.to("jax") moves every parameter from CPU to the JAX device — just like model.to("cuda") for GPUs.
- The forward pass runs through PyTorch's code path, but every operation is executed by JAX under the hood. - weights — the model's state_dict as a pytree of jax.Arrays
- jax_func — a function with signature jax_func(weights, args_tuple, kwargs_dict) - Column-parallel: Q, K, V, Gate, and Up projections are split along the output dimension
- Row-parallel: O and Down projections are split along the input dimension
- Between these two, only a single all-reduce operation is needed per layer - Forward pass — moved a PyTorch model to JAX with model.to("jax")
- JIT compilation — compiled for 10-100x speedup with jax.jit
- Text classification — used prompt engineering for sentiment analysis
- Text generation — implemented autoregressive decoding with StaticCache
- Distributed inference — sharded weights across devices with tensor parallelism
- Chatbot — wrapped generation in an instruction-following chat function - torchax GitHub — library source and documentation
- torchax Docs — official getting started guide
- Original tutorial series by Han Qi — the 3-part blog series this tutorial builds on
- JAX Documentation — JIT compilation, pytrees, distributed arrays
- HuggingFace LLM Inference Optimization — StaticCache and torch.compile docs
- Companion GitHub repo — all code, notebooks, and diagrams - Han Qi (@qihqi) — author of torchax and the original HuggingFace + JAX tutorial series
- The torchax team at Google — for building and maintaining the library
- The HuggingFace team — for the transformers ecosystem
- The JAX team at Google — for JAX, XLA, and TPU support