Tensorflow 2.8 實現 GRU 文本生成任務

語言: CN / TW / HK

本文正在參加「金石計劃 . 瓜分6萬現金大獎」 

前言

本文使用 cpu 的 tensorflow 2.8 來完成 GRU 文本生成任務。如果想要了解文本生成的相關概念,可以參考我之前寫的文章:https://juejin.cn/post/6973567782113771551

大綱

  1. 獲取數據
  2. 處理數據
  3. 搭建並訓練模型
  4. 生成文本邏輯
  5. 預測
  6. 保存和讀取模型

實現

1. 獲取數據

(1)我們使用到的數據是莎士比亞的作品,我們使用 TensorFlow 的內置函數從網絡下載到本地的磁盤,我們展現了部分內容,可以看到裏面都是一段一段對話形式的台詞。

(2)通過使用集合找出數據中總共出現了 65 個不同的字符。

import tensorflow as tf
import numpy as np
import os
import time
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
vocab = sorted(set(text))
print(text[:100])
print(f'{len(vocab)} unique characters')

結果輸出:

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
65 unique characters

2. 處理數據

(1)在使用數據的時候我們需要將所有的字符都映射成對應的數字, StringLookup 這個內置函數剛好可以實現這個功能,使用這個函數之前要將文本都切分成字符。另外我們還可以使用 StringLookup 這個內置函數完成從數字到字符的映射轉換。我們自定義了函數 text_from_ids 可以實現將字符的序列還原回原始的文本。

(2)我們將莎士比亞數據中的文本使用 ids_from_chars 全部轉換為整數序列,然後使用 from_tensor_slices 創建 Dataset 對象。

(3)我們將數據都切分層每個 batch 大小為 seq_length+1 的長度,這樣是為了後面創建(input,target)這一樣本形式的。每個樣本 sample 的 input 序列選取文本中的前 seq_length 個字符 sample[:seq_length] 為輸入。對於每個 input ,相應的 target 也包含相同長度的文本,只是整體向右移動了一個字符,選取結果為 sample[1:seq_length+1]。例如 seq_length 是 4,我們的序列是“Hello”,那麼 input 序列為“hell”,目標序列為“ello”。

(4)我們展示了一個樣本,可以看到 input 和 label 的形成遵循上面的規則,其目的就是要讓 RNN 的每個時間步上都有對應的輸入字符和對應的目標字母,輸入字符是當前的字符,目標字符肯定就是後面一個相鄰的字符。

ids_from_chars = tf.keras.layers.StringLookup(vocabulary=list(vocab), mask_token=None)
chars_from_ids = tf.keras.layers.StringLookup(vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)
def text_from_ids(ids):
    return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
seq_length = 64
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text
dataset = sequences.map(split_input_target)

for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Label:", text_from_ids(target_example).numpy())

結果輸出:

Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAl'
Label: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll'

(5)我們將所有處理好的樣本先進行混洗,保證樣本的隨機性,然後將樣本都進行分批,每個 batch 設置大小為 64 ,設置每個詞嵌入維度為 128 ,設置 GRU 的輸入為 128 維。

BATCH_SIZE = 64
BUFFER_SIZE = 10000
vocab_size = len(ids_from_chars.get_vocabulary())
embedding_dim = 128
gru_units = 128
dataset = (dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))

3. 搭建並訓練模型

(1)第一層是詞嵌入層,主要是將用户輸入的序列中的每個證書轉換為模型需要的多維輸入。

(2)第二層是 GRU 層,主要是接收每個時間步的輸入,並且將前後狀態進行計算和保存,讓 GRU 可以記住文本序列規律。

(3)第三層是全連接層,主要是輸出一個字典大小維度的向量,表示的是每個字符對應的概率分佈。

(4)這裏有一些細節需要處理,如果 states 是空,那麼就直接隨機初始化 gru 的初始狀態,另外如果需要返回 states 結果,那麼就將全連接層的輸出和 states 一起返回。

class MyModel(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, gru_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(gru_units, return_sequences=True,  return_state=True)
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x, training=training)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)

        if return_state:
            return x, states
        else:
            return x
model = MyModel( vocab_size=vocab_size, embedding_dim=embedding_dim,  gru_units=gru_units)

(5)我們隨機選取一個樣本,輸入到還沒有訓練的模型中,然後進行文本生成預測,可以看出目前的輸出毫無規。

for one_input, one_target in dataset.take(1):
    one_predictions = model(one_input)
    print(one_predictions.shape, "--> (batch_size, sequence_length, vocab_size)")
sampled_indices = tf.random.categorical(one_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
print("Input:\n", text_from_ids(one_input[0]).numpy())
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

結果輸出:

(64, 64, 66) --> (batch_size, sequence_length, vocab_size)
Input:
 b'\nBut let thy spiders, that suck up thy venom,\nAnd heavy-gaited t'
Next Char Predictions:
 b'ubH-I\nBxZReX!n\n$VBgkBqQxQEVaQ!-Siw uHoTaX!YT;vFYX,r:aLh h$fNRlEN'

(6)這裏主要是選擇損失函數和優化器,我們選取 SparseCategoricalCrossentropy 來作為損失函數,選取 Adam 作為優化器。

(7)我這裏還定義了一個回調函數,在每次 epoch 結束的時候,我們保存一次模型,總共執行 20 個 epoch 。

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss)
checkpoint_dir = './my_training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_prefix, save_weights_only=True)
EPOCHS=20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

