Overview
Gemma is a family of lightweight, state-of-the-art open models built from research and technology used to create Google Gemini models. Gemma can be further finetuned to suit specific needs. But Large Language Models, such as Gemma, can be very large in size and some of them may not fit on a sing accelerator for finetuning. In this case there are two general approaches for finetuning them:
- Parameter Efficient Fine-Tuning (PEFT), which seeks to shrink the effective model size by sacrificing some fidelity. LoRA falls in this category and the Finetune Gemma models in Keras using LoRA tutorial demonstrates how to finetune the Gemma 2B model
gemma_2b_en
with LoRA using KerasNLP on a single GPU. - Full parameter finetuning with model parallelism. Model parallelism distributes a single model's weights across multiple devices and enables horizontal scaling. You can find out more about distributed training in this Keras guide.
This tutorial walks you through using Keras with a JAX backend to finetune the Gemma 7B model with LoRA and model-parallism distributed training on Google's Tensor Processing Unit (TPU). Note that LoRA can be turned off in this tutorial for a slower but more accurate full-parameter tuning.
Using accelerators
Technically you can use either TPU or GPU for this tutorial.
Notes on TPU environments
Google has 3 products that provide TPUs:
- Colab provides TPU v2, which is not sufficient for this tutorial.
- Kaggle offers TPU v3 for free and they work for this tutorial.
- Cloud TPU offers TPU v3 and newer generations. One way to set it up is:
- Create a new TPU VM
- Set up SSH port forwarding for your intended Jupyter server port
- Install Jupyter and start it on the TPU VM, then connect to Colab through "Connect to a local runtime"
Notes on multi-GPU setup
Although this tutorial foc