Sök…


parametrar

Parameter detaljer
func pythonfunktion, som tar numpy arrays som sina ingångar och returnerar numpy arrays som dess utgångar
i P lista över Tensorer (ingångar)
Tout lista med tensorflowdatatyper för utgångar från func

Grundläggande exempel

tf.py_func(func, inp, Tout) skapar en TensorFlow-operation som kallar en Python-funktion, func på en lista med inp .

Se dokumentationen för tf.py_func(func, inp, Tout) .

Varning : tf.py_func() -operationen körs endast på CPU. Om du använder distribuerad TensorFlow tf.py_func() operationen tf.py_func() placeras på en CPU-enhet i samma process som klienten.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

Varför man använder tf.py_func

tf.py_func() låter dig köra godtycklig Python-kod mitt i ett TensorFlow-diagram. Det är särskilt bekvämt för inslagning av anpassade NumPy-operatörer för vilka ingen motsvarande TensorFlow-operatör (ännu) finns. Att lägga till tf.py_func() är ett alternativ till att använda sess.run() inuti diagrammet.

Ett annat sätt att göra det är att skära grafen i två delar:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

Med tf.py_func detta mycket lättare:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)


Modified text is an extract of the original Stack Overflow Documentation
Licensierat under CC BY-SA 3.0
Inte anslutet till Stack Overflow