How to use TensorFlow in Java

Introduction

Machine Learning is gaining popularity and usage over the globe. It has already drastically changed the way certain applications are built and will likely continue to be a huge (and increasing) part of our daily lives.

There’s no sugarcoating it, Machine Learning isn’t simple. It’s pretty daunting and can seem very complex to many.

Companies such as Google took it upon themselves to bring Machine Learning concepts closer to developers and allow them to gradually, with major help, make their first steps.

Thus, frameworks such as TensorFlow were born.

What is TensorFlow?

TensorFlow is an open-source Machine Learning framework developed by Google in Python and C++.

It helps developers easily acquire data, prepare and train models, predict future states, and perform large-scale machine learning.

With it, we can train and run deep neural networks which are most often used for Optical Character Recognition, Image Recognition/Classification, Natural Language Processing, etc.

Tensors and Operations

TensorFlow is based on computational graphs, which you can imagine as a classic graph with nodes and edges.

Each node is referred to as an operation, and they take zero or more tensors in and produce zero or more tensors out. An operation can be very simple, such as basic addition, but they can also be very complex.

Tensors are depicted as edges of the graph, and are the core data unit. We perform different functions on these tensors as we feed them to operations. They can have a single or multiple dimensions, which are sometimes referred to as their ranks - (Scalar: rank 0, Vector: rank 1, Matrix: rank 2)

This data flows through the computational graph through tensors, impacted by operations - hence the name TensorFlow.

Tensors can store data in any number of dimensions, and there are three main types of tensors: placeholdersvariables, and constants.

Installing TensorFlow

Using Maven, installing TensorFlow is as easy as including the dependency:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

If your device supports GPU support, then use these dependencies:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow_jni_gpu</artifactId>
  <version>1.13.1</version>
</dependency>

You can check the version of TensorFlow currently installed by using the TensorFlow object:

System.out.println(TensorFlow.version());

TensorFlow Java API

The Java API TensorFlow offers is contained within the org.tensorflow package. It’s currently experimental so it’s not guaranteed to be stable.

Please note that the only fully supported language for TensorFlow is Python and that the Java API isn’t nearly as functional.

Graphs

As mentioned before, TensorFlow is based on computational graphs - where org.tensorflow.Graph is Java’s implementation.

Note: Its instances are thread-safe, though we need to explicitly release resources used by the Graph after we’re finished with it.

Let’s start off with an empty graph:

Graph graph = new Graph();

This graph doesn’t mean much, it’s empty. To do anything with it, we first need to load it up with Operations.

To load it up with operations, we use the opBuilder() method, which returns an OperationBuilder object that’ll add the operations to our graph once we call the .build() method.

Constants

Let’s add a constant to our graph:

Operation x = graph.opBuilder("Const", "x")
               .setAttr("dtype", DataType.FLOAT)
               .setAttr("value", Tensor.create(3.0f))
               .build(); 

Placeholders

Placeholders are a “type” of variable that don’t have a value at declaration. Their values will be assigned at a later date. This allows us to build graphs with operations without any actual data:

Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.FLOAT)
        .build();

Functions

And now finally, to round this up, we need to add certain functions. These could be as simple as multiplication, division, or addition, or as complex as matrix multiplications. The same as before, we define functions using the .opBuilder() method:

Operation xy = graph.opBuilder("Mul", "xy")
  .addInput(x.output(0))
  .addInput(y.output(0))
  .build();         

Note: We’re using output(0) as a tensor can have more than one output.

Graph Visualisation

Sadly, the Java API doesn’t yet include any tools that allow you to visualize graphs as you would in Python. When the Java API gets updated, so will this article.

Sessions

As mentioned before, a Session is the driver for a Graph’s execution. It encapsulates the environment in which Operations and Graphs are executed to compute Tensors.

What this means is that the tensors in our graph that we constructed don’t actually hold any value, as we didn’t run the graph within a session.

Let’s first add the graph to a session:

Session session = new Session(graph);

Our computation simply multiples the x and y value. In order to run our graph and compute it, we fetch() the xy operation and feed it the x and y values:

Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
System.out.println(tensor.floatValue());

Running this piece of code will yield:

10.0f

Saving Models in Python and Loading in Java

This may sound a bit odd, but since Python is the only well-supported language, the Java API still doesn’t have the functionality to save models.

This means that the Java API is meant only for the serving use-case, at least until it’s fully supported by TensorFlow. At least, we can train and save models in Python and then load them in Java to serve them, using the SavedModelBundle class:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);  

System.out.println(tensor.floatValue());

Conclusion

TensorFlow is a powerful, robust and widely-used framework. It’s constantly being improved and lately introduced to new languages - including Java and JavaScript.

Although the Java API doesn’t yet have nearly as much functionality as TensorFlow for Python, it can still serve as a good intro to TensorFlow for Java developers.

I hope this tutorial will surely help and you if you liked this tutorial, please consider sharing it with others.

Keep Visiting…

☞ Introducing TensorFlow 2.0

☞ Complete Guide to TensorFlow for Deep Learning with Python

☞ Introducing TensorFlow Datasets

☞ Introducing TensorFlow Federated

☞ How To Install and Use TensorFlow on Ubuntu 18.04?

Originally published on stackabuse.com

#tensorflow #java #machine-learning #web-development

How to use TensorFlow in Java
3 Likes22.20 GEEK