Recherche…


Remarques

Lorsque vous avez un modèle énorme, il est utile de former des groupes de tenseurs dans votre graphe de calcul, qui sont connectés entre eux. Par exemple, la classe tf.GraphKeys contient des collections standard telles que:

tf.GraphKeys.VARIABLES
tf.GraphKeys.TRAINABLE_VARIABLES
tf.GraphKeys.SUMMARIES

Créez votre propre collection et utilisez-la pour collecter toutes vos pertes.

Ici, nous allons créer une collection pour les pertes du graphe de calcul de Neural Network.

Commencez par créer un graphe de calcul comme suit:

with tf.variable_scope("Layer"):
    W = tf.get_variable("weights", [m, k],
        initializer=tf.zeros_initializer([m, k], dtype=tf.float32))
    b1 = tf.get_variable("bias", [k],
        initializer = tf.zeros_initializer([k], dtype=tf.float32))
    z = tf.sigmoid((tf.matmul(input, W) + b1))
    
    with tf.variable_scope("Softmax"):
        U = tf.get_variable("weights", [k, r],
            initializer=tf.zeros_initializer([k,r], dtype=tf.float32))
        b2 = tf.get_variable("bias", [r],
            initializer=tf.zeros_initializer([r], dtype=tf.float32))
    out = tf.matmul(z, U) + b2
cross_entropy = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(out, labels))

Pour créer une nouvelle collection, vous pouvez simplement appeler tf.add_to_collection() - le premier appel créera la collection.

tf.add_to_collection("my_losses", 
    self.config.l2 * (tf.add_n([tf.reduce_sum(U ** 2), tf.reduce_sum(W ** 2)])))
tf.add_to_collection("my_losses", cross_entropy)

Et enfin, vous pouvez obtenir des tenseurs de votre collection:

loss = sum(tf.get_collection("my_losses"))

Notez que tf.get_collection() renvoie une copie de la collection ou une liste vide si la collection n'existe pas. En outre, il ne crée pas la collection si elle n'existe pas. Pour ce faire, vous pouvez utiliser tf.get_collection_ref() qui renvoie une référence à la collection et en crée une si elle n'existe pas encore.

Recueillir des variables à partir de portées imbriquées

Vous trouverez ci-dessous une couche cachée unique multicéphale Perceptron (MLP) utilisant la portée imbriquée des variables.

def weight_variable(shape):
    return tf.get_variable(name="weights", shape=shape,
                           initializer=tf.zeros_initializer(dtype=tf.float32))

def bias_variable(shape):
    return tf.get_variable(name="biases", shape=shape,
                           initializer=tf.zeros_initializer(dtype=tf.float32))

def fc_layer(input, in_dim, out_dim, layer_name):
    with tf.variable_scope(layer_name):
        W = weight_variable([in_dim, out_dim])
        b = bias_variable([out_dim])
        linear = tf.matmul(input, W) + b
        output = tf.sigmoid(linear)

with tf.variable_scope("MLP"):
    x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="x")
    y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="y")
    fc1 = fc_layer(x, 1, 8, "fc1")
    fc2 = fc_layer(fc1, 8, 1, "fc2")

mse_loss = tf.reduce_mean(tf.reduce_sum(tf.square(fc2 - y), axis=1))

Le MLP utilise le nom de portée de niveau supérieur MLP et il a deux couches avec leurs noms de portée respectifs fc1 et fc2 . Chaque couche a également ses propres variables de weights et de biases .

Les variables peuvent être collectées comme suit:

trainable_var_key = tf.GraphKeys.TRAINABLE_VARIABLES
all_vars = tf.get_collection(key=trainable_var_key, scope="MLP")
fc1_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1")
fc2_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc2")
fc1_weight_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/weights")
fc1_bias_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/biases")

Les valeurs des variables peuvent être collectées à l'aide de la commande sess.run() . Par exemple, si nous souhaitons collecter les valeurs de fc1_weight_vars après l’entraînement, nous pourrions procéder comme suit:

sess = tf.Session()
# add code to initialize variables
# add code to train the network
# add code to create test data x_test and y_test

fc1_weight_vals = sess.run(fc1, feed_dict={x: x_test, y: y_test})
print(fc1_weight_vals)  # This should be an ndarray with ndim=2 and shape=[1, 8]


Modified text is an extract of the original Stack Overflow Documentation
Sous licence CC BY-SA 3.0
Non affilié à Stack Overflow