From 651fef48f2873da0e00ab0a33629974d00fb742d Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 17 Jun 2016 15:18:04 -0700 Subject: [PATCH] No FindEx for backward pass of test nets. --- src/caffe/layers/cudnn_conv_layer.cpp | 75 ++++++++++++++------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 6f149fc6237..fd1a4f423cd 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -309,42 +309,45 @@ void CuDNNConvolutionLayer::FindExConvAlgo( fwd_algo_[i] = fwd_results[0].algo; workspace_fwd_sizes_[i] = fwd_results[0].memory; - // Find backward filter algorithm - CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( - Caffe::cudnn_handle(), - bottom_descs_[i], - bottom[i]->gpu_data(), - top_descs_[i], - top[i]->gpu_diff(), - conv_descs_[i], - filter_desc_, - tmp_weights, - kRequestAlgoCount, - &filter_algo_count, - bwd_filter_results, - workspace.data(), - workspace.size())); - bwd_filter_algo_[i] = bwd_filter_results[0].algo; - workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory; - - // Find backward data algorithm - CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( - Caffe::cudnn_handle(), - filter_desc_, - this->blobs_[0]->gpu_data(), - top_descs_[i], - top[i]->gpu_diff(), - conv_descs_[i], - bottom_descs_[i], - bottom[i]->mutable_gpu_diff(), - kRequestAlgoCount, - &data_algo_count, - bwd_data_results, - workspace.data(), - workspace.size())); - - bwd_data_algo_[i] = bwd_data_results[0].algo; - workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory; + // Only set backward-filter/data algorithms in training phase + if (this->phase_ == TRAIN) { + // Find backward filter algorithm + CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( + Caffe::cudnn_handle(), + bottom_descs_[i], + bottom[i]->gpu_data(), + top_descs_[i], + top[i]->gpu_diff(), + conv_descs_[i], + filter_desc_, + tmp_weights, + kRequestAlgoCount, + &filter_algo_count, + bwd_filter_results, + workspace.data(), + workspace.size())); + bwd_filter_algo_[i] = bwd_filter_results[0].algo; + workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory; + + // Find backward data algorithm + CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( + Caffe::cudnn_handle(), + filter_desc_, + this->blobs_[0]->gpu_data(), + top_descs_[i], + top[i]->gpu_diff(), + conv_descs_[i], + bottom_descs_[i], + bottom[i]->mutable_gpu_diff(), + kRequestAlgoCount, + &data_algo_count, + bwd_data_results, + workspace.data(), + workspace.size())); + + bwd_data_algo_[i] = bwd_data_results[0].algo; + workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory; + } } GPUMemory::deallocate(tmp_weights); workspace.release();