pytorch tensorのブロック化操作後に平均コードをマージ
2070 ワード
for i, data in enumerate(self.dataloader['test']): # i denote the index of the batch?
print(">> Testing model. Batch: {}".format(i + 1))
B=data[0].shape[0]
C=data[0].shape[1]
H=data[0].shape[2]
W=data[0].shape[3]
kernel_size = 128
stride = 32
nb_patches_h=int((H-kernel_size)/stride+1)
nb_patches_w=int((W-kernel_size)/stride+1)
input=data[0].to(self.device)
mask=torch.ones(input.size())
# [B,C,nb_patches_h,nb_patches_w,kernel_size,kernel_size]
patches=input.unfold(2,kernel_size,stride).unfold(3,kernel_size,stride)
mask_patches=mask.unfold(2,kernel_size,stride).unfold(3,kernel_size,stride)
#patches [64,1,15,15,128,128]
# perform the operations on each sub-tensor here
re_patches=torch.empty(size=patches.size())
for indx in range(0,nb_patches_h):
for indy in range(0,nb_patches_w):
curPatches4D=patches[:,:,indx,indy,:,:]
curOutput,_=self.myModel(curPatches4D)
re_patches[:,:,indx,indy,:,:]=curOutput
re_patches=re_patches.contiguous().view(B,C,-1,kernel_size*kernel_size)
mask_patches=mask_patches.contiguous().view(B,C,-1,kernel_size*kernel_size)
re_patches=re_patches.permute(0,1,3,2)
mask_patches=mask_patches.permute(0,1,3,2)
re_patches=re_patches.contiguous().view(B,C*kernel_size*kernel_size,-1)
mask_patches=mask_patches.contiguous().view(B,C*kernel_size*kernel_size,-1)
output=F.fold(re_patches,output_size=(H,W),kernel_size=kernel_size,stride=stride)
mask_output=F.fold(mask_patches,output_size=(H,W),kernel_size=kernel_size,stride=stride)
output=output/mask_output