歡迎您光臨本站 註冊首頁

Pytorch損失函數nn.NLLLoss2d()用法說明

←手機掃碼閱讀     kyec555 @ 2020-07-08 , reply:0

最近做顯著星檢測用到了NLL損失函數

對於NLL函數,需要自己計算log和softmax的概率值,然後從才能作為輸入

輸入 [batch_size, channel , h, w]

目標 [batch_size, h, w]

輸入的目標矩陣,每個像素必須是類型.舉個例子。第一個像素是0,代表著類別屬於輸入的第1個通道;第二個像素是0,代表著類別屬於輸入的第0個通道,以此類推。

  x = Variable(torch.Tensor([[[1, 2, 1],         [2, 2, 1],         [0, 1, 1]],         [[0, 1, 3],         [2, 3, 1],         [0, 0, 1]]]))    x = x.view([1, 2, 3, 3])  print("x輸入", x)

 

這裡輸入x,並改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然後使用softmax函數計算每個類別的概率,這裡dim=1表示從在1維度

上計算,也就是channel維度。logsoftmax是計算完softmax後在計算log值

手動計算舉個栗子:第一個元素

  y = Variable(torch.LongTensor([[1, 0, 1],         [0, 0, 1],         [1, 1, 1]]))    y = y.view([1, 3, 3])

 

輸入label y,改變成[batch_size, h, w]格式

  loss = nn.NLLLoss2d()  out = loss(x, y)  print(out)

 

輸入函數,得到loss=0.7947

來手動計算

第一個label=1,則 loss=-1.3133

第二個label=0, 則loss=-0.3133

  .  …  …  loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

 

是一致的

注意:這個函數會對每個像素做平均,每個batch也會做平均,這裡有9個像素,1個batch_size。

補充知識:PyTorch:NLLLoss2d

我就廢話不多說了,大家還是直接看代碼吧~

  import torch  import torch.nn as nn  from torch import autograd  import torch.nn.functional as F     inputs_tensor = torch.FloatTensor([  [[2, 4],   [1, 2]],  [[5, 3],   [3, 0]],  [[5, 3],   [5, 2]],  [[4, 2],   [3, 2]],   ])  inputs_tensor = torch.unsqueeze(inputs_tensor,0)  # inputs_tensor = torch.unsqueeze(inputs_tensor,1)  print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape     targets_tensor = torch.LongTensor([   [0, 2],   [2, 3]  ])     targets_tensor = torch.unsqueeze(targets_tensor,0)  print '--target size(nBatch x height x width): ', targets_tensor.shape     inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)  inputs_variable = F.log_softmax(inputs_variable)  targets_variable = autograd.Variable(targets_tensor)     loss = nn.NLLLoss2d()  output = loss(inputs_variable, targets_variable)  print '--NLLLoss2d: {}'.format(output)

 


[kyec555 ] Pytorch損失函數nn.NLLLoss2d()用法說明已經有250次圍觀

http://coctec.com/docs/python/shhow-post-241816.html