Ricerca…


introduzione

Costruire e soprattutto addestrare un modello può essere più semplice in Python, quindi come caricare e utilizzare il modello addestrato in Java?

Osservazioni

Il modello può accettare qualsiasi numero di input, quindi modifica NUM_PREDICTIONS se desideri eseguire più previsioni di una. Renditi conto che Java sta usando JNI per chiamare nel modello tensorflow di C ++, quindi vedrai alcuni messaggi di informazione provenienti dal modello quando lo esegui.

Crea e salva un modello con Python

import tensorflow as tf
# good idea
tf.reset_default_graph()

# DO MODEL STUFF
# Pretrained weighting of 2.0
W = tf.get_variable('w', shape=[], initializer=tf.constant(2.0), dtype=tf.float32)
# Model input x
x = tf.placeholder(tf.float32, name='x')
# Model output y = W*x
y = tf.multiply(W, x, name='y')

# DO SESSION STUFF
sess = tf.Session()
sess.run(tf.global_variables_initializer()) 

# SAVE THE MODEL
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(
  sess, 
  [tf.saved_model.tag_constants.SERVING]
)
builder.save()

Carica e utilizza il modello in Java.

public static void main( String[] args ) throws IOException
{
    // good idea to print the version number, 1.2.0 as of this writing
    System.out.println(TensorFlow.version());        
    final int NUM_PREDICTIONS = 1;

    // load the model Bundle
    try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {

        // create the session from the Bundle
        Session sess = b.session();
        // create an input Tensor, value = 2.0f
        Tensor x = Tensor.create(
            new long[] {NUM_PREDICTIONS}, 
            FloatBuffer.wrap( new float[] {2.0f} ) 
        );
        
        // run the model and get the result, 4.0f.
        float[] y = sess.runner()
            .feed("x", x)
            .fetch("y")
            .run()
            .get(0)
            .copyTo(new float[NUM_PREDICTIONS]);

        // print out the result.
        System.out.println(y[0]);
    }                
}


Modified text is an extract of the original Stack Overflow Documentation
Autorizzato sotto CC BY-SA 3.0
Non affiliato con Stack Overflow