Tensorflow 2 was released in September 2019 and introduced major refactors to the Tensorflow interface — eager execution by default, imperative code support, and the removal of tf.Session. While these upgrades promised a much-improved developer experience comparable to PyTorch, in practice, Tensorflow 2 — in comparison to PyTorch — requires users to scale a far steeper learning curve for it to be used effectively.
Why is this the case? Part of the issue stems from the fact that Tensorflow is designed around compiling user code into a computational graph, which is optimized with XLA. Then, eager support was tacked on later in Tensorflow 2, and uses Autograph to convert imperative code into TF graph ops. This introduced an additional layer of complexity and doesn’t fully support all types of imperative code. Furthermore, Tensorflow 2 suffers from feature bloat — there are often multiple ways of doing things and it’s not always clear which way is best.
Despite these complexities, there are still many reasons to use TF. The ecosystem offers some of the best tools for managing large deployment scenarios and can be deployed on different runtimes such as TFLite, Tensorflow.js, and Apple M1 chips. It has top-notch support for distributed training, and aside from Jax, is the only framework with truly usable TPU support.
In this blog post, we will walk through some solutions to the common problems you might encounter when using TF on non-trivial projects and discuss best practices for optimal productivity. Ultimately, there is no single “best” framework — you should use the right tool for the specific job at hand, i.e. if you’re writing research code, you may want to use Pytorch or Jax; if you’re training a very large model or want to deploy to different runtimes, Tensorflow 2 may be the better choice.
The difference between Modules, Models, and Layers
Like other imperative frameworks, in TF 2, it’s common practice to encapsulate logic inside container classes that have internal state and methods that use that state. TF 2 offers three ways of doing this:
- tf.Module: the base container class akin to nn.Module in Pytorch. Objects that subclass tf.Module get a few basic things like collecting variables, trainable_variables or submodules initialized in the __init__ function and can automatically save to SavedModel format (more on that later). It’s very lightweight and doesn’t try to do much “magic”.
- tf.keras.Layer: extends tf.Module and is the base class for all tf.keras layers. You should ONLY use this if you plan on doing anything related to the tf.keras high level API. By subclassing this, you get the plumbing that makes the magic of tf.keras work behind the scenes.
- tf.keras.Model: This is what you should use if you plan on chaining together multiple tf.keras.Layer objects. Most of the actual tf.keras high level API methods (train, fit, predict) live on the tf.keras.Model.
If you can, you should probably use Layers and Model with the tf.keras high level API. However, sometimes keras will get in the way if you want to do extensive custom work, and in those cases, you should stick to the raw tf.Module and custom training loops.
SavedModel format vs Checkpoints
TF2 offers two ways of saving models to disk: SavedModel and Checkpoint. SavedModel is analogous to saving a model in Pytorch after it has been traced with torch.jit. It serializes the entire graph representation including the model weights, which can be optimized and/or converted for specific runtimes. The SavedModel abstraction was designed specifically for model serving in mind: it’s easily portable and self-contained. However, because it’s saving a compiled graph, it will only work if your Tensorflow code is graph compatible (more on that later). If possible, you should preferably use SavedModel because it also saves a checkpoint file inside the output directory, meaning you can load it and use it in the same way as a Checkpoint, getting the best of both worlds.
A Checkpoint is just the model weights, meaning the original code used to emit the checkpoint needs to be used to do model inference later on. It’s designed to be used when you want to go back to an older version (eg. spot instances) or pre-training/fine-tuning. The only reason to save standalone Checkpoints (instead of a SavedModel) is if your code is not graph compatible.
How to avoid common pitfalls with graph mode
Autograph is the “glue” that converts imperative style code into TF graph operations. However, it only supports a subset of imperative programming and should be used with care. If you don’t do this, you will probably get cryptic error messages from Autograph complaining about how it can’t compile the graph.
If you haven’t already read the official guide on Tensorflow 2, the tips below simply add additional advice on top of what’s already covered there and correct some of the claims the autograph authors made about its performance.
Tip #1: Avoid placing complex looping logic inside the graph
If you have complex looping logic in your training loop, try placing that logic outside of the graph. For example, the guide shows this example of doing loop aggregation with tf.TensorArray:
batch_size = 2 seq_len = 3 feature_size = 4 def rnn_step(inp, state): return inp + state @tf.function def dynamic_rnn(rnn_step, input_data, initial_state): # [batch, time, features] -> [time, batch, features] input_data = tf.transpose(input_data, [1, 0, 2]) max_seq_len = input_data.shape states = tf.TensorArray(tf.float32, size=max_seq_len) state = initial_state for i in tf.range(max_seq_len): state = rnn_step(input_data[i], state) states = states.write(i, state) return tf.transpose(states.stack(), [1, 0, 2]) dynamic_rnn(rnn_step, tf.random.uniform([batch_size, seq_len, feature_size]), tf.zeros([batch_size, feature_size]))
Does this work in this specific example? Yes. Will this work most of the time in other non-trivial settings? No. Often, the easier solution would be to try to only convert the parts that do the heavy computation to graph mode. We could refactor this to something like the following (where rnn_step presumably calls out to a neural network):
batch_size = 2 seq_len = 3 feature_size = 4 @tf.function def rnn_step(inp, state): return inp + state def dynamic_rnn(rnn_step, input_data, initial_state): # [batch, time, features] -> [time, batch, features] input_data = tf.transpose(input_data, [1, 0, 2]) max_seq_len = input_data.shape states =  state = initial_state for i in tf.range(max_seq_len): state = rnn_step(input_data[i], state) states.append(i) return tf.transpose(tf.concat(states), [1, 0, 2]) dynamic_rnn(rnn_step, tf.random.uniform([batch_size, seq_len, feature_size]), tf.zeros([batch_size, feature_size]))
In this case, we get the meat of the computation converted to graph mode, while the outer training loop is free to use imperative style list aggregation. If you can’t move things out of the tf.function, try using non-imperative ops like tf.cond or tf.map_fn instead of if and for statements.
Tip #2: Prefer dynamic shape to inferred shape
Tensorflow offers 2 ways of retrieving the shape of a tensor: tf.shape (dynamic) and x.shape (inferred shape). In eager mode, tf.shape and Tensor.shape are identical. Within a tf.function, not all dimensions may be known until execution time, hence inferred shape will return None values for dynamic shape dimensions such as batch size (which may vary from call to call).
Tip #3: Watch out for retracing
Always pad tensor inputs to a common shape before passing to a graph function to avoid retracing. Retracing forces Tensorflow to recreate parts of the compiled graph if it senses repeated changes to the shape of your input tensors, and will cause huge slowdowns in performance.
Don’t reinvent the wheel: consider using Tensorflow Addons and Orbit
Tensorflow has a well-established process for add-ons to the core framework called Tensorflow Addons. If it feels like you’re implementing something that someone has already done in another project — it might have already been implemented in Tensorflow Addons. These include:
If you use types in your Tensorflow code, Tensorflow Addons recently added a great set of Python types, that can be used to organize and add additional readability/type safety.
Consider using Orbit as an alternative to the Keras high-level API. You can think of Orbit as Keras’s much more customizable, powerful older sibling. Like other training loop frameworks, it breaks things down into the following common interfaces. Runnable, which handles the inner training loop via train and evaluate methods and Controller handles the outer training loop and distribution strategy. Controller supports four outer training loop strategies:
- train, which trains until a specified number of global steps is reached
- evaluate, for one-off model evaluation
- train_and_evaluate, for interleaved training and evaluation
- evaluate_continuously, for monitoring a given directory and running evaluations on new model checkpoints.