tensorflow
Создание настраиваемой операции с помощью tf.py_func (только для CPU)
Поиск…
параметры
параметр | подробности |
---|---|
FUNC | python, которая принимает numpy массивы в качестве своих входов и возвращает массивы numpy как свои выходы |
вх | список тензоров (входов) |
шпионить | список типов данных tensorflow для выходов func |
Основной пример
Оператор tf.py_func(func, inp, Tout)
создает операцию TensorFlow, которая вызывает функцию Python, func
в списке тензоров inp
.
См. Документацию по tf.py_func(func, inp, Tout)
.
Предупреждение : tf.py_func()
будет выполняться только на CPU. Если вы используете распределенный tf.py_func()
операция tf.py_func()
должна быть размещена на устройстве ЦП в том же процессе, что и клиент.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Зачем использовать tf.py_func
Оператор tf.py_func()
позволяет запускать произвольный код Python в середине графика TensorFlow. Это особенно удобно для упаковки пользовательских операторов NumPy, для которых не существует эквивалентного оператора TensorFlow (пока). Добавление tf.py_func()
является альтернативой использованию sess.run()
внутри графика.
Другой способ сделать это - вырезать график в двух частях:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
С tf.py_func
это намного проще:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)