tensorflow
Wie verwende ich TensorFlow Graph Collections?
Suche…
Bemerkungen
Wenn Sie ein großes Modell haben, ist es nützlich, einige Gruppen von Tensoren in Ihrem Berechnungsgraph zu bilden, die miteinander verbunden sind. Zum Beispiel enthält die Klasse tf.GraphKeys solche Standardauflistungen wie:
tf.GraphKeys.VARIABLES
tf.GraphKeys.TRAINABLE_VARIABLES
tf.GraphKeys.SUMMARIES
Erstellen Sie Ihre eigene Sammlung und verwenden Sie sie, um alle Ihre Verluste einzutreiben.
Hier erstellen wir eine Sammlung für Verluste des Berechnungsgraphen von Neural Network.
Erstellen Sie zunächst einen Berechnungsgraph:
with tf.variable_scope("Layer"):
W = tf.get_variable("weights", [m, k],
initializer=tf.zeros_initializer([m, k], dtype=tf.float32))
b1 = tf.get_variable("bias", [k],
initializer = tf.zeros_initializer([k], dtype=tf.float32))
z = tf.sigmoid((tf.matmul(input, W) + b1))
with tf.variable_scope("Softmax"):
U = tf.get_variable("weights", [k, r],
initializer=tf.zeros_initializer([k,r], dtype=tf.float32))
b2 = tf.get_variable("bias", [r],
initializer=tf.zeros_initializer([r], dtype=tf.float32))
out = tf.matmul(z, U) + b2
cross_entropy = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(out, labels))
Um eine neue Sammlung zu erstellen, können Sie einfach tf.add_to_collection()
aufrufen. Der erste Aufruf erstellt die Sammlung.
tf.add_to_collection("my_losses",
self.config.l2 * (tf.add_n([tf.reduce_sum(U ** 2), tf.reduce_sum(W ** 2)])))
tf.add_to_collection("my_losses", cross_entropy)
Und zum Schluß können Sie Tensoren aus Ihrer Sammlung erhalten:
loss = sum(tf.get_collection("my_losses"))
Beachten Sie, dass tf.get_collection()
eine Kopie der Sammlung oder eine leere Liste zurückgibt, wenn die Sammlung nicht vorhanden ist. Außerdem wird die Sammlung NICHT erstellt, wenn sie nicht vorhanden ist. Dazu können Sie tf.get_collection_ref()
das einen Verweis auf die Auflistung zurückgibt und tatsächlich eine leere erstellt, wenn diese noch nicht vorhanden ist.
Sammeln Sie Variablen aus verschachtelten Bereichen
Nachfolgend finden Sie ein einzelnes Multilayer Perceptron (MLP) mit versteckter Ebene, das verschachtelte Bereiche von Variablen verwendet.
def weight_variable(shape):
return tf.get_variable(name="weights", shape=shape,
initializer=tf.zeros_initializer(dtype=tf.float32))
def bias_variable(shape):
return tf.get_variable(name="biases", shape=shape,
initializer=tf.zeros_initializer(dtype=tf.float32))
def fc_layer(input, in_dim, out_dim, layer_name):
with tf.variable_scope(layer_name):
W = weight_variable([in_dim, out_dim])
b = bias_variable([out_dim])
linear = tf.matmul(input, W) + b
output = tf.sigmoid(linear)
with tf.variable_scope("MLP"):
x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="x")
y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="y")
fc1 = fc_layer(x, 1, 8, "fc1")
fc2 = fc_layer(fc1, 8, 1, "fc2")
mse_loss = tf.reduce_mean(tf.reduce_sum(tf.square(fc2 - y), axis=1))
Das MLP verwendet den Bereichsnamen MLP
der obersten Ebene und verfügt über zwei Ebenen mit den jeweiligen Bereichsnamen fc1
und fc2
. Jede Schicht hat auch ihre eigenen Variablen für weights
und biases
.
Die Variablen können wie folgt gesammelt werden:
trainable_var_key = tf.GraphKeys.TRAINABLE_VARIABLES
all_vars = tf.get_collection(key=trainable_var_key, scope="MLP")
fc1_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1")
fc2_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc2")
fc1_weight_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/weights")
fc1_bias_vars = tf.get_collection(key=trainable_var_key, scope="MLP/fc1/biases")
Die Werte der Variablen können mit dem Befehl sess.run()
erfasst werden. Wenn wir zum Beispiel die Werte der fc1_weight_vars
nach dem Training sammeln fc1_weight_vars
, könnten wir Folgendes tun:
sess = tf.Session()
# add code to initialize variables
# add code to train the network
# add code to create test data x_test and y_test
fc1_weight_vals = sess.run(fc1, feed_dict={x: x_test, y: y_test})
print(fc1_weight_vals) # This should be an ndarray with ndim=2 and shape=[1, 8]