Sök…


Deklarera och initiera variabla tensorer

Variabla tensorer används när värdena kräver uppdatering inom en session. Det är den typen av tensor som skulle användas för viktsmatrisen när du skapar nervnätverk, eftersom dessa värden kommer att uppdateras när modellen tränas.

Att förklara en variabel tensor kan göras med tf.Variable() eller tf.get_variable() . Det rekommenderas att använda tf.get_variable , eftersom det ger mer flexibilitet, t.ex.:

# Declare a 2 by 3 tensor populated by ones
a = tf.Variable(tf.ones([2,3], dtype=tf.float32))
a = tf.get_variable('a', shape=[2, 3], initializer=tf.constant_initializer(1))

Något att notera är att deklarering av en variabel tensor automatiskt inte initialiserar värdena. Värdena måste integreras uttryckligt när du startar en session med något av följande:

  • tf.global_variables_initializer().run()
  • session.run(tf.global_variables_initializer())

Följande exempel visar hela processen för att deklarera och initiera en variabel tensor.

# Build a graph
graph = tf.Graph()
with graph.as_default():
    a = tf.get_variable('a', shape=[2,3], initializer=tf.constant_initializer(1), dtype=tf.float32))     # Create a variable tensor

# Create a session, and run the graph
with tf.Session(graph=graph) as session:
    tf.global_variables_initializer().run()  # Initialize values of all variable tensors
    output_a = session.run(a)            # Return the value of the variable tensor
    print(output_a)                      # Print this value

Vilket skriver ut följande:

[[ 1.  1.  1.]
 [ 1.  1.  1.]]

Hämta värdet på en TensorFlow-variabel eller en Tensor

Ibland måste vi hämta och skriva ut värdet på en TensorFlow-variabel för att garantera att vårt program är korrekt.

Om vi till exempel har följande program:

import tensorflow as tf
import numpy as np
a = tf.Variable(tf.random_normal([2,3])) # declare a tensorflow variable
b = tf.random_normal([2,2]) #declare a tensorflow tensor
init = tf.initialize_all_variables()

om vi vill få värdet på a eller b kan följande procedurer användas:

with tf.Session() as sess:
    sess.run(init)
    a_value = sess.run(a)
    b_value = sess.run(b)
    print a_value
    print b_value

eller

with tf.Session() as sess:
    sess.run(init)
    a_value = a.eval()
    b_value = b.eval()
    print a_value
    print b_value


Modified text is an extract of the original Stack Overflow Documentation
Licensierat under CC BY-SA 3.0
Inte anslutet till Stack Overflow