Bringing AI into Java: Using TensorFlow and ONNX for Machine Learning
Java isn’t the first language that comes to mind when developers think about machine learning. Python dominates the AI landscape with its rich ecosystem. But millions of enterprise applications run on Java, and bringing AI directly into these systems beats building separate Python microservices every time.
TensorFlow and ONNX offer robust pathways for Java developers to integrate machine learning without abandoning their existing stack.
Why Java for Machine Learning?
The case for Java in ML isn’t about replacing Python for research—it’s about production deployment at scale.
Enterprise reality: Your recommendation engine needs to run inside a Spring Boot application serving millions of requests. Your fraud detection model must integrate with existing Java services handling transactions. Rewriting everything in Python isn’t realistic.
Performance matters: The JVM offers excellent multi-threading, mature garbage collection, and battle-tested optimization. For inference workloads, Java can match or exceed Python performance.
Operational simplicity: One runtime, one deployment pipeline, one monitoring stack. Adding Python microservices introduces complexity—multiple languages, container orchestration overhead, and network latency between services.
TensorFlow for Java: Direct Integration
TensorFlow for Java lets you load pre-trained models and run inference directly in your JVM applications.
Getting started with Maven:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.5.0</version>
</dependency>
Practical example – Image classification:
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.FloatNdArray;
public class ImageClassifier {
private SavedModelBundle model;
public ImageClassifier(String modelPath) {
model = SavedModelBundle.load(modelPath, "serve");
}
public float[] classify(float[] imageData) {
try (Tensor<Float> input = Tensor.of(
FloatNdArray.wrap(imageData, 1, 224, 224, 3))) {
var result = model.session()
.runner()
.feed("input_layer", input)
.fetch("output_layer")
.run()
.get(0);
float[] predictions = new float[1000];
result.data().read(predictions);
return predictions;
}
}
}
Real-world use case: A Java-based content management system uses TensorFlow to automatically tag uploaded images. The model runs in-process, classifying images in under 50ms without external API calls.
Strengths:
- Native TensorFlow format support
- Good documentation from Google
- Direct access to TensorFlow ecosystem
Limitations:
- Smaller community than Python TensorFlow
- Some advanced features lag behind Python API
- Model training in Java remains impractical
ONNX Runtime: The Universal Standard
ONNX (Open Neural Network Exchange) solves a bigger problem: framework interoperability. Train your model in PyTorch, export to ONNX, run it anywhere—including Java.
Why ONNX matters: You’re not locked into TensorFlow. Models from PyTorch, scikit-learn, XGBoost, and dozens of other frameworks convert to ONNX format. Your data scientists work with their preferred tools while you deploy standardized models.
Maven dependency:
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.1</version>
</dependency>
Practical example – Sentiment analysis:
import ai.onnxruntime.*;
public class SentimentAnalyzer {
private OrtEnvironment env;
private OrtSession session;
public SentimentAnalyzer(String modelPath) throws OrtException {
env = OrtEnvironment.getEnvironment();
session = env.createSession(modelPath,
new OrtSession.SessionOptions());
}
public float analyzeSentiment(long[] tokenIds) throws OrtException {
OnnxTensor tensor = OnnxTensor.createTensor(
env, new long[][]{tokenIds});
try (OrtSession.Result results =
session.run(Map.of("input_ids", tensor))) {
float[][] output = (float[][]) results.get(0).getValue();
return output[0][1]; // Positive sentiment score
}
}
}
Real-world use case: An e-commerce platform analyzes customer reviews in real-time. Data scientists trained a PyTorch model, exported it to ONNX, and the Java backend runs inference without Python dependencies.
Strengths:
- Framework agnostic—use any ML tool
- Excellent performance with optimized runtime
- Strong Microsoft backing and active development
- Smaller model files than some alternatives
Limitations:
- Export process can require debugging
- Not all operations supported in every framework conversion
Choosing Your Approach
Use TensorFlow for Java when:
- Your models are already in TensorFlow format
- You need tight integration with TensorFlow ecosystem
- Your team has TensorFlow expertise
Use ONNX Runtime when:
- You want framework flexibility
- Data scientists prefer PyTorch or other tools
- You need maximum portability across platforms
- Performance is critical (ONNX runtime is highly optimized)
Integration Patterns That Work
Pattern 1: Model as a dependency Package your ONNX or TensorFlow model as a JAR resource. Version it with your application. Simple, self-contained, but watch JAR size.
Pattern 2: External model loading Load models from cloud storage or filesystem at startup. Enables model updates without redeployment. Add validation to ensure model compatibility.
Pattern 3: Model registry Use tools like MLflow or DVC to manage model versions. Java services pull specific versions at runtime. Best for production systems with frequent model updates.
Performance tip: Pre-allocate tensors and reuse sessions. Creating new sessions for each inference kills performance. Use connection pooling patterns you’d apply to database connections.
Preprocessing: The Hidden Challenge
Models need data in specific formats. Java applications must handle:
Text preprocessing:
// Tokenization for NLP models
public long[] tokenize(String text, Map<String, Long> vocab) {
String[] words = text.toLowerCase().split("\\s+");
return Arrays.stream(words)
.map(w -> vocab.getOrDefault(w, vocab.get("[UNK]")))
.mapToLong(Long::longValue)
.toArray();
}
Image preprocessing:
// Normalize image for vision models
public float[] preprocessImage(BufferedImage img) {
BufferedImage resized = resize(img, 224, 224);
float[] pixels = new float[224 * 224 * 3];
int idx = 0;
for (int y = 0; y < 224; y++) {
for (int x = 0; x < 224; x++) {
Color c = new Color(resized.getRGB(x, y));
pixels[idx++] = (c.getRed() / 255.0f - 0.485f) / 0.229f;
pixels[idx++] = (c.getGreen() / 255.0f - 0.456f) / 0.224f;
pixels[idx++] = (c.getBlue() / 255.0f - 0.406f) / 0.225f;
}
}
return pixels;
}
Match your Java preprocessing exactly to your training pipeline. Mismatches here cause subtle bugs that tank model accuracy.
Production Considerations
Memory management: Models consume memory. Monitor heap usage and configure JVM appropriately. A 500MB model needs sufficient heap beyond your application’s baseline.
Threading: ONNX Runtime supports inter and intra-op parallelism. Configure thread pools based on your workload:
SessionOptions opts = new SessionOptions(); opts.setIntraOpNumThreads(4); opts.setInterOpNumThreads(2);
Monitoring: Track inference latency, throughput, and errors. Integrate with existing Java monitoring—Micrometer, Prometheus, whatever your stack uses.
Fallback strategies: What happens when inference fails? Have graceful degradation. Return cached results, default predictions, or route to a backup service.
Learning Resources
Official Documentation:
- TensorFlow Java Guide
- ONNX Runtime Java API
- ONNX Model Zoo – Pre-trained models to experiment with
Tutorials & Examples:
Community:
The Practical Reality
You won’t train models in Java. That ship has sailed—Python’s ecosystem is unbeatable for experimentation and training. But deploying those models into Java applications? That’s increasingly common and increasingly practical.
TensorFlow and ONNX give you production-ready inference without architectural gymnastics. Your data scientists keep their Python notebooks. Your Java services get ML capabilities. Everyone wins.
Start small. Pick one model—maybe a simple classifier or recommendation engine. Export it to ONNX or TensorFlow format. Integrate it into a non-critical service. Measure performance. Learn the quirks. Then scale up.
Machine learning in Java isn’t about beating Python at its own game. It’s about meeting enterprise applications where they already are.

