tensorflow
Tworzenie niestandardowej operacji za pomocą tf.py_func (tylko procesor)
Szukaj…
Parametry
Parametr | Detale |
---|---|
func | funkcja python, która pobiera tablice numpy jako dane wejściowe i zwraca tablice numpy jako dane wyjściowe |
w p | lista tensorów (wejścia) |
Naganiacz | lista typów danych tensorflow dla wyników func |
Podstawowy przykład
Operator tf.py_func(func, inp, Tout)
tworzy operację TensorFlow, która wywołuje funkcję Pythona, func
na liście inp
tensorów.
Zobacz dokumentację tf.py_func(func, inp, Tout)
.
Ostrzeżenie : Operacja tf.py_func()
będzie działać tylko na procesorze. Jeśli używasz rozproszonego TensorFlow, tf.py_func()
musi być umieszczona na urządzeniu CPU w tym samym procesie co klient.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Dlaczego warto korzystać z tf.py_func
Operator tf.py_func()
umożliwia uruchomienie dowolnego kodu w języku Python na środku wykresu TensorFlow. Jest to szczególnie wygodne do owijania niestandardowych operatorów NumPy, dla których nie istnieje (jeszcze) równoważny operator TensorFlow. Dodanie tf.py_func()
jest alternatywą dla używania sess.run()
wewnątrz wykresu.
Innym sposobem na to jest przecięcie wykresu na dwie części:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
Z tf.py_func
jest to o wiele łatwiejsze:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)