Zoeken…


Opmerkingen

Wanneer je een enorm model hebt, is het handig om enkele groepen tensoren in je computergrafiek te vormen, die met elkaar verbonden zijn. De klasse tf.GraphKeys bevat bijvoorbeeld standaardcollecties als:

tf.GraphKeys.VARIABLES
tf.GraphKeys.TRAINABLE_VARIABLES
tf.GraphKeys.SUMMARIES

Maak je eigen verzameling en gebruik deze om al je verliezen te innen.

Hier maken we een verzameling voor verliezen van de computergrafiek van Neural Network.

Maak eerst een rekenkundige grafiek als volgt:

with tf.variable_scope("Layer"):
    W = tf.get_variable("weights", [m, k],
        initializer=tf.zeros_initializer([m, k], dtype=tf.float32))
    b1 = tf.get_variable("bias", [k],
        initializer = tf.zeros_initializer([k], dtype=tf.float32))
    z = tf.sigmoid((tf.matmul(input, W) + b1))
    
    with tf.variable_scope("Softmax"):
        U = tf.get_variable("weights", [k, r],
            initializer=tf.zeros_initializer([k,r], dtype=tf.float32))
        b2 = tf.get_variable("bias", [r],
            initializer=tf.zeros_initializer([r], dtype=tf.float32))
    out = tf.matmul(z, U) + b2
cross_entropy = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(out, labels))

Om een nieuwe verzameling te maken, kunt u eenvoudig beginnen met het aanroepen van tf.add_to_collection() - de eerste aanroep maakt de verzameling aan.

tf.add_to_collection("my_losses", 
    self.config.l2 * (tf.add_n([tf.reduce_sum(U ** 2), tf.reduce_sum(W ** 2)])))
tf.add_to_collection("my_losses", cross_entropy)

En eindelijk kun je tensoren uit je verzameling halen:

loss = sum(tf.get_collection("my_losses"))

Merk op dat tf.get_collection() een kopie van de verzameling of een lege lijst retourneert als de verzameling niet bestaat. Ook wordt de verzameling NIET gemaakt als deze niet bestaat. Om dit te doen, zou je tf.get_collection_ref() kunnen gebruiken die een verwijzing naar de verzameling retourneert en in feite een lege creëert als deze nog niet bestaat.

Verzamel variabelen van geneste scopes

Hieronder is een enkele verborgen laag Multilayer Perceptron (MLP) met behulp van geneste scoping van variabelen.

def weight_variable(shape):
    return tf.get_variable(name="weights", shape=shape,
                           initializer=tf.zeros_initializer(dtype=tf.float32))

def bias_variable(shape):
    return tf.get_variable(name="biases", shape=shape,
                           initializer=tf.zeros_initializer(dtype=tf.float32))

def fc_layer(input, in_dim, out_dim, layer_name):
    with tf.variable_scope(layer_name):
        W = weight_variable([in_dim, out_dim])
        b = bias_variable([out_dim])
        linear = tf.matmul(input, W) + b
        output = tf.sigmoid(linear)

with tf.variable_scope("MLP"):
    x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="x")
    y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="y")
    fc1 = fc_layer(x, 1, 8, "fc1")
    fc2 = fc_layer(fc1, 8, 1, "fc2")

mse_loss = tf.reduce_mean(tf.reduce_sum(tf.square(fc2 - y), axis=1))

De MLP gebruikt de hoogste fc1 MLP en heeft twee lagen met hun respectieve fc1 en fc2 . Elke laag heeft ook zijn eigen weights en biases variabelen.

De variabelen kunnen als volgt worden verzameld:

trainable_var_key = tf.GraphKeys.TRAINABLE_VARIABLES
all_vars = tf.get_collection(key=trainable_var_key, scope="MLP")
fc1_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1")
fc2_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc2")
fc1_weight_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/weights")
fc1_bias_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/biases")

De waarden van de variabelen kunnen worden verzameld met de opdracht sess.run() . Als we bijvoorbeeld na de training de waarden van de fc1_weight_vars willen verzamelen, kunnen we het volgende doen:

sess = tf.Session()
# add code to initialize variables
# add code to train the network
# add code to create test data x_test and y_test

fc1_weight_vals = sess.run(fc1, feed_dict={x: x_test, y: y_test})
print(fc1_weight_vals)  # This should be an ndarray with ndim=2 and shape=[1, 8]


Modified text is an extract of the original Stack Overflow Documentation
Licentie onder CC BY-SA 3.0
Niet aangesloten bij Stack Overflow