torch.backends.cudnn.benchmark




import torch.backends.cudnn as cudnn
cudnn.benchmark = True

PyTorchでモデル内の畳み込み層を事前に最適化することができ,すなわちcuDNnが提供するすべての畳み込み実装アルゴリズムを各畳み込み層でテストし,最も速いものを選択することができる.これにより、モデルが起動されると、少しでも前処理時間がかかるだけで、トレーニング時間を大幅に削減することができます.ネットワークで入力サイズを安定的に固定する場合に使用し、batchsizeで固定します.
//        :https://github.com/pytorch/pytorch/blob/b5fa9a340a0d174131ad0a452c395860d571b5b0/aten/src/ATen/native/cudnn/Conv.cpp#L701
template
void findAlgorithm(const ConvolutionArgs& args, bool benchmark, perf_t* algoPerf) {
  using search = algorithm_search;
  auto& cache = search::cache();

  //                     ,       ,   
  if (cache.find(args.params, algoPerf)) {
    return;
  }

  //     PyTorch        torch.backends.cudnn.deterministic=True,
  //    cudnn.benchmark == False  ,             ,  
  if (args.params.deterministic && !benchmark) {
    algoPerf->algo = search::DEFAULT_ALGO;
    if (args.params.dataType == CUDNN_DATA_HALF) {
      algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
    } else {
      algoPerf->mathType = CUDNN_DEFAULT_MATH;
    }
    search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
    return;
  }

  //                         ,
  // recheck                     
  if (benchmark) {
    if (cache.find(args.params, algoPerf)) {
      // re-check cache since another thread may have benchmarked the algorithm
      return;
    }
  }

  //  ,           ,               ,
  //      search::findAlgorithm    benchmarking。
  //      search::findAlgorithm   ,     。
  auto perfResults = search::findAlgorithm(args, benchmark);

  //    findAlgorithm        ,        determinnistic,
  //    findAlgorithm    
  //     ,   deterministic,           
  // for deterministic algo, look at all the perf results and return the best
  // deterministic algo
  if (perfResults.status == CUDNN_STATUS_SUCCESS &&
      !(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) {

      // if benchmarking, map the original params with the found algo+math type for re-use
      if (benchmark) {
        // cache      benchmark    
        cache.insert(args.params, perfResults);

        // Free the cached blocks in our caching allocator. They are
        // needed here because the above benchmarking uses a huge amount of memory,
        // e.g. a few GBs.
        c10::cuda::CUDACachingAllocator::emptyCache();
      }

      *algoPerf = perfResults;
  } else {
      algoPerf->algo = search::DEFAULT_ALGO;
      if (args.params.dataType == CUDNN_DATA_HALF) {
        algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
      } else {
        algoPerf->mathType = CUDNN_DEFAULT_MATH;
      }
      search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
  }
}


//      forward      
//        : https://github.com/pytorch/pytorch/blob/b5fa9a340a0d174131ad0a452c395860d571b5b0/aten/src/ATen/native/cudnn/Conv.cpp#L504
template<>
struct algorithm_search {
  using perf_t = cudnnConvolutionFwdAlgoPerf_t;
  using algo_t = cudnnConvolutionFwdAlgo_t;

  //       !
  static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
  static BenchmarkCache& cache() { return fwd_algos; }

  static perf_t findAlgorithm(const ConvolutionArgs& args, bool benchmark) {
    // CuDNN     forward   ,    :
    static const algo_t algos[] = {
         CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
         CUDNN_CONVOLUTION_FWD_ALGO_FFT,
         CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
         CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
         CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
         CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
         CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
         CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
    };
    static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
    static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
                  "Missing cuDNN convolution forward algorithms");
    int perf_count;
    std::unique_ptr perf_results(new perf_t[num_algos]);

    //       benchmark   ,          ,PyTorch      ,
    //     cudnnGetConvolutionForwardAlgorithm_v7 !
    if (!benchmark) {
      AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
          args.handle,
          args.idesc.desc(),
          args.wdesc.desc(),
          args.cdesc.desc(),
          args.odesc.desc(),
          num_algos,
          &perf_count,
          perf_results.get()));
    } else { //      benchmark,    cudnnFindConvolutionForwardAlgorithmEx !
      size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
      Workspace ws(max_ws_size);
      AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
          args.handle,
          args.idesc.desc(), args.input.data_ptr(),
          args.wdesc.desc(), args.weight.data_ptr(),
          args.cdesc.desc(),
          args.odesc.desc(), args.output.data_ptr(),
          num_algos,
          &perf_count,
          perf_results.get(),
          ws.data,
          ws.size));
    }
    return getBestAlgorithm(perf_results.get(), args, perf_count);
  }