pytorch DataLoaderカスタムデータセット
4836 ワード
pytorchはmini-batchのデータを生成し、トレーニングやテスト時にマルチスレッド処理を行い、データの準備を速めるデータ処理方式を提供しています.この関数ツールは
その中でDatasetは私たちが自分のマルチスレッドデータ処理フレームワークを定義する親クラスで、私たちが定義したフレームワークはこのクラスの下で簡単にデータ準備を定義するフレームワークを継承しましょう!!!
次に、データと組み合わせて、次のコードを説明します.
train_を呼び出すことができますloaderとtest_loaderはmini-batchのデータを呼び出し、batch_sizeはtrain_loaderとtest_loaderの値はすでに設定されており、以下のコードで一括して呼び出されます.
torch.utils.data import Dataset, DataLoader
その中でDatasetは私たちが自分のマルチスレッドデータ処理フレームワークを定義する親クラスで、私たちが定義したフレームワークはこのクラスの下で簡単にデータ準備を定義するフレームワークを継承しましょう!!!
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def __init__(self, filepath, transform=None,keys = None, target_transform=None):
pass
'''
,filepath ,transform (features) ,target_transform (labels) ,keys , , , ndarray ,
'''
def __getitem__(self,index):
pass
def __len__(self):
pass
次に、データと組み合わせて、次のコードを説明します.
class MyDataset(Dataset):
def __init__(self, filepath, transform=None,keys = None, target_transform=None):
with open(filepath,'rb') as f:
self.data = pickle.load(f)
self.keys = keys
self.input_seq = self.data[self.keys[0]] ###
self.output_seq = self.data[self.keys[1]] ####
self.transform = transform ####
self.target_transform = target_transform ######
def __getitem__(self, index):
input_seq,output_seq = self.input_seq[index],self.output_seq[index] ##
if self.transform is not None:
input_seq = self.transform(input_seq)
output_seq = self.transform(output_seq)
return input_seq,output_seq ###
def __len__(self):
return self.data[self.keys[0]].shape[0] ### ,
train_data = MyDataset(filepath = 'train30.pickle',keys = ['aa','bb'])
test_data = MyDataset(filepath = 'test30.pickle',keys = ['aa','bb'])
train_loader = DataLoader(dataset = train_data,batch_size = 32,shuffle = False)
test_loader = DataLoader(dataset = test_data,batch_size = 32,shuffle = False)
train_を呼び出すことができますloaderとtest_loaderはmini-batchのデータを呼び出し、batch_sizeはtrain_loaderとtest_loaderの値はすでに設定されており、以下のコードで一括して呼び出されます.
from torch.autograd import Variable
for i,(input_seq,out_seq) in enumerate(train_loader):
input_seq = Variable(input_seq.cuda())
output_seq = Variable(output_seq.cuda())