A Skeleton Showing How to Use TF Models in JAVA
0. Requirements
Note: Maven is not necessary, you could as well add raw TensorFlow’s Java binding JAR into your project classpath. But if you spend some time to figure out Maven, it does make your life easier and neater.
1. Train and Save a Model in Python with TF
Please go to model_py/model.py
for a toy example, which takes in a matrix of shape (2, 2), and adds 1 to each element then outputs.
2. Freeze the Model
Basically what this step does is to convert all trained variables into constant tensors, so that when you later on load the model in Java, you don’t need to initialize the variables.
You can also choose to optimize the graph for inference using Tensorflow’s optimize_for_inference()
. A series of basic optimizations will be applied including:
- strip unused nodes: since in inference we only need input and output nodes together with nodes in the middle of them
- remove unnecessary nodes: nodes such as numerics checking and identity are good to be pruned when inference
Please seemodel_py/freeze.py
.
3. Load in Java and Run
You still need to 1)prepare your inference data, 2)import the graph from JAVA, 3)launch a session to hold data flow and 4)feed the graph with data.
Please see src/main/java/edu/nyu/jetlite/tf_intergration.java
.