tensorflow 1.x 實戰教程(十一)—模型的保存與恢復

語言: CN / TW / HK

目標

本文旨在介紹 tensorflow 入門知識點及實戰示例,希望各位新手同學能在學習之後熟練 tensorflow 相關基本操作

模型保存

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST", one_hot=True)
batch_size = 64
n_batches = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

w = tf.Variable(tf.random_normal([784, 10], stddev=0.1))
b = tf.Variable(tf.zeros([10]))
predict = tf.nn.softmax(tf.matmul(x, w) + b)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=y))
opt = tf.train.AdamOptimizer(0.001).minimize(loss)
correct = tf.equal(tf.argmax(y, 1), tf.argmax(predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    total_batch = 0
    last = 0
    best = 0
    for epoch in range(100):
        for _ in range(n_batches):
            xx,yy = mnist.train.next_batch(batch_size)
            sess.run(opt, {x:xx, y:yy})
        acc, l = sess.run([accuracy, loss], {x:mnist.test.images, y:mnist.test.labels})
        if acc > best:
            best = acc
            last = total_batch
            saver.save(sess, 'saved_model/model') # 每次只保存最好的結果
            print(epoch, acc, l)
        if total_batch - last > 5:
            print('early stop')
            break
        total_batch += 1

結果輸出

0 0.9035 1.5953374
1 0.9147 1.5688152
2 0.9212 1.5580758
3 0.9234 1.552525
4 0.9239 1.5495663
5 0.9264 1.5462393
6 0.9271 1.5441632
7 0.9288 1.5419955
8 0.9302 1.5403246
12 0.9308 1.5376735
14 0.9324 1.5360526
19 0.9333 1.534032
25 0.9338 1.5329739
26 0.934 1.5326717
early stop

模型讀取

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'saved_model/model')
    acc, l = sess.run([accuracy, loss], {x:mnist.test.images, y:mnist.test.labels})
    print(acc, l)

結果打印

0.934 1.5326717

要點一

因為我們只保存效果最好的模型,所以我們在讀取模型,使用相同數據進行測試的結果和訓練的最後一次是一樣的。

本文參考

本文參考:https://blog.csdn.net/qq_19672707/article/details/106082917