結果輸出:

Epoch 1/20
268/268 [==============================] - 12s 39ms/step - loss: 2.7113
Epoch 2/20
268/268 [==============================] - 11s 39ms/step - loss: 2.1106
...
Epoch 19/20
268/268 [==============================] - 11s 40ms/step - loss: 1.4723
Epoch 20/20
268/268 [==============================] - 11s 38ms/step - loss: 1.4668

4. 生成文本邏輯

(1)這裏為我們主要是定義了一個類,可以使用已經訓練好的模型進行文本生成的任務,在初始化的時候我們需要將字符到數字的映射 chars_from_ids,以及數字到字符的映射 ids_from_chars 都進行傳入。

(2)這裏需要注意的是我們新增了一個 prediction_mask ,最後將其與模型輸出的 predicted_logits 進行相加,其實就是將 [UNK] 對應概率降到無限小,這樣就不會在採樣的時候採集 [UNK] 。

(3)在進行預測時候我們只要把每個序列上的最後一個時間步的輸出拿到即可,這其實就是所有字符對應的概率分佈,我們只需要通過 categorical 函數進行隨機採樣,概率越大的字符被採集到的可能性越大。

class OneStep(tf.keras.Model):
    def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
        super().__init__()
        self.temperature = temperature
        self.model = model
        self.chars_from_ids = chars_from_ids
        self.ids_from_chars = ids_from_chars

        skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
        sparse_mask = tf.SparseTensor( values=[-float('inf')]*len(skip_ids), indices=skip_ids, dense_shape=[len(ids_from_chars.get_vocabulary())])
        self.prediction_mask = tf.sparse.to_dense(sparse_mask)

    @tf.function
    def generate_one_step(self, inputs, states=None):
        input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
        input_ids = self.ids_from_chars(input_chars).to_tensor()
        predicted_logits, states = self.model(inputs=input_ids, states=states, return_state=True)
        predicted_logits = predicted_logits[:, -1, :]
        predicted_logits = predicted_logits/self.temperature
        predicted_logits = predicted_logits + self.prediction_mask
        predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
        predicted_ids = tf.squeeze(predicted_ids, axis=-1)
        predicted_chars = self.chars_from_ids(predicted_ids)
        return predicted_chars, states

one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

5. 預測

(1)我們可以對一個樣本進行文本生成預測,也可以對批量的樣本進行文本預測工作。下面分別使用例子進行了效果展示。

(2)我們可以發現,在不仔細檢查的情況下,模型生成的文本在格式上和原作是類似的,而且也形成了“單詞”和“句子”,儘管有的根本壓根就不符合語法,想要增強效果的最簡單方法就是增大模型的(尤其是 GRU)的神經元個數,或者增加訓練的 epoch 次數。

states = None
next_char = tf.constant(['First Citizen:'])
result = [next_char]
for n in range(300):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)
result = tf.strings.join(result)
print(result[0].numpy().decode('utf-8'))

結果輸出:

First Citizen: I kome flower as murtelys bease her sovereign!

DUKE VINCENTIO:
More life, I say your pioused in joid thune:
I am crebles holy for lien'd; she will. If helps an Gaod questilford
And reive my hearted
At you be so but to-deaks' BAPtickly Romeo, myself then saddens my wiflious wine creple.
Now if you

進行批量預測:

states = None
next_char = tf.constant(['First Citizen:', 'Second Citizen:', 'Third Citizen:'])
result = [next_char]

for n in range(300):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result)

結果:

tf.Tensor(
[b"First Citizen: stors, not not-became mother, you reachtrall eight.\n\nBUCKINGHAM:\nI net\nShmo'ens from him thy haplay. So ready,\nCantent'd should now to thy keep upon thy king.\nWhat shall play you just-my mountake\nPanch his lord, ey? Of thou!\n\nDUKE VINCENTIO:\nThus vilided,\nSome side of this? I though he is heart the"
 b"Second Citizen:\nThen I'll were her thee exceacies even you laggined.\n\nHENRY BOLINGBROKE:\nMet like her safe.\n\nGLOUCESTER:\nSoet a spired\nThat withal?\n\nJULIET,\nA rable senul' thmest thou wilt the saper and a Came; or like a face shout thy worsh is tortument we shyaven?\nLet it take your at swails,\nAnd will cosoprorate"
 b'Third Citizen:\nDishall your wife, is thus?\n\nQUEEN ELIZABETH:\nNo morrot\nAny bring it bedies did be got have it,\nPervart put two food the gums: and my monst her,\nYou complike your noble lies. An must against man\nDreaming times on you.\nIt were you. I was charm on the contires in breath\nAs turning: gay, sir, Margaret'], shape=(3,), dtype=string)

6. 保存和讀取模型

我們對模型的權重進行保存,方便下次調用。

tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')

使用加載的模型進行文本生成預測。

states = None
next_char = tf.constant(['First Citizen:'])
result = [next_char]

for n in range(300):
    next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
    result.append(next_char)

print(tf.strings.join(result)[0].numpy().decode("utf-8"))

結果輸出:

First Citizen:
Let me shet
Of mortal prince! BJuiting late and fublings.
Art like could not, thou quiclay of all that changes
Whose himit offent and montagueing: therefore, and their ledion:
Proceed thank you; and never night.

GRUMIO:
Nell hath us to the friend'st though, sighness?

GLOUCESSE:
How'd hang
A littl