tensorflow
Сохранить модель Tensorflow в Python и загрузить с помощью Java
Поиск…
Вступление
Построение и особенно обучение модели может быть проще всего сделано на Python, так как вы загружаете и используете обученную модель на Java?
замечания
Модель может принимать любое количество входов, поэтому измените NUM_PREDICTIONS, если вы хотите запустить больше прогнозов, чем один. Поймите, что Java использует JNI для вызова модели C ++ tensorflow, поэтому вы увидите некоторые информационные сообщения, поступающие из модели при ее запуске.
Создание и сохранение модели с помощью Python
import tensorflow as tf
# good idea
tf.reset_default_graph()
# DO MODEL STUFF
# Pretrained weighting of 2.0
W = tf.get_variable('w', shape=[], initializer=tf.constant(2.0), dtype=tf.float32)
# Model input x
x = tf.placeholder(tf.float32, name='x')
# Model output y = W*x
y = tf.multiply(W, x, name='y')
# DO SESSION STUFF
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# SAVE THE MODEL
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING]
)
builder.save()
Загрузите и используйте модель на Java.
public static void main( String[] args ) throws IOException
{
// good idea to print the version number, 1.2.0 as of this writing
System.out.println(TensorFlow.version());
final int NUM_PREDICTIONS = 1;
// load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
// create the session from the Bundle
Session sess = b.session();
// create an input Tensor, value = 2.0f
Tensor x = Tensor.create(
new long[] {NUM_PREDICTIONS},
FloatBuffer.wrap( new float[] {2.0f} )
);
// run the model and get the result, 4.0f.
float[] y = sess.runner()
.feed("x", x)
.fetch("y")
.run()
.get(0)
.copyTo(new float[NUM_PREDICTIONS]);
// print out the result.
System.out.println(y[0]);
}
}
Modified text is an extract of the original Stack Overflow Documentation
Лицензировано согласно CC BY-SA 3.0
Не связан с Stack Overflow