Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Implement SerializedSequenceSimulatedEnvProblem
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 263020747
  • Loading branch information
koz4k authored and copybara-github committed Aug 12, 2019
1 parent 41726d4 commit f7f8549
Show file tree
Hide file tree
Showing 5 changed files with 579 additions and 64 deletions.
6 changes: 3 additions & 3 deletions tensor2tensor/trax/rlax/ppo_training_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def loss(*args, **kwargs):
)
trainer.train_epoch(epoch_steps=1, eval_steps=1)

# Repeat the initial observations over and over again.
# Repeat the history over and over again.
stream = itertools.repeat(np.zeros(history_shape))
env_fn = functools.partial(
simulated_env_problem.SimulatedEnvProblem,
simulated_env_problem.RawSimulatedEnvProblem,
model=model,
history_length=history_shape[1],
trajectory_length=3,
Expand All @@ -184,7 +184,7 @@ def loss(*args, **kwargs):
action_space=gym.spaces.Discrete(n=n_actions),
reward_range=(-1, 1),
discrete_rewards=False,
initial_observation_stream=stream,
history_stream=stream,
output_dir=output_dir,
)

Expand Down
Loading

0 comments on commit f7f8549

Please # to comment.