歡迎您光臨本站 註冊首頁

MxNet預訓練模型到Pytorch模型的轉換方式

←手機掃碼閱讀     qp18502452 @ 2020-06-08 , reply:0

預訓練模型在不同深度學習框架中的轉換是一種常見的任務。今天剛好DPN預訓練模型轉換問題,順手將這個過程記錄一下。

核心轉換函數如下所示:

  def convert_from_mxnet(model, checkpoint_prefix, debug=False):   _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)   remapped_state = {}   for state_key in model.state_dict().keys():    k = state_key.split('.')    aux = False    mxnet_key = ''    if k[0] == 'features':     if k[1] == 'conv1_1':      # input block      mxnet_key += 'conv1_x_1__'      if k[2] == 'bn':       mxnet_key += 'relu-sp__bn_'       aux, key_add = _convert_bn(k[3])       mxnet_key += key_add      else:       assert k[3] == 'weight'       mxnet_key += 'conv_' + k[3]     elif k[1] == 'conv5_bn_ac':      # bn + ac at end of features block      mxnet_key += 'conv5_x_x__relu-sp__bn_'      assert k[2] == 'bn'      aux, key_add = _convert_bn(k[3])      mxnet_key += key_add     else:      # middle blocks      if model.b and 'c1x1_c' in k[2]:       bc_block = True # b-variant split c-block special treatment      else:       bc_block = False      ck = k[1].split('_')      mxnet_key += ck[0] + '_x__' + ck[1] + '_'      ck = k[2].split('_')      mxnet_key += ck[0] + '-' + ck[1]      if ck[1] == 'w' and len(ck) > 2:       mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'      mxnet_key += '__'      if k[3] == 'bn':       mxnet_key += 'bn_' if bc_block else 'bn__bn_'       aux, key_add = _convert_bn(k[4])       mxnet_key += key_add      else:       ki = 3 if bc_block else 4       assert k[ki] == 'weight'       mxnet_key += 'conv_' + k[ki]    elif k[0] == 'classifier':     if 'fc6-1k_weight' in mxnet_weights:      mxnet_key += 'fc6-1k_'     else:      mxnet_key += 'fc6_'     mxnet_key += k[1]    else:     assert False, 'Unexpected token'       if debug:     print(mxnet_key, '=> ', state_key, end=' ')       mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]    torch_tensor = torch.from_numpy(mxnet_array.asnumpy())    if k[0] == 'classifier' and k[1] == 'weight':     torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))    remapped_state[state_key] = torch_tensor       if debug:     print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())      model.load_state_dict(remapped_state)      return model

 

從中可以看出,其轉換步驟如下:

(1)創建pytorch的網絡結構模型,設為model

(2)利用mxnet來讀取其存儲的預訓練模型,得到mxnet_weights;

(3)遍歷加載後模型mxnet_weights的state_dict().keys

(4)對一些指定的key值,需要進行相應的處理和轉換

(5)對修改鍵名之後的key利用numpy之間的轉換來實現加載。

為了實現上述轉換,首先pip安裝mxnet,現在新版的mxnet安裝還是非常方便的。

第二步,運行轉換程序,實現預訓練模型的轉換。

可以看到在相當的文件夾下已經出現了轉換後的模型。

 

   


[qp18502452 ] MxNet預訓練模型到Pytorch模型的轉換方式已經有232次圍觀

http://coctec.com/docs/python/shhow-post-237495.html