Skip to content

Commit

Permalink
Merge pull request #8 from KohlerHECTOR/KohlerHECTOR/issue7
Browse files Browse the repository at this point in the history
Dagger weights
  • Loading branch information
KohlerHECTOR authored Jul 13, 2024
2 parents 7b3e565 + b5e60c4 commit 231b958
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
1 change: 0 additions & 1 deletion examples/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from stable_baselines3.common.monitor import Monitor

import gymnasium as gym
from gymnasium.wrappers.time_limit import TimeLimit
from sklearn.tree import DecisionTreeRegressor
from huggingface_sb3 import load_from_hub

Expand Down
11 changes: 7 additions & 4 deletions interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Interpreter:
"""
A class to interpret a neural net policy using a decision tree policy.
It follows algorithm 1 from https://arxiv.org/abs/2405.14956
By default, the trajectories will be sampled in a DAgger-like way.
Parameters
----------
Expand Down Expand Up @@ -75,11 +76,13 @@ def train(self, nb_iter):
self.tree_policies = [deepcopy(self.tree_policy)]
self.tree_policies_rewards = [tree_reward]

for t in range(nb_iter - 1):
for t in range(1, nb_iter + 1):
print("Fitting tree nb {} ...".format(t + 1))
S_new, A_new = self.tree_policy.generate_data(self.env, self.data_per_iter)
S = np.concatenate((S, S_new))
A = np.concatenate((A, self.oracle.predict(S_new)[0]))
S_tree, _ = self.tree_policy.generate_data(self.env, int((t/nb_iter) * self.data_per_iter))
S_oracle, A_oracle = self.oracle.generate_data(self.env, int((1 - t/nb_iter) * self.data_per_iter))

S = np.concatenate((S, S_tree, S_oracle))
A = np.concatenate((A, self.oracle.predict(S_tree)[0], A_oracle))

self.tree_policy.fit_tree(S, A)
tree_reward, _ = evaluate_policy(self.tree_policy, self.env)
Expand Down
2 changes: 1 addition & 1 deletion interpreter/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def generate_data(self, env, nb_data):
The generated actions.
"""
assert (
nb_data > 0 and env.observation_space.shape == self.observation_space.shape
nb_data >= 0 and env.observation_space.shape == self.observation_space.shape
)
if isinstance(env.action_space, gym.spaces.Discrete):
assert env.action_space.n == self.action_space.n
Expand Down

0 comments on commit 231b958

Please # to comment.