torch.backends.cudnn.benchmark
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
PyTorchでモデル内の畳み込み層を事前に最適化することができ,すなわちcuDNnが提供するすべての畳み込み実装アルゴリズムを各畳み込み層でテストし,最も速いものを選択することができる.これにより、モデルが起動されると、少しでも前処理時間がかかるだけで、トレーニング時間を大幅に削減することができます.ネットワークで入力サイズを安定的に固定する場合に使用し、batchsizeで固定します.
// :https://github.com/pytorch/pytorch/blob/b5fa9a340a0d174131ad0a452c395860d571b5b0/aten/src/ATen/native/cudnn/Conv.cpp#L701
template
void findAlgorithm(const ConvolutionArgs& args, bool benchmark, perf_t* algoPerf) {
using search = algorithm_search;
auto& cache = search::cache();
// , ,
if (cache.find(args.params, algoPerf)) {
return;
}
// PyTorch torch.backends.cudnn.deterministic=True,
// cudnn.benchmark == False , ,
if (args.params.deterministic && !benchmark) {
algoPerf->algo = search::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
} else {
algoPerf->mathType = CUDNN_DEFAULT_MATH;
}
search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
return;
}
// ,
// recheck
if (benchmark) {
if (cache.find(args.params, algoPerf)) {
// re-check cache since another thread may have benchmarked the algorithm
return;
}
}
// , , ,
// search::findAlgorithm benchmarking。
// search::findAlgorithm , 。
auto perfResults = search::findAlgorithm(args, benchmark);
// findAlgorithm , determinnistic,
// findAlgorithm
// , deterministic,
// for deterministic algo, look at all the perf results and return the best
// deterministic algo
if (perfResults.status == CUDNN_STATUS_SUCCESS &&
!(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) {
// if benchmarking, map the original params with the found algo+math type for re-use
if (benchmark) {
// cache benchmark
cache.insert(args.params, perfResults);
// Free the cached blocks in our caching allocator. They are
// needed here because the above benchmarking uses a huge amount of memory,
// e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache();
}
*algoPerf = perfResults;
} else {
algoPerf->algo = search::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
} else {
algoPerf->mathType = CUDNN_DEFAULT_MATH;
}
search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
}
}
// forward
// : https://github.com/pytorch/pytorch/blob/b5fa9a340a0d174131ad0a452c395860d571b5b0/aten/src/ATen/native/cudnn/Conv.cpp#L504
template<>
struct algorithm_search {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
using algo_t = cudnnConvolutionFwdAlgo_t;
// !
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
static BenchmarkCache& cache() { return fwd_algos; }
static perf_t findAlgorithm(const ConvolutionArgs& args, bool benchmark) {
// CuDNN forward , :
static const algo_t algos[] = {
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
};
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution forward algorithms");
int perf_count;
std::unique_ptr perf_results(new perf_t[num_algos]);
// benchmark , ,PyTorch ,
// cudnnGetConvolutionForwardAlgorithm_v7 !
if (!benchmark) {
AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
num_algos,
&perf_count,
perf_results.get()));
} else { // benchmark, cudnnFindConvolutionForwardAlgorithmEx !
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
args.wdesc.desc(), args.weight.data_ptr(),
args.cdesc.desc(),
args.odesc.desc(), args.output.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size));
}
return getBestAlgorithm(perf_results.get(), args, perf_count);
}