Skip to content

Commit

Permalink
Merge pull request #6 from KohlerHECTOR/KohlerHECTOR/issue1
Browse files Browse the repository at this point in the history
Clarify perfs of ObliqueDTPolicy compared to DTPolicy
  • Loading branch information
KohlerHECTOR authored Jul 3, 2024
2 parents c00059d + a3e6969 commit 3447112
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ with open("tree_halfcheetah.pkl", "rb") as f:
clf = load(f)
# Render
evaluate_policy(
DTPolicy(clf, env),
ObliqueDTPolicy(clf, env),
env=Monitor(gym.make("HalfCheetah-v4", render_mode="human")),
render=True,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Main functions
:toctree: generated/
:template: function.rst

interpreter.interpreter.Interpreter.train
interpreter.Interpreter.train
interpreter.policies.ObliqueDTPolicy.get_oblique_data


2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = "interpreter"
copyright = "2024, Hector Kohler"
author = "Hector Kohler"
release = "0.1.1"
release = "0.1.2"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ with open("tree_halfcheetah.pkl", "rb") as f:
clf = load(f)
# Render
evaluate_policy(
DTPolicy(clf, env),
ObliqueDTPolicy(clf, env),
env=Monitor(gym.make("HalfCheetah-v4", render_mode="human")),
render=True,
)
Expand Down
2 changes: 1 addition & 1 deletion interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def train(self, nb_iter):
print("Fitting tree nb {} ...".format(t + 1))
S_new, A_new = self.tree_policy.generate_data(self.env, self.data_per_iter)
S = np.concatenate((S, S_new))
A = np.concatenate((A, A_new))
A = np.concatenate((A, self.oracle.predict(S_new)[0]))

self.tree_policy.fit_tree(S, A)
tree_reward, _ = evaluate_policy(self.tree_policy, self.env)
Expand Down
2 changes: 1 addition & 1 deletion interpreter/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def generate_data(self, env, nb_data):
)
if isinstance(env.action_space, gym.spaces.Discrete):
assert env.action_space.n == self.action_space.n
A = np.zeros((nb_data, 1))
A = np.zeros((nb_data))
elif isinstance(env.action_space, gym.spaces.Box):
assert env.action_space.shape == self.action_space.shape
A = np.zeros((nb_data, self.action_space.shape[0]))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

__version__ = "0.1.1"
__version__ = "0.1.2"

packages = find_packages(
exclude=[
Expand Down

0 comments on commit 3447112

Please # to comment.