歡迎您光臨本站 註冊首頁

解決Keras TensorFlow 混編中 trainable=False設置無效問題

←手機掃碼閱讀     zmcjlove @ 2020-06-29 , reply:0

這是最近碰到一個問題,先描述下問題:

首先我有一個訓練好的模型(例如vgg16),我要對這個模型進行一些改變,例如添加一層全連接層,用於種種原因,我只能用TensorFlow來進行模型優化,tf的優化器,默認情況下對所有tf.trainable_variables()進行權值更新,問題就出在這,明明將vgg16的模型設置為trainable=False,但是tf的優化器仍然對vgg16做權值更新

以上就是問題描述,經過谷歌百度等等,終於找到了解決辦法,下面我們一點一點的來複原整個問題。

trainable=False 無效

首先,我們導入訓練好的模型vgg16,對其設置成trainable=False

  from keras.applications import VGG16  import tensorflow as tf  from keras import layers

 

  # 導入模型  base_mode = VGG16(include_top=False)  # 查看可訓練的變量  tf.trainable_variables()

 

  [,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,]

 

  # 設置 trainable=False  # base_mode.trainable = False似乎也是可以的  for layer in base_mode.layers:    layer.trainable = False

 

設置好trainable=False後,再次查看可訓練的變量,發現並沒有變化,也就是說設置無效

# 再次查看可訓練的變量
 tf.trainable_variables()

  [,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,]

 

解決的辦法

解決的辦法就是在導入模型的時候建立一個variable_scope,將需要訓練的變量放在另一個variable_scope,然後通過tf.get_collection獲取需要訓練的變量,最後通過tf的優化器中var_list指定需要訓練的變量

  from keras import models  with tf.variable_scope('base_model'):    base_model = VGG16(include_top=False, input_shape=(224,224,3))  with tf.variable_scope('xxx'):    model = models.Sequential()    model.add(base_model)    model.add(layers.Flatten())    model.add(layers.Dense(10))

 

  # 獲取需要訓練的變量  trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')  trainable_var

 

[,
 ]

  # 定義tf優化器進行訓練,這裡假設有一個loss  loss = model.output / 2; # 隨便定義的,方便演示  train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)

 

總結

在keras與TensorFlow混編中,keras中設置trainable=False對於TensorFlow而言並不起作用

解決的辦法就是通過variable_scope對變量進行區分,在通過tf.get_collection來獲取需要訓練的變量,最後通過tf優化器中var_list指定訓練



[zmcjlove ] 解決Keras TensorFlow 混編中 trainable=False設置無效問題已經有222次圍觀

http://coctec.com/docs/python/shhow-post-240282.html