Skip to content

Commit

Permalink
Merge pull request #173 from pooyadavoodi/findex_no_bwd_for_testnet
Browse files Browse the repository at this point in the history
No FindEx for backward pass of test nets.
  • Loading branch information
drnikolaev authored Jun 18, 2016
2 parents c4d3722 + 651fef4 commit 28e0f13
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,42 +309,45 @@ void CuDNNConvolutionLayer<Dtype>::FindExConvAlgo(
fwd_algo_[i] = fwd_results[0].algo;
workspace_fwd_sizes_[i] = fwd_results[0].memory;

// Find backward filter algorithm
CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
Caffe::cudnn_handle(),
bottom_descs_[i],
bottom[i]->gpu_data(),
top_descs_[i],
top[i]->gpu_diff(),
conv_descs_[i],
filter_desc_,
tmp_weights,
kRequestAlgoCount,
&filter_algo_count,
bwd_filter_results,
workspace.data(),
workspace.size()));
bwd_filter_algo_[i] = bwd_filter_results[0].algo;
workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory;

// Find backward data algorithm
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
Caffe::cudnn_handle(),
filter_desc_,
this->blobs_[0]->gpu_data(),
top_descs_[i],
top[i]->gpu_diff(),
conv_descs_[i],
bottom_descs_[i],
bottom[i]->mutable_gpu_diff(),
kRequestAlgoCount,
&data_algo_count,
bwd_data_results,
workspace.data(),
workspace.size()));

bwd_data_algo_[i] = bwd_data_results[0].algo;
workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory;
// Only set backward-filter/data algorithms in training phase
if (this->phase_ == TRAIN) {
// Find backward filter algorithm
CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
Caffe::cudnn_handle(),
bottom_descs_[i],
bottom[i]->gpu_data(),
top_descs_[i],
top[i]->gpu_diff(),
conv_descs_[i],
filter_desc_,
tmp_weights,
kRequestAlgoCount,
&filter_algo_count,
bwd_filter_results,
workspace.data(),
workspace.size()));
bwd_filter_algo_[i] = bwd_filter_results[0].algo;
workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory;

// Find backward data algorithm
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
Caffe::cudnn_handle(),
filter_desc_,
this->blobs_[0]->gpu_data(),
top_descs_[i],
top[i]->gpu_diff(),
conv_descs_[i],
bottom_descs_[i],
bottom[i]->mutable_gpu_diff(),
kRequestAlgoCount,
&data_algo_count,
bwd_data_results,
workspace.data(),
workspace.size()));

bwd_data_algo_[i] = bwd_data_results[0].algo;
workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory;
}
}
GPUMemory::deallocate(tmp_weights);
workspace.release();
Expand Down

0 comments on commit 28e0f13

Please # to comment.