Sök…


parametrar

Parameter detaljer
datatyp (dtype) specifikt en av de datatyper som tillhandahålls av tensorflow-paketet. tensorflow.float32
dataform (form) Dimensioner på platshållare som lista eller tupel. None kan användas för dimensioner som är okända. Exempelvis (Ingen, 30) skulle definiera en (? X 30) dimension platshållare
namn (namn) Ett namn för operationen (valfritt).

Grunderna för platshållare

Med platshållare kan du mata värden i ett tensorflödesdiagram. Tillägg De tillåter dig att ange begränsningar för dimensioner och datatyp för värden som matas in. Som sådana är de användbara när du skapar ett neuralt nätverk för att mata nya träningsexempel.

Följande exempel deklarerar en platshållare för en 3 x 4-tensor med element som är (eller kan skrivas till) 32-bitars flottörer.

a = tf.placeholder(tf.float32, shape=[3,4], name='a')

Platshållare kommer inte att innehålla några värden på egen hand, så det är viktigt att mata dem med värden när du kör en session annars får du ett felmeddelande. Detta kan göras med feed_dict argumentet när du ringer session.run() , t.ex.:

# run the graph up to node b, feeding the placeholder `a` with values in my_array 
session.run(b, feed_dict={a: my_array})

Här är ett enkelt exempel som visar hela processen för att deklarera och mata en placeholer.

import tensorflow as tf
import numpy as np

# Build a graph
graph = tf.Graph()
with graph.as_default():
    # declare a placeholder that is 3 by 4 of type float32
    a = tf.placeholder(tf.float32, shape=(3, 4), name='a')
    
    # Perform some operation on the placeholder
    b = a * 2
    
# Create an array to be fed to `a`
input_array = np.ones((3,4))

# Create a session, and run the graph
with tf.Session(graph=graph) as session:
    # run the session up to node b, feeding an array of values into a
    output = session.run(b, feed_dict={a: input_array})
    print(output)

Platshållaren tar en grupp 3 med 4, och att tensor multipliceras sedan med 2 vid nod b, som sedan returnerar och skriver ut följande:

[[ 2.  2.  2.  2.]
 [ 2.  2.  2.  2.]
 [ 2.  2.  2.  2.]]

Platshållare med standard

Ofta vill man intermittent köra en eller flera valideringspartier under utbildningen av ett djupt nätverk. Typiskt träningsdata matas av en kö medan valideringsdata kan föras genom feed_dict parameter i sess.run() . tf.placeholder_with_default() är utformad för att fungera bra i den här situationen:

import numpy as np
import tensorflow as tf

IMG_SIZE = [3, 3]
BATCH_SIZE_TRAIN = 2
BATCH_SIZE_VAL = 1

def get_training_batch(batch_size):
    ''' training data pipeline '''
    image = tf.random_uniform(shape=IMG_SIZE)
    label = tf.random_uniform(shape=[])
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    images, labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    return images, labels

# define the graph
images_train, labels_train = get_training_batch(BATCH_SIZE_TRAIN)
image_batch = tf.placeholder_with_default(images_train, shape=None)
label_batch = tf.placeholder_with_default(labels_train, shape=None)
new_images = tf.mul(image_batch, -1)
new_labels = tf.mul(label_batch, -1)

# start a session
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # typical training step where batch data are drawn from the training queue
    py_images, py_labels = sess.run([new_images, new_labels])
    print('Data from queue:')
    print('Images: ', py_images)  # returned values in range [-1.0, 0.0]
    print('\nLabels: ', py_labels) # returned values [-1, 0.0]

    # typical validation step where batch data are supplied through feed_dict
    images_val = np.random.randint(0, 100, size=np.hstack((BATCH_SIZE_VAL, IMG_SIZE)))
    labels_val = np.ones(BATCH_SIZE_VAL)
    py_images, py_labels = sess.run([new_images, new_labels],
                      feed_dict={image_batch:images_val, label_batch:labels_val})
    print('\n\nData from feed_dict:')
    print('Images: ', py_images) # returned values are integers in range [-100.0, 0.0]
    print('\nLabels: ', py_labels) # returned values are -1.0

    coord.request_stop()
    coord.join(threads)

I detta exempel image_batch och label_batch av get_training_batch() inte motsvarande värden skickas som feed_dict parametern under ett samtal till sess.run() .



Modified text is an extract of the original Stack Overflow Documentation
Licensierat under CC BY-SA 3.0
Inte anslutet till Stack Overflow