Enterprise Java
Building Predictive APIs with TensorFlow and Spring Boot
1. Why Combine AI/ML with Spring Boot?
Modern applications increasingly need smart capabilities – from recommendation engines to fraud detection. While Python dominates ML development, Java teams can leverage:
- TensorFlow Java for model inference
- Spring Boot for scalable API delivery
- DJL (Deep Java Library) as an alternative framework
This guide walks through serving a trained ML model via REST API with zero Python dependencies.
2. Architecture Overview
[Python Environment] -- Trains Model --> SavedModel.pb
?
[Java Service] <-- Loads Model --> [Spring Boot REST API]
?
[Client Apps] <-- Gets Predictions
Key components:
- TensorFlow SavedModel (exported from Python)
- Spring Boot web layer
- TensorFlow Java API for inference
Step 1: Train and Export Model (Python)
# train.py
import tensorflow as tf
# Sample neural network
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, epochs=10)
# Export for Java
tf.saved_model.save(model, "saved_model")
This creates a /saved_model directory with:
saved_model.pb(architecture)variables/(trained weights)
Step 2: Spring Boot Integration
Dependencies (pom.xml)
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.4.1</version>
</dependency>
Load Model in Java
import org.tensorflow.*;
import org.tensorflow.types.TFloat32;
public class Predictor {
private SavedModelBundle model;
@PostConstruct
public void init() {
this.model = SavedModelBundle.load(
"src/main/resources/saved_model",
"serve"
);
}
public float predict(float[] input) {
try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(input);
TFloat32 result = (TFloat32)model.session()
.runner()
.feed("dense_input", inputTensor)
.fetch("dense_1")
.run()
.get(0)) {
return result.getFloat();
}
}
}
Step 3: Expose as REST API
@RestController
@RequestMapping("/api/predict")
public class PredictionController {
@Autowired
private Predictor predictor;
@PostMapping
public PredictionResponse predict(@RequestBody PredictionRequest request) {
float result = predictor.predict(request.getFeatures());
return new PredictionResponse(result);
}
}
Sample request:
curl -X POST http://localhost:8080/api/predict \
-H "Content-Type: application/json" \
-d '{"features": [0.1, 0.5, 0.3]}'
3. Performance Optimization Tips
- Batching Predictions
Process multiple inputs in one session run:
float[][] batchInputs = ...; Tensor<TFloat32> batchTensor = TFloat32.tensorOf(batchInputs);
2. GPU Acceleration
Add CUDA dependencies for NVIDIA GPUs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform-gpu</artifactId>
<version>0.4.1</version>
</dependency>
3. Model Warmup
Initialize model at startup to avoid first-call latency:
@Bean
public CommandLineRunner warmup(Predictor predictor) {
return args -> predictor.predict(new float[inputSize]);
}
4. Alternative: DJL (Deep Java Library)
For more Java-native ML workflows:
// Build model directly in Java
Model model = Model.newInstance("linear");
model.load(new Path("model.pt"));
try(NDManager manager = NDManager.newBaseManager()) {
NDArray input = manager.create(new float[]{...});
Predictor predictor = model.newPredictor();
NDArray result = predictor.predict(input);
}
Advantages:
- Unified API for TensorFlow/PyTorch/MXNet
- No SWIG/JNI overhead
- Built-in image preprocessing
5. Conclusion
Key takeaways:
✅ Serve TensorFlow models without Python in production
✅ Achieve <10ms latency per prediction
✅ Scale horizontally like any Spring Boot service
Next Steps:
- Try the TensorFlow Java examples
- Explore DJL’s Spring Boot starter
- Monitor performance with Micrometer metrics




