tensorflow
Sauvegarde du modèle Tensorflow en Python et chargement avec Java
Recherche…
Introduction
Construire et surtout former un modèle peut être plus facile en Python. Alors, comment charger et utiliser le modèle formé en Java?
Remarques
Le modèle peut accepter n'importe quel nombre d'entrées. Modifiez donc NUM_PREDICTIONS si vous souhaitez exécuter plus de prédictions qu'une seule. Sachez que Java utilise JNI pour appeler le modèle de tensorflow C ++. Vous verrez donc des messages d’information provenant du modèle lors de son exécution.
Créer et enregistrer un modèle avec Python
import tensorflow as tf
# good idea
tf.reset_default_graph()
# DO MODEL STUFF
# Pretrained weighting of 2.0
W = tf.get_variable('w', shape=[], initializer=tf.constant(2.0), dtype=tf.float32)
# Model input x
x = tf.placeholder(tf.float32, name='x')
# Model output y = W*x
y = tf.multiply(W, x, name='y')
# DO SESSION STUFF
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# SAVE THE MODEL
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING]
)
builder.save()
Chargez et utilisez le modèle en Java.
public static void main( String[] args ) throws IOException
{
// good idea to print the version number, 1.2.0 as of this writing
System.out.println(TensorFlow.version());
final int NUM_PREDICTIONS = 1;
// load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
// create the session from the Bundle
Session sess = b.session();
// create an input Tensor, value = 2.0f
Tensor x = Tensor.create(
new long[] {NUM_PREDICTIONS},
FloatBuffer.wrap( new float[] {2.0f} )
);
// run the model and get the result, 4.0f.
float[] y = sess.runner()
.feed("x", x)
.fetch("y")
.run()
.get(0)
.copyTo(new float[NUM_PREDICTIONS]);
// print out the result.
System.out.println(y[0]);
}
}
Modified text is an extract of the original Stack Overflow Documentation
Sous licence CC BY-SA 3.0
Non affilié à Stack Overflow