खोज…


परिचय

Tensorflow एक ग्राफ में सभी चर के वर्तमान मूल्यों को सहेजने / बहाल करने और वास्तविक ग्राफ संरचना को बचाने / बहाल करने के बीच अंतर करता है। ग्राफ़ को पुनर्स्थापित करने के लिए, आप या तो Tensorflow के फ़ंक्शंस का उपयोग करने के लिए स्वतंत्र हैं या बस अपने कोड ऑफ़ कोड को फिर से कॉल करें, जिसने ग्राफ़ को पहले स्थान पर बनाया है। ग्राफ़ को परिभाषित करते समय, आपको यह भी सोचना चाहिए कि ग्राफ़ को सहेजने और पुनर्स्थापित करने के बाद कौन से और कैसे चर / ऑप्स पुनर्प्राप्त करने योग्य होने चाहिए।

टिप्पणियों

ऊपर के मॉडल अनुभाग को ठीक से समझने पर यदि आप मॉडल का निर्माण करते हैं और फिर चर को पुनर्स्थापित करते हैं। मेरा मानना है कि जब तक आप tf.add_to_collection() का उपयोग करके बचत करते हैं, तब तक प्रासंगिक tf.add_to_collection() / प्लेसहोल्डर को जोड़ने के लिए मॉडल का पुनर्निर्माण आवश्यक नहीं है। उदाहरण के लिए:

tf.add_to_collection('cost_op', cost_op)

फिर बाद में आप सहेजे गए ग्राफ़ को पुनर्स्थापित कर सकते हैं और उपयोग करके cost_op तक पहुँच प्राप्त कर सकते हैं

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

यहां तक कि अगर आप tf.add_to_collection() नहीं चलाते हैं, तो आप अपने tf.add_to_collection() को पुनः प्राप्त कर सकते हैं, लेकिन प्रक्रिया थोड़ी अधिक बोझिल है, और आपको चीजों के लिए सही नाम खोजने के लिए कुछ खुदाई करनी पड़ सकती है। उदाहरण के लिए:

एक स्क्रिप्ट में जो टेंसरफ़्लो ग्राफ बनाता है, हम टेंसर्स के कुछ सेट को परिभाषित करते हैं lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

हम उन्हें बाद में निम्नानुसार याद कर सकते हैं:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

एक टेंसर का नाम खोजने के कई तरीके हैं - आप इसे टेन्सर बोर्ड पर अपने ग्राफ में पा सकते हैं, या आप कुछ इस तरह से खोज सकते हैं:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

मॉडल सहेजना

एक मॉडल को टेंसरफ्लो में सहेजना बहुत आसान है।

मान लें कि आपके पास इनपुट x साथ एक रैखिक मॉडल है और आउटपुट y भविष्यवाणी करना चाहते हैं। यहाँ नुकसान औसत वर्ग त्रुटि (MSE) है। बैच का आकार 16 है।

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

यहां सेवर ऑब्जेक्ट आता है, जिसमें कई पैरामीटर (cf. doc ) हो सकते हैं।

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

अंत में हम 1000 पुनरावृत्तियों के लिए मॉडल को tf.Session() में प्रशिक्षित करते हैं। हम केवल मॉडल को हर 100 पुनरावृत्तियों को यहां सहेजते हैं।

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

इस कोड को चलाने के बाद, आपको अपनी निर्देशिका में अंतिम 5 चौकियों को देखना चाहिए:

  • model-500 और model-500.meta
  • model-600 और model-600.meta
  • model-700 और model-700.meta
  • model-800 और model-800.meta
  • model-900 और model-900.meta

ध्यान दें कि इस उदाहरण में, जबकि saver वास्तव में चेकपॉइंट और ग्राफ़ की संरचना ( *.meta ) के रूप में चर के दोनों वर्तमान मूल्यों को बचाता है, कोई विशेष ध्यान नहीं दिया गया था कि कैसे प्लेसहोल्डर x और y एक बार पुनः प्राप्त किया जाए। मॉडल को बहाल किया गया था। उदाहरण के लिए, यदि इस प्रशिक्षण स्क्रिप्ट की तुलना में कहीं और बहाल किया जाता है, तो यह बहाल ग्राफ से x और y को पुनर्प्राप्त करने के लिए बोझिल हो सकता है (विशेष रूप से अधिक जटिल मॉडल में)। उससे बचने के लिए, हमेशा अपने चर / प्लेसहोल्डर / ऑप्स को नाम दें या टिप्पणी में दिखाए गए अनुसार tf.collections का उपयोग करने के बारे में tf.collections

मॉडल को पुनर्स्थापित करना

रिस्टोर करना भी काफी अच्छा और आसान है।

यहाँ एक आसान सहायक कार्य है:

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

मुख्य कोड:

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
के तहत लाइसेंस प्राप्त है CC BY-SA 3.0
से संबद्ध नहीं है Stack Overflow