歡迎您光臨本站 註冊首頁

淺談tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意點

←手機掃碼閱讀     niceskyabc @ 2020-06-12 , reply:0

batch很好理解,就是batch size。注意在一個epoch中最後一個batch大小可能小於等於batch size
 

dataset.repeat就是俗稱epoch,但在tf中與dataset.shuffle的使用順序可能會導致個epoch的混合
 

dataset.shuffle就是說維持一個buffer size 大小的 shuffle buffer,圖中所需的每個樣本從shuffle buffer中獲取,取得一個樣本後,就從源資料集中加入一個樣本到shuffle buffer中。

  import os  os.environ['CUDA_VISIBLE_DEVICES'] = ""  import numpy as np  import tensorflow as tf  np.random.seed(0)  x = np.random.sample((11,2))  # make a dataset from a numpy array  print(x)  print()  dataset = tf.data.Dataset.from_tensor_slices(x)  dataset = dataset.shuffle(3)  dataset = dataset.batch(4)  dataset = dataset.repeat(2)    # create the iterator  iter = dataset.make_one_shot_iterator()  el = iter.get_next()    with tf.Session() as sess:    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))

 

  #源資料集  [[ 0.5488135  0.71518937]   [ 0.60276338 0.54488318]   [ 0.4236548  0.64589411]   [ 0.43758721 0.891773 ]   [ 0.96366276 0.38344152]   [ 0.79172504 0.52889492]   [ 0.56804456 0.92559664]   [ 0.07103606 0.0871293 ]   [ 0.0202184  0.83261985]   [ 0.77815675 0.87001215]   [ 0.97861834 0.79915856]]    # 透過shuffle batch後取得的樣本  [[ 0.4236548  0.64589411]   [ 0.60276338 0.54488318]   [ 0.43758721 0.891773 ]   [ 0.5488135  0.71518937]]  [[ 0.96366276 0.38344152]   [ 0.56804456 0.92559664]   [ 0.0202184  0.83261985]   [ 0.79172504 0.52889492]]  [[ 0.07103606 0.0871293 ]   [ 0.97861834 0.79915856]   [ 0.77815675 0.87001215]] #最後一個batch樣本個數為3  [[ 0.60276338 0.54488318]   [ 0.5488135  0.71518937]   [ 0.43758721 0.891773 ]   [ 0.79172504 0.52889492]]  [[ 0.4236548  0.64589411]   [ 0.56804456 0.92559664]   [ 0.0202184  0.83261985]   [ 0.07103606 0.0871293 ]]  [[ 0.77815675 0.87001215]   [ 0.96366276 0.38344152]   [ 0.97861834 0.79915856]] #最後一個batch樣本個數為3

 

1、按照shuffle中設定的buffer size,首先從源資料集取得三個樣本:
 shuffle buffer:
 [ 0.5488135 0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548 0.64589411]
 2、從buffer中取一個樣本到batch中得:
 shuffle buffer:
 [ 0.5488135 0.71518937]
 [ 0.60276338 0.54488318]
 batch:
 [ 0.4236548 0.64589411]
 3、shuffle buffer不足三個樣本,從源資料集提取一個樣本:
 shuffle buffer:
 [ 0.5488135 0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.43758721 0.891773 ]
 4、從buffer中取一個樣本到batch中得:
 shuffle buffer:
 [ 0.5488135 0.71518937]
 [ 0.43758721 0.891773 ]
 batch:
 [ 0.4236548 0.64589411]
 [ 0.60276338 0.54488318]
 5、如此反覆。這就意味中如果shuffle 的buffer size=1,資料集不打亂。如果shuffle 的buffer size=資料集樣本數量,隨機打亂整個資料集

  import os  os.environ['CUDA_VISIBLE_DEVICES'] = ""  import numpy as np  import tensorflow as tf  np.random.seed(0)  x = np.random.sample((11,2))  # make a dataset from a numpy array  print(x)  print()  dataset = tf.data.Dataset.from_tensor_slices(x)  dataset = dataset.shuffle(1)  dataset = dataset.batch(4)  dataset = dataset.repeat(2)    # create the iterator  iter = dataset.make_one_shot_iterator()  el = iter.get_next()    with tf.Session() as sess:    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    [[ 0.5488135  0.71518937]   [ 0.60276338 0.54488318]   [ 0.4236548  0.64589411]   [ 0.43758721 0.891773 ]   [ 0.96366276 0.38344152]   [ 0.79172504 0.52889492]   [ 0.56804456 0.92559664]   [ 0.07103606 0.0871293 ]   [ 0.0202184  0.83261985]   [ 0.77815675 0.87001215]   [ 0.97861834 0.79915856]]    [[ 0.5488135  0.71518937]   [ 0.60276338 0.54488318]   [ 0.4236548  0.64589411]   [ 0.43758721 0.891773 ]]  [[ 0.96366276 0.38344152]   [ 0.79172504 0.52889492]   [ 0.56804456 0.92559664]   [ 0.07103606 0.0871293 ]]  [[ 0.0202184  0.83261985]   [ 0.77815675 0.87001215]   [ 0.97861834 0.79915856]]  [[ 0.5488135  0.71518937]   [ 0.60276338 0.54488318]   [ 0.4236548  0.64589411]   [ 0.43758721 0.891773 ]]  [[ 0.96366276 0.38344152]   [ 0.79172504 0.52889492]   [ 0.56804456 0.92559664]   [ 0.07103606 0.0871293 ]]  [[ 0.0202184  0.83261985]   [ 0.77815675 0.87001215]   [ 0.97861834 0.79915856]]

 

