pytorch tensorのブロック化操作後に平均コードをマージ

2070 ワード

for i, data in enumerate(self.dataloader['test']):  # i denote the index of the batch?
                print(">> Testing model. Batch: {}".format(i + 1))

                B=data[0].shape[0]
                C=data[0].shape[1]
                H=data[0].shape[2]
                W=data[0].shape[3]

                kernel_size = 128
                stride = 32

                nb_patches_h=int((H-kernel_size)/stride+1)
                nb_patches_w=int((W-kernel_size)/stride+1)

                input=data[0].to(self.device)
                mask=torch.ones(input.size())

                # [B,C,nb_patches_h,nb_patches_w,kernel_size,kernel_size]
                patches=input.unfold(2,kernel_size,stride).unfold(3,kernel_size,stride)
                mask_patches=mask.unfold(2,kernel_size,stride).unfold(3,kernel_size,stride)

                #patches [64,1,15,15,128,128]
                # perform the operations on each sub-tensor here
                re_patches=torch.empty(size=patches.size())
                for indx in range(0,nb_patches_h):
                    for indy in range(0,nb_patches_w):
                        curPatches4D=patches[:,:,indx,indy,:,:]
                        curOutput,_=self.myModel(curPatches4D)
                        re_patches[:,:,indx,indy,:,:]=curOutput

                re_patches=re_patches.contiguous().view(B,C,-1,kernel_size*kernel_size)
                mask_patches=mask_patches.contiguous().view(B,C,-1,kernel_size*kernel_size)

                re_patches=re_patches.permute(0,1,3,2)
                mask_patches=mask_patches.permute(0,1,3,2)

                re_patches=re_patches.contiguous().view(B,C*kernel_size*kernel_size,-1)
                mask_patches=mask_patches.contiguous().view(B,C*kernel_size*kernel_size,-1)

                output=F.fold(re_patches,output_size=(H,W),kernel_size=kernel_size,stride=stride)
                mask_output=F.fold(mask_patches,output_size=(H,W),kernel_size=kernel_size,stride=stride)

                output=output/mask_output