diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 6f149fc6237..fd1a4f423cd 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -309,42 +309,45 @@ void CuDNNConvolutionLayer::FindExConvAlgo( fwd_algo_[i] = fwd_results[0].algo; workspace_fwd_sizes_[i] = fwd_results[0].memory; - // Find backward filter algorithm - CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( - Caffe::cudnn_handle(), - bottom_descs_[i], - bottom[i]->gpu_data(), - top_descs_[i], - top[i]->gpu_diff(), - conv_descs_[i], - filter_desc_, - tmp_weights, - kRequestAlgoCount, - &filter_algo_count, - bwd_filter_results, - workspace.data(), - workspace.size())); - bwd_filter_algo_[i] = bwd_filter_results[0].algo; - workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory; - - // Find backward data algorithm - CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( - Caffe::cudnn_handle(), - filter_desc_, - this->blobs_[0]->gpu_data(), - top_descs_[i], - top[i]->gpu_diff(), - conv_descs_[i], - bottom_descs_[i], - bottom[i]->mutable_gpu_diff(), - kRequestAlgoCount, - &data_algo_count, - bwd_data_results, - workspace.data(), - workspace.size())); - - bwd_data_algo_[i] = bwd_data_results[0].algo; - workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory; + // Only set backward-filter/data algorithms in training phase + if (this->phase_ == TRAIN) { + // Find backward filter algorithm + CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( + Caffe::cudnn_handle(), + bottom_descs_[i], + bottom[i]->gpu_data(), + top_descs_[i], + top[i]->gpu_diff(), + conv_descs_[i], + filter_desc_, + tmp_weights, + kRequestAlgoCount, + &filter_algo_count, + bwd_filter_results, + workspace.data(), + workspace.size())); + bwd_filter_algo_[i] = bwd_filter_results[0].algo; + workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory; + + // Find backward data algorithm + CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( + Caffe::cudnn_handle(), + filter_desc_, + this->blobs_[0]->gpu_data(), + top_descs_[i], + top[i]->gpu_diff(), + conv_descs_[i], + bottom_descs_[i], + bottom[i]->mutable_gpu_diff(), + kRequestAlgoCount, + &data_algo_count, + bwd_data_results, + workspace.data(), + workspace.size())); + + bwd_data_algo_[i] = bwd_data_results[0].algo; + workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory; + } } GPUMemory::deallocate(tmp_weights); workspace.release();