Gin is a lightweight configuration framework for Python,
based on dependency injection. While T5X does not employ gin in its core
libraries, it is used to configure runs of the train
, eval
, and infer
scripts. This usage is a bit different (and more limited) than how gin is
typically applied, so this primer should be useful even for those who may be
familiar with gin from other libaries (e.g., T5 or Mesh TensorFlow).
Nevertheless, you may still find it helpful to refer to the gin documentation for more background.
[TOC]
Rather than plumbing run arguments and hyperparameters through via limited set
of command-line flags or a flat configuration schema, T5X's gin integration
allows you to parameterize the top-level run functions (train
, evaluate
, and
infer
) as well as any object or function that is passed to them. This enables
a vast amount of flexibility over your runs without needing to modify any code
within the core T5X library.
For example, you can implement a Python class in your own codebase (e.g., a custom model or trainer) and use gin to pass an instance of it to the T5X XM launcher without having to fork any code. Previously you needed to implement every experimental idea in the core library (no matter how widely used it would be) and add a ConfigDict flag to enable/disable it, resulting in significant code debt over time.
On the other hand, gin can sometimes be too powerful, allowing users the ability to bind arguments throughout a codebase, which makes it difficult or impossible to update "private" internal interfaces. However, by limiting configurability to a single top-level function and its arguments we can better control the configurable surface to public interfaces and user-owned code, and also avoid unintended side effects.
Let's look at the evaluate
call signature from
eval.py as an example:
def evaluate(*,
model: models.BaseModel,
dataset_cfg: utils.DatasetConfig,
restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
partitioner: partitioning.BasePartitioner,
output_dir: str):
"""Evaluation function.
Args:
model: The model object to use for inference.
dataset_cfg: Specification for the dataset to infer based on.
restore_checkpoint_cfg: Specification for the model parameter checkpoint to
load.
partitioner: The partitioner for the model parameters and
data across devices.
output_dir: Path to directory to write temporary files and final results.
"""
...
In the binary, the user-provided gin configuration file will be parsed. It
specifies which values should be bound to the evaluate
argument, after which
we can directly call the fully-bound function without any arguments. Basically,
we are creating a custom closure of evaluate
(a la functools.partial
) but
specifying the arguments via gin instead of Python.
Furthermore, this ability to bind custom arguments is recursive. Not only can we
bind the arguments of evaluate
, but we can also bind the constructor and
method arguments of the instance of models.BaseModel
that we pass to
evaluate
.
Let's now look at an example of a gin configuration for parameterizing
evaluate
, specifically evaluating a
T5 model fine-tuned for closed book question answering
on Natural Questions Open:
from __gin__ import dynamic_registration
import __main__ as eval_script
from t5x import models
from t5x import partitioning
from t5x import utils
MODEL = %gin.REQUIRED
eval_script.evaluate:
model = %MODEL
output_dir = '/tmp/t5x_eval'
dataset_cfg = @utils.DatasetConfig()
partitioner = @partitioning.PjitPartitioner()
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
# Load model with overrides.
include 'models/t5_large.gin'
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
utils.DatasetConfig:
mixture_or_task_name = 'natural_questions_open'
split = 'test'
task_feature_lengths = None
batch_size = 32
shuffle = False
seed = 0
use_cached = False
pack = False
use_custom_packing_ops = False
module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
partitioning.PjitPartitioner:
num_partitions = 1
utils.RestoreCheckpointConfig:
mode = 'specific'
path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo'
assignment_map = None
strict = True
dtype = None
Let's go through this block-by-block.
from __gin__ import dynamic_registration
The first line imports a new gin feature (see cl/372624800 for more details) to allow us to register functions and objects for configuration from within the gin file itself without having to modify or decorate functions from the imported packages.
import __main__ as eval_script
from t5x import models
from t5x import utils
The second block imports the modules containing the components we plan to
configure in this file and is required for dynamic registration. Note that only
those functions and objects that we specify below will actually be configured,
not everything in the module. Also, as is the case in Python, the binary module
is referred as __main__
, although we rename it to eval_script
for clarity in
the rest of the config.
MODEL = %gin.REQUIRED
The third block creates a
gin macro
(essentially a lazy reference) and for now sets it to refer to the special macro
gin.REQUIRED
, which will cause a failure during parsing of the configuration
if not updated via a later assignment in the config file or command-line flags
(see below).
eval_script.evaluate:
model = %MODEL
output_dir = '/tmp/t5x_eval'
dataset_cfg = @utils.DatasetConfig()
partitioner = @partitioning.PjitPartitioner()
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
The fourth block specifies the binding for the evaluate
function. For model
,
we pass the value of the MODEL
macro (to be defined later). For output_dir
we pass a string path. For dataset_cfg
, restore_checkpoint_cfg
, and
partitioner
, we pass instantiations of DatasetConfig
,
RestoreCheckpointConfig
, and PjitPartitioner
, which are defined in
utils.py and
partitioning.py
respectively. The '@' prefix tells gin that the following is a configured
function or class, and the '()' suffix signifies that it should be called (in
the cases of class, this means calling the constructor). If we wanted to pass in
the closure (or a partially bound) function instead of its return value, we
would leave off the parentheses.
The remainder of the file deals with defining the MODEL
macro and fully
binding these constructors.
# Load model with overrides.
include 't5x/examples/t5/t5_1_1/large.gin'
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
Although we could define MODEL = model.EncoderDecoderModel()
here, we prefer
to create a separate gin file that defines it. This makes it easier to reuse
parts of the common configurations. All of the bindings in the newly included
file are read and override any conflicting ones defined so far in this file.
It's equivalent to copy and pasting the contents of the included file at this
location in the config. If you want to see how the model itself is instantiated,
you can refer to
t5_1_1/large.gin
(which simply overrides a few values from
t5_1_1/base.gin).
The final line of this block shows an example of how you can modify the default
arguments of the EncoderDecoderModel
instance referenced by %MODEL
, in this
case changing the default beam size it will use during prediction. Notice that
since we are only binding one argument here, we choose to write it on a single
line instead of using the block binding syntax used elsewhere in the file.
utils.DatasetConfig:
mixture_or_task_name = 'natural_questions_open'
split = 'test'
task_feature_lengths = None
batch_size = 32
shuffle = False
seed = 0
use_cached = False
pack = False
use_custom_packing_ops = False
module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks'
partitioning.PjitPartitioner:
num_partitions = 1
utils.RestoreCheckpointConfig:
mode = 'specific'
path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo'
assignment_map = None
strict = True
dtype = None
The last 3 blocks are fairly straightforward. They are effectively setting the
attributes of these dataclasses by binding values to their constructors that
will be used when they are instantiated and passed to evaluate
, as specified
in the fourth block.
The above example lacks one key component of gin: scopes.
What happens if you need to use a class or function multiple times but with different bound values?
A clear example of this is in the top-level train
function (in
train.py). The call signature
includes 3 different instances of utils.DatasetConfig
: one for the train
dataset, one for the "train-eval" dataset (used for evaluation with teacher
forcing), and one for the "infer-eval" dataset (used for evaluation with
inference/decoding).
The solution is to prefix each instance with a unique identifier both when
specifying where it is to be passed to train
and when binding its arguments.
For example, the gin file might look like the following (skipping the irrelevant
bits):
...
train_script.train:
train_dataset_cfg = @train/utils.DatasetConfig()
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
...
train/utils.DatasetConfig:
mixture_or_task_name = 'train_mixture'
split = 'train'
...
train_eval/utils.DatasetConfig:
mixture_or_task_name = 'eval_mixture'
split = 'validation'
...
infer_eval/utils.DatasetConfig:
mixture_or_task_name = 'eval_mixture'
split = 'test'
...
We have therefore configured 3 different scoped-versions of
utils.DatasetConfig
producing 3 separate instances that are passed to train
.
Note that these three scopes will all inherit from the base scope, so if you
want to set a shared binding, you may directly configure utils.DatasetConfig
without a scope prefix.
So now that you have a gin config, how do you pass it to the script? There are two ways: gin files and override flags.
-
Gin Files You have already seen an example of a gin file above. You can specify the gin file(s) to use in your script via the
--gin_file
flag. If you want to load multiple gin files, you can set the flag multiple times and the files will be loaded in order, with the second potentially overriding the first when there are conflicts. It is possible to supply a comma-separate list of search prefixes via--gin_search_paths
and then only specify the relative path to the--gin_file
flags. However, we strongly recommend against using--gin_search_paths
. Using absolute paths via the--gin_file
flags will reduce sources of ambiguity and improve the consistency of your scripts. -
Override Flags Gin flags allow for more fine-grained overrides of any configurable aspect of your run. These flags follow the single-line binding format from the above example with the addition of a
--gin.
prefix. For example, if you want to override the dataset shuffling, you can set--gin.utils.DatasetConfig.shuffle=False
. In the train setting where there are multiple datasets, you must supply the appropriate scope, e.g.,--gin.train/utils.DatasetConfig.shuffle=False
. These bindings are processed in order after the gin files are loaded, and therefore overwrite any previously assigned value in the gin files.
Note: when supplying a string, dict, list, or tuple value via a flag, you
must put it in quotes. In the case of strings, it requires escaped quotes
(\"<string>\"
). For example: --gin.utils.DatasetConfig.split=\"validation\"
,
--gin.utils.DatasetConfig.task_feature_lengths="{'inputs': 512, 'targets': 84}"
, and --gin.dense.MlpBlock.activations="('dense', 'gelu')"
An example where you may need multiple files is with the train
script.
You can first specify which model you want to train by supplying a gin file containing its definition, for example: t5_1_1/small.gin.
You may then specify a run config that supplies some of the common defaults. For example, if you are doing pretraining you can use runs/pretrain.gin, and if you are doing finetuning, you can use runs/finetune.gin.
We can apply these two files with the following command:
python -m t5x.train_unfragmented \
--gin_file=t5x/examples/t5/t5_1_1/small.gin \
--gin_file=t5x/configs/runs/finetune.gin \
--logtostderr
However, running this command will give you an error like the following:
ValueError: MODEL_DIR/macro.value set to `%gin.REQUIRED` but not subsequently overridden.
This is because the config still includes some gin.REQUIRED
macros that you'll
need to override with the details of your run. At the top of
runs/finetune.gin
you'll see the list of required overrides, which we will populate for finetuning
on WMT in the updated launch command here:
python -m t5x.train_unfragmented \
--gin_file=t5x/examples/t5/t5_1_1/small.gin \
--gin_file=t5x/configs/runs/finetune.gin \
--gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
--gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \
--gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}" \
--gin.TRAIN_STEPS=1_020_000 \
--gin.MODEL_DIR=\"/tmp/t5_1_1_base_finetune_gin\" \
--gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000\" \
--logtostderr
Note you may still override any registered bindings. For example, to disable
inference evaluation you may add --gin.train.infer_eval_dataset_cfg=None
.
At the beginning of the primer, we saw a fully-specified run config. We can do something similar with the previous example to create a self-contained run configuration. t5_1_1/examples/small_wmt_finetune.gin is just such an example that allows you to exactly duplicate the previous launch command simply by calling:
python -m t5x.train_unfragmented \
--gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
--gin.MODEL_DIR=\"/tmp/t5_1_1_small_finetune_gin\" \
--logtostderr
After your gin files and flag overrides are parsed, the complete configuration
will be logged to INFO, written to config.gin
in the output directory, and
added to a TensorBoard summary.
It is highly recommended that you review this generated config to ensure that your overrides are working as expected.