How to use TensorFlow models in Java

A skeleton example

Posted by Xuan on October 31, 2017

A Skeleton Showing How to Use TF Models in JAVA

0. Requirements

Note: Maven is not necessary, you could as well add raw TensorFlow’s Java binding JAR into your project classpath. But if you spend some time to figure out Maven, it does make your life easier and neater.

1. Train and Save a Model in Python with TF

Please go to model_py/model.py for a toy example, which takes in a matrix of shape (2, 2), and adds 1 to each element then outputs.

2. Freeze the Model

Basically what this step does is to convert all trained variables into constant tensors, so that when you later on load the model in Java, you don’t need to initialize the variables.
You can also choose to optimize the graph for inference using Tensorflow’s optimize_for_inference(). A series of basic optimizations will be applied including:

  • strip unused nodes: since in inference we only need input and output nodes together with nodes in the middle of them
  • remove unnecessary nodes: nodes such as numerics checking and identity are good to be pruned when inference
    Please see model_py/freeze.py.

3. Load in Java and Run

You still need to 1)prepare your inference data, 2)import the graph from JAVA, 3)launch a session to hold data flow and 4)feed the graph with data.
Please see src/main/java/edu/nyu/jetlite/tf_intergration.java.