Pytorchテクニック1:DataLoaderのcollate_fnパラメータ

8130 ワード

Pytorchテクニック1:DataLoaderのcollate_fnパラメータ


DataLoaderのcollate_について説明します.fnパラメータは、カスタムbatch出力を実現します.
  • Pytorchテクニック1:DataLoaderのcollate_fnパラメータ
  • DataLoaderの完全なパラメータ表は次のとおりです:
  • 試験例

  • DataLoaderの完全なパラメータ表は次のとおりです。


    DataLoaderの完全なパラメータ表は次のとおりです.
    class torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=0,
        collate_fn=,
        pin_memory=False,
        drop_last=False,
        timeout=0,
        worker_init_fn=None)

    DataLoaderはデータセットに単一プロセスまたはマルチプロセスの反復器のいくつかの重要なパラメータを提供します.意味:-shuffle:Trueに設定すると、各世代がデータセット-collate_を乱します.fn:どのようにサンプルを取るか、私たちは自分の関数を定義して、目的の機能-drop_を正確に実現することができます.Last:batch_を除くデータセットの長さの処理方法を示します.size残りのデータ.Trueは捨てて、さもなくば保留します

    テストの例

    import torch
    import torch.utils.data as Data
    import numpy as np
    
    test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])
    
    inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
    target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
    
    torch_dataset = Data.TensorDataset(inputing,target)
    batch = 3
    
    loader = Data.DataLoader(
        dataset=torch_dataset,
        batch_size=batch, #  
        #  dataset batch_size , 
        collate_fn=lambda x:(
            torch.cat(
                [x[i][j].unsqueeze(0) for i in range(len(x))], 0
                ).unsqueeze(0) for j in range(len(x[0]))
            )
        )
    
    for (i,j) in loader:
        print(i)
        print(j)

    出力結果:
    tensor([[[ 0,  1,  2],
             [ 1,  2,  3],
             [ 2,  3,  4]]], dtype=torch.int32)
    tensor([[[ 0],
             [ 1],
             [ 2]]], dtype=torch.int32)
    tensor([[[ 3,  4,  5],
             [ 4,  5,  6],
             [ 5,  6,  7]]], dtype=torch.int32)
    tensor([[[ 3],
             [ 4],
             [ 5]]], dtype=torch.int32)
    tensor([[[  6,   7,   8],
             [  7,   8,   9],
             [  8,   9,  10]]], dtype=torch.int32)
    tensor([[[ 6],
             [ 7],
             [ 8]]], dtype=torch.int32)
    tensor([[[  9,  10,  11]]], dtype=torch.int32)
    tensor([[[ 9]]], dtype=torch.int32)

    collate_が要らないならfnの値、出力は
    tensor([[ 0,  1,  2],
            [ 1,  2,  3],
            [ 2,  3,  4]], dtype=torch.int32)
    tensor([[ 0],
            [ 1],
            [ 2]], dtype=torch.int32)
    tensor([[ 3,  4,  5],
            [ 4,  5,  6],
            [ 5,  6,  7]], dtype=torch.int32)
    tensor([[ 3],
            [ 4],
            [ 5]], dtype=torch.int32)
    tensor([[  6,   7,   8],
            [  7,   8,   9],
            [  8,   9,  10]], dtype=torch.int32)
    tensor([[ 6],
            [ 7],
            [ 8]], dtype=torch.int32)
    tensor([[  9,  10,  11]], dtype=torch.int32)
    tensor([[ 9]], dtype=torch.int32)

    だからcollate_fnは結果を複数次元にすることである.見てごらんfnの値はどういう意味ですか.次のように変更しました
    collate_fn=lambda x:x

    並列出力
    for i in loader:
        print(i)

    結果を得る
    [(tensor([ 0,  1,  2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1,  2,  3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2,  3,  4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
    [(tensor([ 3,  4,  5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4,  5,  6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5,  6,  7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
    [(tensor([ 6,  7,  8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7,  8,  9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([  8,   9,  10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
    [(tensor([  9,  10,  11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

    各iはリストであり、各リストにはbatch_が含まれる各タプルにはTensorDatasetの個別データが含まれているsize個のタプル.したがって、各batchに1*3*3を含むinputと1*3*1を含むtargetに再結合するには、パッケージを再解除してパッケージ化します.私たちのcollateを見てください.fn:
    collate_fn=lambda x:(
        torch.cat(
            [x[i][j].unsqueeze(0) for i in range(len(x))], 0
            ).unsqueeze(0) for j in range(len(x[0]))
        )

    jはinputとtargetの2つの変数をとる.i取ったのはbatch_size.次にunsqueeze(0)メソッドで1次元を前に追加します.torch.cat(,0)をパッケージ化します.次にunsqueeze(0)メソッドで前に1次元を追加します.完了します.