수색…


소개

Tensorflow는 그래프의 모든 변수의 현재 값 저장 / 복원과 실제 그래프 구조 저장 / 복원을 구분합니다. 그래프를 복원하려면 Tensorflow의 기능을 자유롭게 사용하거나 처음부터 그래프를 작성한 코드 조각을 다시 호출하면됩니다. 그래프를 정의 할 때 그래프가 저장되고 복원되면 변수 및 연산을 검색 할 수있는 방법과 방법을 고려해야합니다.

비고

위의 모델 복원 섹션에서 모델을 올바르게 작성한 다음 변수를 복원하십시오. tf.add_to_collection() 사용하여 저장하는 경우 관련 텐서 / 자리 표시자를 추가하면 모델을 다시 작성할 필요가 없다고 생각합니다. 예 :

tf.add_to_collection('cost_op', cost_op)

그런 다음 나중에 저장된 그래프를 복원하고 다음을 사용하여 cost_op 액세스 할 수 있습니다.

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

tf.add_to_collection() 실행하지 않더라도 텐서를 가져올 수는 있지만 프로세스가 좀 더 번거롭기 때문에 올바른 이름을 찾으려면 약간의 파고를해야 할 수도 있습니다. 예 :

tensorflow 그래프를 작성하는 스크립트에서 몇 가지 텐서 집합을 정의합니다. lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

우리는 나중에 다음과 같이 그들을 기억할 수 있습니다 :

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

텐서의 이름을 찾는 방법은 여러 가지가 있습니다. 텐서 보드의 그래프에서 찾을 수 있습니다. 또는 다음과 같이 검색 할 수 있습니다.

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

모델 저장

tensorflow에서 모델을 저장하는 것은 매우 쉽습니다.

입력 x 가있는 선형 모델을 가지고 출력 y 를 예측하려고한다고 가정 해 보겠습니다. 손실은 MSE (Mean Square Error)입니다. 배치 크기는 16입니다.

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

여러 매개 변수를 가질 수있는 Saver 객체가 있습니다 (cf. doc ).

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

마지막으로 tf.Session() 에서 1000 회 반복하여 모델을 학습합니다. 100 회 반복 할 때마다 모델을 저장합니다.

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

이 코드를 실행하면 디렉토리에 마지막 5 개의 체크 포인트가 표시됩니다.

  • model-500model-500.meta
  • model-600model-600.meta
  • model-700model-700.meta
  • model-800model-800.meta
  • model-900model-900.meta

이 예제에서 saver 실제로 변수의 현재 값을 체크 포인트와 그래프의 구조 ( *.meta )로 저장하지만, 예를 들어 placeholder xy 를 검색하는 방법은 특별히 신경 쓰지 않았습니다 모델이 복원되었습니다. 예를 들어,이 교육 스크립트 이외의 다른 곳에서 복원을 수행하는 경우 복원 된 그래프에서 xy 를 검색하는 것이 번거로울 수 있습니다 (특히 더 복잡한 모델에서). 이를 방지하기 위해 항상 변수 / placeholder / ops에 이름을 지정하거나 tf.collections 중 하나에 표시된대로 tf.collections 사용에 대해 생각하십시오.

모델 복원

복원도 아주 쉽고 쉽습니다.

다음은 편리한 도우미 함수입니다.

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

주요 코드 :

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
아래 라이선스 CC BY-SA 3.0
와 제휴하지 않음 Stack Overflow