Skip to content

Commit 5c04446

Browse files
committed
correct imports
1 parent 3f19583 commit 5c04446

File tree

5 files changed

+18
-15
lines changed

5 files changed

+18
-15
lines changed

tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h"
16+
#define EIGEN_USE_THREADS
1717

18+
#if GOOGLE_CUDA
19+
#define EIGEN_USE_GPU
20+
#endif // GOOGLE_CUDA
21+
22+
#include "tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h"
1823
#include "tensorflow/core/framework/op_kernel.h"
1924
#include "tensorflow/core/framework/register_types.h"
2025
#include "tensorflow/core/framework/types.h"

tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ limitations under the License.
1717

1818
#define EIGEN_USE_GPU
1919

20-
#include "tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h"
21-
20+
#include "tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h"
2221
#include "external/cub_archive/cub/device/device_reduce.cuh"
2322
#include "tensorflow/core/framework/tensor.h"
2423
#include "tensorflow/core/framework/tensor_shape.h"

tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ REGISTER_OP("CorrelationCost")
9090
Compute Correlation costs.
9191
9292
This layer implements the correlation operation from
93-
FlowNet: Learning Optical Flow with Convolutional Networks (Fischer et al.)
93+
FlowNet Learning Optical Flow with Convolutional Networks (Fischer et al.)
9494
9595
input_a: A `Tensor` of the format specified by `data_format`.
9696
input_b: A `Tensor` of the format specified by `data_format`.

tensorflow_addons/opticalflow/correlation_cost.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from tensorflow.contrib.correlation_cost.ops import gen_correlation_cost_op
22-
from tensorflow.contrib.util import loader
21+
import tensorflow as tf
2322
from tensorflow.python.framework import ops
2423
from tensorflow.python.ops import array_ops
25-
from tensorflow.python.platform import resource_loader
24+
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2625

27-
_correlation_cost_op_so = loader.load_op_library(
28-
resource_loader.get_path_to_datafile("_correlation_cost_op.so"))
26+
_correlation_cost_op_so = tf.load_op_library(
27+
get_path_to_datafile("custom_ops/opticalflow/_correlation_cost_ops.so"))
2928

3029
# pylint: disable=redefined-builtin
3130

@@ -81,7 +80,7 @@ def correlation_cost(input_a,
8180
"""
8281

8382
with ops.name_scope(name, "correlation_cost"):
84-
op_call = gen_correlation_cost_op.correlation_cost
83+
op_call = _correlation_cost_op_so.correlation_cost
8584
ret = op_call(
8685
input_a,
8786
input_b,
@@ -98,7 +97,7 @@ def correlation_cost(input_a,
9897
return ret
9998

10099

101-
correlation_cost_grad = gen_correlation_cost_op.correlation_cost_grad
100+
correlation_cost_grad = _correlation_cost_op_so.correlation_cost_grad
102101

103102

104103
@ops.RegisterGradient("CorrelationCost")
@@ -114,7 +113,7 @@ def _correlation_cost_grad(op, grad_output):
114113
input_b = ops.convert_to_tensor(op.inputs[1], name="input_b")
115114
grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output")
116115

117-
op_call = gen_correlation_cost_op.correlation_cost_grad
116+
op_call = _correlation_cost_op_so.correlation_cost_grad
118117
grads = op_call(
119118
input_a,
120119
input_b,

tensorflow_addons/opticalflow/correlation_cost_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121

22-
from tensorflow.contrib.correlation_cost.python.ops import correlation_cost_op
22+
from tensorflow_addons.opticalflow import correlation_cost
2323
from tensorflow.python.framework import dtypes
2424
from tensorflow.python.ops import array_ops
2525
from tensorflow.python.framework import ops
@@ -49,7 +49,7 @@ def _forward(self,
4949
stride_2 = 2
5050
pad = 4
5151

52-
call_op = correlation_cost_op.correlation_cost
52+
call_op = correlation_cost
5353
actual_op = call_op(
5454
input_a_op,
5555
input_b_op,
@@ -179,7 +179,7 @@ def _gradients(self, data_format='NCHW', use_gpu=False):
179179
input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32)
180180
input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32)
181181

182-
call_op = correlation_cost_op.correlation_cost
182+
call_op = correlation_cost
183183
actual_op = call_op(
184184
input_a_op,
185185
input_b_op,

0 commit comments

Comments
 (0)