tensorflow
Speichern Sie das Tensorflow-Modell in Python und laden Sie es mit Java
Suche…
Einführung
Das Erstellen und speziell Trainieren eines Modells kann in Python am einfachsten durchgeführt werden.
Bemerkungen
Das Modell kann eine beliebige Anzahl von Eingaben akzeptieren. Ändern Sie daher NUM_PREDICTIONS, wenn Sie mehr Vorhersagen als eine ausführen möchten. Stellen Sie fest, dass Java JNI für den Aufruf des C ++ Tensorflow-Modells verwendet. Wenn Sie dies ausführen, werden einige Infomeldungen aus dem Modell angezeigt.
Erstellen und speichern Sie ein Modell mit Python
import tensorflow as tf
# good idea
tf.reset_default_graph()
# DO MODEL STUFF
# Pretrained weighting of 2.0
W = tf.get_variable('w', shape=[], initializer=tf.constant(2.0), dtype=tf.float32)
# Model input x
x = tf.placeholder(tf.float32, name='x')
# Model output y = W*x
y = tf.multiply(W, x, name='y')
# DO SESSION STUFF
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# SAVE THE MODEL
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING]
)
builder.save()
Laden und verwenden Sie das Modell in Java.
public static void main( String[] args ) throws IOException
{
// good idea to print the version number, 1.2.0 as of this writing
System.out.println(TensorFlow.version());
final int NUM_PREDICTIONS = 1;
// load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
// create the session from the Bundle
Session sess = b.session();
// create an input Tensor, value = 2.0f
Tensor x = Tensor.create(
new long[] {NUM_PREDICTIONS},
FloatBuffer.wrap( new float[] {2.0f} )
);
// run the model and get the result, 4.0f.
float[] y = sess.runner()
.feed("x", x)
.fetch("y")
.run()
.get(0)
.copyTo(new float[NUM_PREDICTIONS]);
// print out the result.
System.out.println(y[0]);
}
}
Modified text is an extract of the original Stack Overflow Documentation
Lizenziert unter CC BY-SA 3.0
Nicht angeschlossen an Stack Overflow