tensorflow
variabler
Sök…
Deklarera och initiera variabla tensorer
Variabla tensorer används när värdena kräver uppdatering inom en session. Det är den typen av tensor som skulle användas för viktsmatrisen när du skapar nervnätverk, eftersom dessa värden kommer att uppdateras när modellen tränas.
Att förklara en variabel tensor kan göras med tf.Variable()
eller tf.get_variable()
. Det rekommenderas att använda tf.get_variable
, eftersom det ger mer flexibilitet, t.ex.:
# Declare a 2 by 3 tensor populated by ones a = tf.Variable(tf.ones([2,3], dtype=tf.float32)) a = tf.get_variable('a', shape=[2, 3], initializer=tf.constant_initializer(1))
Något att notera är att deklarering av en variabel tensor automatiskt inte initialiserar värdena. Värdena måste integreras uttryckligt när du startar en session med något av följande:
-
tf.global_variables_initializer().run()
-
session.run(tf.global_variables_initializer())
Följande exempel visar hela processen för att deklarera och initiera en variabel tensor.
# Build a graph graph = tf.Graph() with graph.as_default(): a = tf.get_variable('a', shape=[2,3], initializer=tf.constant_initializer(1), dtype=tf.float32)) # Create a variable tensor # Create a session, and run the graph with tf.Session(graph=graph) as session: tf.global_variables_initializer().run() # Initialize values of all variable tensors output_a = session.run(a) # Return the value of the variable tensor print(output_a) # Print this value
Vilket skriver ut följande:
[[ 1. 1. 1.]
[ 1. 1. 1.]]
Hämta värdet på en TensorFlow-variabel eller en Tensor
Ibland måste vi hämta och skriva ut värdet på en TensorFlow-variabel för att garantera att vårt program är korrekt.
Om vi till exempel har följande program:
import tensorflow as tf
import numpy as np
a = tf.Variable(tf.random_normal([2,3])) # declare a tensorflow variable
b = tf.random_normal([2,2]) #declare a tensorflow tensor
init = tf.initialize_all_variables()
om vi vill få värdet på a eller b kan följande procedurer användas:
with tf.Session() as sess:
sess.run(init)
a_value = sess.run(a)
b_value = sess.run(b)
print a_value
print b_value
eller
with tf.Session() as sess:
sess.run(init)
a_value = a.eval()
b_value = b.eval()
print a_value
print b_value