TensorFlow Model Optimization Toolkit — Pruning API

Since we introduced the Model Optimization Toolkit — a suite of techniques that developers, both novice and advanced, can use to optimize machine learning models — we have been busy working on our roadmap to add several new approaches and tools. Today, we are happy to share the new weight pruning API.

Weight pruning

Optimizing machine learning programs can take very different forms. Fortunately, neural networks have proven resilient to different transformations aimed at this goal.

One such family of optimizations aims to reduce the number of parameters and operations involved in the computation by removing connections, and thus parameters, in between neural network layers.

The weight pruning API is built on top of Keras, so it will be very easy for developers to apply this technique to any existing Keras training program. This API will be part of a new GitHub repository for the model optimization toolkit, along with many upcoming optimization techniques.

import tensorflow_model_optimization as tfmot
model = build_your_model()  
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=2000, end_step=4000)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

What is weight pruning?

Weight pruning means eliminating unnecessary values in the weight tensors. We are practically setting the neural network parameters’ values to zero to remove what we estimate are unnecessary connections between the layers of a neural network. This is done during the training process to allow the neural network to adapt to the changes.

Why is weight pruning useful?

An immediate benefit from this work is disk compression: sparse tensors are amenable to compression. Thus, by applying simple file compression to the pruned TensorFlow checkpoint, or the converted TensorFlow Lite model, we can reduce the size of the model for its storage and/or transmission. For example, in the tutorial, we show how a 90% sparse model for MNIST can be compressed from 12MB to 2MB.

Moreover, across several experiments, we found that weight pruning is compatible with quantization, resulting in compound benefits. In the same tutorial, we show how we can further compress the pruned model from 2MB to just 0.5MB by applying post-training quantization.

In the future, TensorFlow Lite will add first-class support for sparse representation and computation, thus expanding the compression benefit to the runtime memory and unlocking performance improvements, since sparse tensors allow us to skip otherwise unnecessary computations involving the zeroed values.

Results across several models

In our experiments, we have validated that this technique can be successfully applied to different types of models across distinct tasks, from image processing convolutional-based neural networks to speech processing ones using recurrent neural networks. The following table shows a subset of some of these experimental results.

Sparsity results across different models and tasks.

How does it work?

Our Keras-based weight pruning API uses a straightforward, yet broadly applicable algorithm designed to iteratively remove connections based on their magnitude during training. Fundamentally, a final target sparsity is specified (e.g. 90%), along with a schedule to perform the pruning (e.g. start pruning at step 2,000, stop at step 10,000, and do it every 100 steps), and an optional configuration for the pruning structure (e.g. apply to individual values or blocks of values in certain shape).

Example of tensors with no sparsity (left), sparsity in blocks of 1x1 (center), and sparsity in blocks of 1x2 (right).

As training proceeds, the pruning routine will be scheduled to execute, eliminating (i.e. setting to zero) the weights with the lowest magnitude values (i.e. those closest to zero) until the current sparsity target is reached. Every time the pruning routine is scheduled to execute, the current sparsity target is recalculated, starting from 0% until it reaches the final target sparsity at the end of the pruning schedule by gradually increasing it according to a smooth ramp-up function.

Example of sparsity ramp-up function with a schedule to start pruning from step 0 until step 100, and a final target sparsity of 90%.

Just like the schedule, the ramp-up function can be tweaked as needed. For example, in certain cases, it may be convenient to schedule the training procedure to start after a certain step when some convergence level has been achieved, or end pruning earlier than the total number of training steps in your training program to further fine-tune the system at the final target sparsity level. For more details on these configurations, please refer to our tutorial and documentation.

At the end of the training procedure, the tensors corresponding to the “pruned” Keras layers will contain zeros according to the final sparsity target for the layer.

Animation of pruning applied to a tensor. Black cells indicate where the non-zero weights exist. Sparsity increases as training proceeds.

New documentation and Github repository

As mentioned earlier, the weight pruning API will be part of a new GitHub project and repository aimed at techniques that make machine learning models more efficient to execute and/or represent. This is a great project to star if you are interested in this exciting area of machine learning or just want to have the resources to optimize your models.

Given the importance of this area, we are also creating a new sub-site under tensorflow.org/model_optimization with relevant documentation and resources. We encourage you to give this a try right away and welcome your feedback. 

Thanks for reading

If you liked this post, share it with all of your programming buddies!

Follow us on Facebook | Twitter

Learn More

Complete Guide to TensorFlow for Deep Learning with Python

Tensorflow Bootcamp For Data Science in Python

Python for Data Science and Machine Learning Bootcamp

Introducing TensorFlow 2.0

9 Things You Should Know About TensorFlow

TensorFlow is dead, long live TensorFlow!

How to Image Classification with TensorFlow 2.0?

Introduction to Tensorflow for Java

Machine Learning Tutorial - Image Processing using Python, OpenCV, Keras and TensorFlow

Originally published on https://medium.com

#tensorflow #deep-learning #machine-learning

TensorFlow Model Optimization Toolkit — Pruning API
101.95 GEEK