歡迎您光臨本站 註冊首頁

keras 解決加載lstm+crf模型出錯的問題

←手機掃碼閱讀     e36605 @ 2020-06-12 , reply:0

錯誤展示

new_model = load_model(“model.h5”)

報錯:

1、keras load_model valueError: Unknown Layer :CRF

2、keras load_model valueError: Unknown loss function:crf_loss

錯誤修改

1、load_model修改源碼:custom_objects = None 改為 def load_model(filepath, custom_objects, compile=True):

2、new_model = load_model(“model.h5”,custom_objects={‘CRF': CRF,‘crf_loss': crf_loss,‘crf_viterbi_accuracy': crf_viterbi_accuracy}

以上修改後,即可運行。

補充知識:用keras搭建bilstm crf

使用 https://github.com/keras-team/keras-contrib實現的crf layer,

安裝 keras-contrib

pip install git+https://www.github.com/keras-team/keras-contrib.git

Code Example:

  # coding: utf-8  from keras.models import Sequential  from keras.layers import Embedding  from keras.layers import LSTM  from keras.layers import Bidirectional  from keras.layers import Dense  from keras.layers import TimeDistributed  from keras.layers import Dropout  from keras_contrib.layers.crf import CRF  from keras_contrib.utils import save_load_utils    VOCAB_SIZE = 2500  EMBEDDING_OUT_DIM = 128  TIME_STAMPS = 100  HIDDEN_UNITS = 200  DROPOUT_RATE = 0.3  NUM_CLASS = 5    def build_embedding_bilstm2_crf_model():   """   帶embedding的雙向LSTM + crf   """   model = Sequential()   model.add(Embedding(VOCAB_SIZE, output_dim=EMBEDDING_OUT_DIM, input_length=TIME_STAMPS))   model.add(Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True)))   model.add(Dropout(DROPOUT_RATE))   model.add(Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True)))   model.add(Dropout(DROPOUT_RATE))   model.add(TimeDistributed(Dense(NUM_CLASS)))   crf_layer = CRF(NUM_CLASS)   model.add(crf_layer)   model.compile('rmsprop', loss=crf_layer.loss_function, metrics=[crf_layer.accuracy])   return model    def save_embedding_bilstm2_crf_model(model, filename):   save_load_utils.save_all_weights(model,filename)    def load_embedding_bilstm2_crf_model(filename):   model = build_embedding_bilstm2_crf_model()   save_load_utils.load_all_weights(model, filename)   return model    if __name__ == '__main__':   model = build_embedding_bilstm2_crf_model()

 

注意:

如果執行build模型報錯,則很可能是keras版本的問題。在keras-contrib==2.0.8且keras==2.0.8時,上面代碼不會報錯。


[e36605 ] keras 解決加載lstm+crf模型出錯的問題已經有263次圍觀

http://coctec.com/docs/python/shhow-post-238275.html