注意如果repeat在shuffle之前使用:
 

官方說repeat在shuffle之前使用能提高效能,但模糊了資料樣本的epoch關係

  import os  os.environ['CUDA_VISIBLE_DEVICES'] = ""  import numpy as np  import tensorflow as tf  np.random.seed(0)  x = np.random.sample((11,2))  # make a dataset from a numpy array  print(x)  print()  dataset = tf.data.Dataset.from_tensor_slices(x)  dataset = dataset.repeat(2)  dataset = dataset.shuffle(11)  dataset = dataset.batch(4)    # create the iterator  iter = dataset.make_one_shot_iterator()  el = iter.get_next()    with tf.Session() as sess:    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    print(sess.run(el))    [[ 0.5488135  0.71518937]   [ 0.60276338 0.54488318]   [ 0.4236548  0.64589411]   [ 0.43758721 0.891773 ]   [ 0.96366276 0.38344152]   [ 0.79172504 0.52889492]   [ 0.56804456 0.92559664]   [ 0.07103606 0.0871293 ]   [ 0.0202184  0.83261985]   [ 0.77815675 0.87001215]   [ 0.97861834 0.79915856]]    [[ 0.56804456 0.92559664]   [ 0.5488135  0.71518937]   [ 0.60276338 0.54488318]   [ 0.07103606 0.0871293 ]]  [[ 0.96366276 0.38344152]   [ 0.43758721 0.891773 ]   [ 0.43758721 0.891773 ]   [ 0.77815675 0.87001215]]  [[ 0.79172504 0.52889492]  #出現相同樣本出現在同一個batch中   [ 0.79172504 0.52889492]   [ 0.60276338 0.54488318]   [ 0.4236548  0.64589411]]  [[ 0.07103606 0.0871293 ]   [ 0.4236548  0.64589411]   [ 0.96366276 0.38344152]   [ 0.5488135  0.71518937]]  [[ 0.97861834 0.79915856]   [ 0.0202184  0.83261985]   [ 0.77815675 0.87001215]   [ 0.56804456 0.92559664]]  [[ 0.0202184  0.83261985]   [ 0.97861834 0.79915856]]     #可以看到最後個batch為2,而前面都是4

 

使用案例:
 

  def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):    print('Parsing', filenames)    def decode_libsvm(line):      #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)      #features = dict(zip(CSV_COLUMNS, columns))      #labels = features.pop(LABEL_COLUMN)      columns = tf.string_split([line], ' ')      labels = tf.string_to_number(columns.values[0], out_type=tf.float32)      splits = tf.string_split(columns.values[1:], ':')      id_vals = tf.reshape(splits.values,splits.dense_shape)      feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)      feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)      feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)      #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])      #for i in range(splits.dense_shape.eval()[0]):      #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))      #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))      #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels      return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels      # Extract lines from input files using the Dataset API, can pass one filename or filename list    dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)  # multi-thread pre-process then prefetch      # Randomizes input using a window of 256 elements (read into memory)    if perform_shuffle:      dataset = dataset.shuffle(buffer_size=256)      # epochs from blending together.    dataset = dataset.repeat(num_epochs)    dataset = dataset.batch(batch_size) # Batch size to use      #return dataset.make_one_shot_iterator()    iterator = dataset.make_one_shot_iterator()    batch_features, batch_labels = iterator.get_next()    #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels    return batch_features, batch_labels

 

      

   


[niceskyabc ] 淺談tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意點已經有236次圍觀

http://coctec.com/docs/program/show-post-238259.html