From c2d3c6559ef96f423721392e0efa7a7e4a3c3d5a Mon Sep 17 00:00:00 2001 From: Jonathan Hunt Date: Tue, 7 Jan 2020 17:07:03 +0000 Subject: [PATCH] TF2 fixes. --- dopamine/discrete_domains/atari_lib.py | 10 ++++++++-- dopamine/replay_memory/circular_replay_buffer.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/dopamine/discrete_domains/atari_lib.py b/dopamine/discrete_domains/atari_lib.py index fff476e1..6d0467e2 100644 --- a/dopamine/discrete_domains/atari_lib.py +++ b/dopamine/discrete_domains/atari_lib.py @@ -51,8 +51,14 @@ class has two main functions: `.__init__` and `.call`. When we create our import tensorflow.compat.v1 as tf import cv2 -from tensorflow.contrib import layers as contrib_layers -from tensorflow.contrib import slim as contrib_slim +from tensorflow.compat.v1 import layers as contrib_layers + +# Allow failure on this import (not in tf2). This means atari won't be +# available but other domains will. +try: + from tensorflow.contrib import slim as contrib_slim +except: + pass NATURE_DQN_OBSERVATION_SHAPE = (84, 84) # Size of downscaled Atari 2600 frame. diff --git a/dopamine/replay_memory/circular_replay_buffer.py b/dopamine/replay_memory/circular_replay_buffer.py index 1a5020fa..096fffd7 100644 --- a/dopamine/replay_memory/circular_replay_buffer.py +++ b/dopamine/replay_memory/circular_replay_buffer.py @@ -34,7 +34,7 @@ import tensorflow.compat.v1 as tf import gin.tf -from tensorflow.contrib import staging as contrib_staging +from tensorflow.python.ops import data_flow_ops # Defines a type describing part of the tuple returned by the replay # memory. Each element of the tuple is a tensor of shape [batch, ...] where @@ -855,7 +855,7 @@ def _set_up_staging(self, transition): transition_type = self.memory.get_transition_elements() # Create the staging area in CPU. - prefetch_area = contrib_staging.StagingArea( + prefetch_area = data_flow_ops.StagingArea( [shape_with_type.type for shape_with_type in transition_type]) # Store prefetch op for tests, but keep it private -- users should not be