TensorFlow Model Optimization Toolkit — Pruning API

TensorFlow Model Optimization Toolkit — Pruning API

The Tensorflow Model Optimization Toolkit minimizes the complexity of optimizing inference. Inference efficiency is a critical issue when deploying machine learning models to mobile devices because of the model size, latency, and power consumption.

Since we introduced the Model Optimization Toolkit — a suite of techniques that developers, both novice and advanced, can use to optimize machine learning models — we have been busy working on our roadmap to add several new approaches and tools. Today, we are happy to share the new weight pruning API.


Weight pruning

Optimizing machine learning programs can take very different forms. Fortunately, neural networks have proven resilient to different transformations aimed at this goal.

One such family of optimizations aims to reduce the number of parameters and operations involved in the computation by removing connections, and thus parameters, in between neural network layers.


The weight pruning API is built on top of Keras, so it will be very easy for developers to apply this technique to any existing Keras training program. This API will be part of a new GitHub repository for the model optimization toolkit, along with many upcoming optimization techniques.

import tensorflow_model_optimization as tfmot
model = build_your_model()  
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=2000, end_step=4000)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
... 
model_for_pruning.fit(...)

What is weight pruning?

Weight pruning means eliminating unnecessary values in the weight tensors. We are practically setting the neural network parameters’ values to zero to remove what we estimate are unnecessary connections between the layers of a neural network. This is done during the training process to allow the neural network to adapt to the changes.


Why is weight pruning useful?

An immediate benefit from this work is disk compression: sparse tensors are amenable to compression. Thus, by applying simple file compression to the pruned TensorFlow checkpoint, or the converted TensorFlow Lite model, we can reduce the size of the model for its storage and/or transmission. For example, in the tutorial, we show how a 90% sparse model for MNIST can be compressed from 12MB to 2MB.

Moreover, across several experiments, we found that weight pruning is compatible with quantization, resulting in compound benefits. In the same tutorial, we show how we can further compress the pruned model from 2MB to just 0.5MB by applying post-training quantization.

In the future, TensorFlow Lite will add first-class support for sparse representation and computation, thus expanding the compression benefit to the runtime memory and unlocking performance improvements, since sparse tensors allow us to skip otherwise unnecessary computations involving the zeroed values.


Results across several models

In our experiments, we have validated that this technique can be successfully applied to different types of models across distinct tasks, from image processing convolutional-based neural networks to speech processing ones using recurrent neural networks. The following table shows a subset of some of these experimental results.


Sparsity results across different models and tasks.

How does it work?

Our Keras-based weight pruning API uses a straightforward, yet broadly applicable algorithm designed to iteratively remove connections based on their magnitude during training. Fundamentally, a final target sparsity is specified (e.g. 90%), along with a schedule to perform the pruning (e.g. start pruning at step 2,000, stop at step 10,000, and do it every 100 steps), and an optional configuration for the pruning structure (e.g. apply to individual values or blocks of values in certain shape).


Example of tensors with no sparsity (left), sparsity in blocks of 1x1 (center), and sparsity in blocks of 1x2 (right).

As training proceeds, the pruning routine will be scheduled to execute, eliminating (i.e. setting to zero) the weights with the lowest magnitude values (i.e. those closest to zero) until the current sparsity target is reached. Every time the pruning routine is scheduled to execute, the current sparsity target is recalculated, starting from 0% until it reaches the final target sparsity at the end of the pruning schedule by gradually increasing it according to a smooth ramp-up function.


Example of sparsity ramp-up function with a schedule to start pruning from step 0 until step 100, and a final target sparsity of 90%.

Just like the schedule, the ramp-up function can be tweaked as needed. For example, in certain cases, it may be convenient to schedule the training procedure to start after a certain step when some convergence level has been achieved, or end pruning earlier than the total number of training steps in your training program to further fine-tune the system at the final target sparsity level. For more details on these configurations, please refer to our tutorial and documentation.

At the end of the training procedure, the tensors corresponding to the “pruned” Keras layers will contain zeros according to the final sparsity target for the layer.


Animation of pruning applied to a tensor. Black cells indicate where the non-zero weights exist. Sparsity increases as training proceeds.

New documentation and Github repository

As mentioned earlier, the weight pruning API will be part of a new GitHub project and repository aimed at techniques that make machine learning models more efficient to execute and/or represent. This is a great project to star if you are interested in this exciting area of machine learning or just want to have the resources to optimize your models.

Given the importance of this area, we are also creating a new sub-site under tensorflow.org/model_optimization with relevant documentation and resources. We encourage you to give this a try right away and welcome your feedback. 

Thanks for reading

If you liked this post, share it with all of your programming buddies!

Follow us on Facebook | Twitter

Learn More

Complete Guide to TensorFlow for Deep Learning with Python

Tensorflow Bootcamp For Data Science in Python

Python for Data Science and Machine Learning Bootcamp

Introducing TensorFlow 2.0

9 Things You Should Know About TensorFlow

TensorFlow is dead, long live TensorFlow!

How to Image Classification with TensorFlow 2.0?

Introduction to Tensorflow for Java

Machine Learning Tutorial - Image Processing using Python, OpenCV, Keras and TensorFlow

Originally published on https://medium.com

tensorflow deep-learning machine-learning

Bootstrap 5 Complete Course with Examples

Bootstrap 5 Tutorial - Bootstrap 5 Crash Course for Beginners

Nest.JS Tutorial for Beginners

Hello Vue 3: A First Look at Vue 3 and the Composition API

Building a simple Applications with Vue 3

Deno Crash Course: Explore Deno and Create a full REST API with Deno

How to Build a Real-time Chat App with Deno and WebSockets

Convert HTML to Markdown Online

HTML entity encoder decoder Online

Integrating Tensorflow and Qiskit for Quantum Machine Learning

Integrating Tensorflow and Qiskit for Quantum Machine Learning: Taking a step towards quantum machine learning. In this article, we will be talking about integrating Qiskit in custom Keras layers.

What is Supervised Machine Learning

What is neuron analysis of a machine? Learn machine learning by designing Robotics algorithm. Click here for best machine learning course models with AI

Pros and Cons of Machine Learning Language

AI, Machine learning, as its title defines, is involved as a process to make the machine operate a task automatically to know more join CETPA

Artificial Intelligence, Machine Learning, Deep Learning 

Artificial Intelligence (AI) will and is currently taking over an important role in our lives — not necessarily through intelligent robots.

Learn TensorFlow.js - Deep Learning and Neural Networks with JavaScript

This full course introduces the concept of client-side artificial neural networks. We will learn how to deploy and run models along with full deep learning applications in the browser! To implement this cool capability, we’ll be using TensorFlow.js (TFJS), TensorFlow’s JavaScript library.