1. 程式人生 > >Tensorflow nmt的資料預處理過程

Tensorflow nmt的資料預處理過程

tensorflow nmt的資料預處理過程  

在tensorflow/nmt專案中,訓練資料和推斷資料的輸入使用了新的Dataset API,應該是tensorflow 1.2之後引入的API,方便資料的操作。如果你還在使用老的Queue和Coordinator的方式,建議升級高版本的tensorflow並且使用Dataset API。

本教程將從訓練資料推斷資料兩個方面,詳解解析資料的具體處理過程,你將看到文字資料如何轉化為模型所需要的實數,以及中間的張量的維度是怎麼樣的,batch_size和其他超引數又是如何作用的。

訓練資料的處理

先來看看訓練資料的處理。訓練資料的處理比推斷資料的處理稍微複雜一些,弄懂了訓練資料的處理過程,就可以很輕鬆地理解推斷資料的處理。
訓練資料

的處理程式碼位於nmt/utils/iterator_utils.py檔案內的get_iterator函式。我們先來看看這個函式所需要的引數是什麼意思:

引數 解釋
src_dataset 源資料集
tgt_dataset 目標資料集
src_vocab_table 源資料單詞查詢表,就是個單詞和int型別資料的對應表
tgt_vocab_table 目標資料單詞查詢表,就是個單詞和int型別資料的對應表
batch_size 批大小
sos 句子開始標記
eos 句子結尾標記
random_seed
隨機種子,用來打亂資料集的
num_buckets 桶數量
src_max_len 源資料最大長度
tgt_max_len 目標資料最大長度
num_parallel_calls 併發處理資料的併發數
output_buffer_size 輸出緩衝區大小
skip_count 跳過資料行數
num_shards 將資料集分片的數量,分散式訓練中有用
shard_index 資料集分片後的id
reshuffle_each_iteration 是否每次迭代都重新打亂順序

上面的解釋,如果有不清楚的,可以檢視我之前一片介紹超引數的文章:

tensorflow_nmt的超引數詳解

該函式處理訓練資料的主要程式碼如下:

if not output_buffer_size:
    output_buffer_size = batch_size * 1000
  src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
  tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
  tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)

  src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))

  src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)
  if skip_count is not None:
    src_tgt_dataset = src_tgt_dataset.skip(skip_count)

  src_tgt_dataset = src_tgt_dataset.shuffle(
      output_buffer_size, random_seed, reshuffle_each_iteration)

  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (
          tf.string_split([src]).values, tf.string_split([tgt]).values),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)

  # Filter zero length input sequences.
  src_tgt_dataset = src_tgt_dataset.filter(
      lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))

  if src_max_len:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (src[:src_max_len], tgt),
        num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  if tgt_max_len:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (src, tgt[:tgt_max_len]),
        num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Convert the word strings to ids.  Word strings that are not in the
  # vocab get the lookup table's default_value integer.
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
                        tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (src,
                        tf.concat(([tgt_sos_id], tgt), 0),
                        tf.concat((tgt, [tgt_eos_id]), 0)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Add in sequence lengths.
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt_in, tgt_out: (
          src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)

我們逐步來分析,這個過程到底做了什麼,資料張量又是如何變化的。

如何對齊資料

num_buckets到底起什麼作用

num_buckets起作用的程式碼如下:  

 if num_buckets > 1:

    def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
      # Calculate bucket_width by maximum source sequence length.
      # Pairs with length [0, bucket_width) go to bucket 0, length
      # [bucket_width, 2 * bucket_width) go to bucket 1, etc.  Pairs with length
      # over ((num_bucket-1) * bucket_width) words all go into the last bucket.
      if src_max_len:
        bucket_width = (src_max_len + num_buckets - 1) // num_buckets
      else:
        bucket_width = 10

      # Bucket sentence pairs by the length of their source sentence and target
      # sentence.
      bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
      return tf.to_int64(tf.minimum(num_buckets, bucket_id))

    def reduce_func(unused_key, windowed_data):
      return batching_func(windowed_data)

    batched_dataset = src_tgt_dataset.apply(
        tf.contrib.data.group_by_window(
            key_func=key_func, reduce_func=reduce_func, window_size=batch_size))