Szukaj…


Uwagi

Gdy masz duży model, przydatne jest utworzenie na wykresie obliczeniowym kilku grup tensorów, które są ze sobą połączone. Na przykład klasa tf.GraphKeys zawiera takie standardowe kolekcje, jak:

tf.GraphKeys.VARIABLES
tf.GraphKeys.TRAINABLE_VARIABLES
tf.GraphKeys.SUMMARIES

Stwórz własną kolekcję i wykorzystaj ją do zebrania wszystkich strat.

Tutaj stworzymy kolekcję strat grafu obliczeniowego Neural Network.

Najpierw utwórz wykres obliczeniowy w taki sposób:

with tf.variable_scope("Layer"):
    W = tf.get_variable("weights", [m, k],
        initializer=tf.zeros_initializer([m, k], dtype=tf.float32))
    b1 = tf.get_variable("bias", [k],
        initializer = tf.zeros_initializer([k], dtype=tf.float32))
    z = tf.sigmoid((tf.matmul(input, W) + b1))
    
    with tf.variable_scope("Softmax"):
        U = tf.get_variable("weights", [k, r],
            initializer=tf.zeros_initializer([k,r], dtype=tf.float32))
        b2 = tf.get_variable("bias", [r],
            initializer=tf.zeros_initializer([r], dtype=tf.float32))
    out = tf.matmul(z, U) + b2
cross_entropy = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(out, labels))

Aby utworzyć nową kolekcję, możesz po prostu zacząć wywoływać tf.add_to_collection() - pierwsze wywołanie utworzy kolekcję.

tf.add_to_collection("my_losses", 
    self.config.l2 * (tf.add_n([tf.reduce_sum(U ** 2), tf.reduce_sum(W ** 2)])))
tf.add_to_collection("my_losses", cross_entropy)

I wreszcie możesz uzyskać tensory ze swojej kolekcji:

loss = sum(tf.get_collection("my_losses"))

Zauważ, że tf.get_collection() zwraca kopię kolekcji lub pustą listę, jeśli kolekcja nie istnieje. Ponadto NIE tworzy kolekcji, jeśli nie istnieje. Aby to zrobić, możesz użyć funkcji tf.get_collection_ref() która zwraca odwołanie do kolekcji i faktycznie tworzy pustą, jeśli jeszcze nie istnieje.

Zbieraj zmienne z zasięgów zagnieżdżonych

Poniżej znajduje się jedna ukryta warstwa Perceptron wielowarstwowy (MLP) z wykorzystaniem zagnieżdżonego zakresu zmiennych.

def weight_variable(shape):
    return tf.get_variable(name="weights", shape=shape,
                           initializer=tf.zeros_initializer(dtype=tf.float32))

def bias_variable(shape):
    return tf.get_variable(name="biases", shape=shape,
                           initializer=tf.zeros_initializer(dtype=tf.float32))

def fc_layer(input, in_dim, out_dim, layer_name):
    with tf.variable_scope(layer_name):
        W = weight_variable([in_dim, out_dim])
        b = bias_variable([out_dim])
        linear = tf.matmul(input, W) + b
        output = tf.sigmoid(linear)

with tf.variable_scope("MLP"):
    x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="x")
    y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="y")
    fc1 = fc_layer(x, 1, 8, "fc1")
    fc2 = fc_layer(fc1, 8, 1, "fc2")

mse_loss = tf.reduce_mean(tf.reduce_sum(tf.square(fc2 - y), axis=1))

MLP używa nazwy ZAKRES najwyższego poziomu MLP i ma dwie warstwy z ich nazwami zakres fc1 i fc2 . Każda warstwa ma także własne zmienne weights i biases .

Zmienne można zebrać w następujący sposób:

trainable_var_key = tf.GraphKeys.TRAINABLE_VARIABLES
all_vars = tf.get_collection(key=trainable_var_key, scope="MLP")
fc1_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1")
fc2_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc2")
fc1_weight_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/weights")
fc1_bias_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/biases")

Wartości zmiennych można zebrać za pomocą polecenia sess.run() . Na przykład, jeśli chcielibyśmy zebrać wartości fc1_weight_vars po treningu, moglibyśmy wykonać następujące czynności:

sess = tf.Session()
# add code to initialize variables
# add code to train the network
# add code to create test data x_test and y_test

fc1_weight_vals = sess.run(fc1, feed_dict={x: x_test, y: y_test})
print(fc1_weight_vals)  # This should be an ndarray with ndim=2 and shape=[1, 8]


Modified text is an extract of the original Stack Overflow Documentation
Licencjonowany na podstawie CC BY-SA 3.0
Nie związany z Stack Overflow