歡迎您光臨本站 註冊首頁

pytorch隨機採樣操作SubsetRandomSampler()

←手機掃碼閱讀     niceskyabc @ 2020-07-08 , reply:0

這篇文章記錄一個採樣器都隨機地從原始的數據集中抽樣數據。抽樣數據採用permutation。 生成任意一個下標重排,從而利用下標來提取dataset中的數據的方法

需要的庫

import torch

使用方法

這裡以MNIST舉例

  train_dataset = dsets.MNIST(root='./data', #文件存放路徑                train=True,  #提取訓練集                transform=transforms.ToTensor(), #將圖像轉化為Tensor                download=True)    sample_size = len(train_dataset)  sampler1 = torch.utils.data.sampler.SubsetRandomSampler(    np.random.choice(range(len(train_dataset)), sample_size))

 

代碼詳解

  np.random.choice()    #numpy.random.choice(a, size=None, replace=True, p=None)  #從a(只要是ndarray都可以,但必須是一維的)中隨機抽取數字,並組成指定大小(size)的數組  #replace:True表示可以取相同數字,False表示不可以取相同數字  #數組p:與數組a相對應,表示取數組a中每個元素的概率,默認為選取每個元素的概率相同。

 

那麼這裡就相當於抽取了一個全排列

  torch.utils.data.sampler.SubsetRandomSampler    # 會根據後面給的列表從數據集中按照下標取元素  # class torch.utils.data.SubsetRandomSampler(indices):無放回地按照給定的索引列表採樣樣本元素。

 

所以就可以了。

補充知識:Pytorch學習之torch----隨機抽樣、序列化、並行化

1. torch.manual_seed(seed)

說明:設置生成隨機數的種子,返回一個torch._C.Generator對象。使用隨機數種子之後,生成的隨機數是相同的。

參數:

seed(int or long) -- 種子

  >>> import torch  >>> torch.manual_seed(1)  <torch._C.Generator object at 0x0000019684586350>  >>> a = torch.rand(2, 3)  >>> a  tensor([[0.7576, 0.2793, 0.4031],      [0.7347, 0.0293, 0.7999]])  >>> torch.manual_seed(1)  <torch._C.Generator object at 0x0000019684586350>  >>> b = torch.rand(2, 3)  >>> b  tensor([[0.7576, 0.2793, 0.4031],      [0.7347, 0.0293, 0.7999]])  >>> a == b  tensor([[1, 1, 1],      [1, 1, 1]], dtype=torch.uint8)

 

2. torch.initial_seed()

說明:返回生成隨機數的原始種子值

  >>> torch.manual_seed(4)  <torch._C.Generator object at 0x0000019684586350>  >>> torch.initial_seed()  4

 

3. torch.get_rng_state()

說明:返回隨機生成器狀態(ByteTensor)

  >>> torch.initial_seed()  4  >>> torch.get_rng_state()  tensor([4, 0, 0, ..., 0, 0, 0], dtype=torch.uint8)

 

4. torch.set_rng_state()

說明:設定隨機生成器狀態

參數:

new_state(ByteTensor) -- 期望的狀態

5. torch.default_generator

說明:默認的隨機生成器。等於<torch._C.Generator object>

6. torch.bernoulli(input, out=None)

說明:從伯努利分佈中抽取二元隨機數(0或1)。輸入張量包含用於抽取二元值的概率。因此,輸入中的所有值都必須在[0,1]區間內。輸出張量的第i個元素值,將會以輸入張量的第i個概率值等於1。返回值將會是與輸入相同大小的張量,每個值為0或者1.

參數:

input(Tensor) -- 輸入為伯努利分佈的概率值

out(Tensor,可選) -- 輸出張量

  >>> a = torch.Tensor(3, 3).uniform_(0, 1)  >>> a  tensor([[0.5596, 0.5591, 0.0915],      [0.2100, 0.0072, 0.0390],      [0.9929, 0.9131, 0.6186]])  >>> torch.bernoulli(a)  tensor([[0., 1., 0.],      [0., 0., 0.],      [1., 1., 1.]])

 

7. torch.multinomial(input, num_samples, replacement=False, out=None)

說明:返回一個張量,每行包含從input相應行中定義的多項分佈中抽取的num_samples個樣本。要求輸入input每行的值不需要總和為1,但是必須非負且總和不能為0。當抽取樣本時,依次從左到右排列(第一個樣本對應第一列)。如果輸入input是一個向量,輸出out也是一個相同長度num_samples的向量。如果輸入input是m行的矩陣,輸出out是形如m x n的矩陣。並且如果參數replacement為True,則樣本抽取可以重複。否則,一個樣本在每行不能被重複。

參數:

input(Tensor) -- 包含概率的張量

num_samples(int) -- 抽取的樣本數

replacement(bool) -- 布爾值,決定是否能重複抽取

out(Tensor) -- 結果張量

  >>> weights = torch.Tensor([0, 10, 3, 0])  >>> weights  tensor([ 0., 10., 3., 0.])  >>> torch.multinomial(weights, 4, replacement=True)  tensor([1, 1, 1, 1])

 

8. torch.normal(means, std, out=None)

說明:返回一個張量,包含從給定參數means,std的離散正態分佈中抽取隨機數。均值means是一個張量,包含每個輸出元素相關的正態分佈的均值。std是一個張量。包含每個輸出元素相關的正態分佈的標準差。均值和標準差的形狀不須匹配,但每個張量的元素個數必須想聽。

參數:

means(Tensor) -- 均值

std(Tensor) -- 標準差

out(Tensor) -- 輸出張量

  >>> n_data = torch.ones(5, 2)  >>> n_data  tensor([[1., 1.],      [1., 1.],      [1., 1.],      [1., 1.],      [1., 1.]])  >>> x0 = torch.normal(2 * n_data, 1)  >>> x0  tensor([[1.6544, 0.9805],      [2.1114, 2.7113],      [1.0646, 1.9675],      [2.7652, 3.2138],      [1.1204, 2.0293]])

 

9. torch.save(obj, f, pickle_module=<module 'pickle' from '/home/lzjs/...)

說明:保存一個對象到一個硬盤文件上。

參數:

obj -- 保存對象

f -- 類文件對象或一個保存文件名的字符串

pickle_module -- 用於pickling源數據和對象的模塊

pickle_protocol -- 指定pickle protocal可以覆蓋默認參數

10. torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/home/lzjs/...)

說明:從磁盤文件中讀取一個通過torch.save()保存的對象。torch.load()可通過參數map_location動態地進行內存重映射,使其能從不動設備中讀取文件。一般調用時,需兩個參數:storage和location tag。返回不同地址中的storage,或者返回None。如果這個參數是字典的話,意味著從文件的地址標記到當前系統的地址標記的映射。

參數:

f -- l類文件對象或一個保存文件名的字符串

map_location -- 一個函數或字典規定如何remap存儲位置

pickle_module -- 用於unpickling元數據和對象的模塊

  torch.load('tensors.pt')  # 加載所有的張量到CPU  torch.load('tensor.pt', map_location=lambda storage, loc:storage)  # 加載張量到GPU  torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

 

11. torch.get_num_threads()

說明:獲得用於並行化CPU操作的OpenMP線程數

12. torch.set_num_threads()

說明:設定用於並行化CPU操作的OpenMP線程數

 

   


[niceskyabc ] pytorch隨機採樣操作SubsetRandomSampler()已經有868次圍觀

http://coctec.com/docs/python/shhow-post-241834.html