tensorflow
Sla het Tensorflow-model op in Python en laad met Java
Zoeken…
Invoering
Het bouwen en vooral trainen van een model kan het eenvoudigst worden gedaan in Python, dus hoe laad en gebruik je het getrainde model in Java?
Opmerkingen
Het model kan elk willekeurig aantal ingangen accepteren, dus wijzig de NUM_PREDICTIONS als u meer voorspellingen wilt uitvoeren dan één. Realiseer je dat de Java JNI gebruikt om het C ++ tensorflow-model aan te roepen, dus je zult enkele info-berichten uit het model zien wanneer je dit uitvoert.
Maak en bewaar een model met Python
import tensorflow as tf
# good idea
tf.reset_default_graph()
# DO MODEL STUFF
# Pretrained weighting of 2.0
W = tf.get_variable('w', shape=[], initializer=tf.constant(2.0), dtype=tf.float32)
# Model input x
x = tf.placeholder(tf.float32, name='x')
# Model output y = W*x
y = tf.multiply(W, x, name='y')
# DO SESSION STUFF
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# SAVE THE MODEL
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING]
)
builder.save()
Laad en gebruik het model in Java.
public static void main( String[] args ) throws IOException
{
// good idea to print the version number, 1.2.0 as of this writing
System.out.println(TensorFlow.version());
final int NUM_PREDICTIONS = 1;
// load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
// create the session from the Bundle
Session sess = b.session();
// create an input Tensor, value = 2.0f
Tensor x = Tensor.create(
new long[] {NUM_PREDICTIONS},
FloatBuffer.wrap( new float[] {2.0f} )
);
// run the model and get the result, 4.0f.
float[] y = sess.runner()
.feed("x", x)
.fetch("y")
.run()
.get(0)
.copyTo(new float[NUM_PREDICTIONS]);
// print out the result.
System.out.println(y[0]);
}
}
Modified text is an extract of the original Stack Overflow Documentation
Licentie onder CC BY-SA 3.0
Niet aangesloten bij Stack Overflow