Link Address: Inference with Gemma using JAX and Flax
Overview
Gemma is a family of lightweight, state-of-the-art open large language models, based on the Google DeepMind Gemini research and technology. This tutorial demonstrates how to perform basic sampling/inference with the Gemma 2B Instruct model using Google DeepMind's gemma library that was written with JAX (a high-performance numerical computing library), Flax (the JAX-based neural network library), Orbax (a JAX-based library for training utilities like checkpointing), and SentencePiece (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma.
This notebook can run on Google Colab with free T4 GPU (go to Edit > Notebook settings > Under Hardware accelerator select T4 GPU).
Setup
1. Set up Kaggle access for Gemma
To complete this tutorial, you first need to follow the setup instructions at Gemma setup, which show you how to do the following:
- Get access to Gemma on kaggle.com.
- Select a Colab runtime with sufficient resources to run the Gemma model.
- Generate and configure a Kaggle username and API key.
After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.
1. Get access to Gemma on kaggle.com.