From a3e696932a714e95664e39bafe6d15feba6e7f54 Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Wed, 3 Jul 2024 20:20:20 +0200 Subject: [PATCH] Clarify perfs of ObliqueDTPolicy compared to DTPolicy Fixes #1 --- README.md | 2 +- docs/api.rst | 2 +- docs/conf.py | 2 +- docs/usage.md | 2 +- interpreter/interpreter.py | 2 +- interpreter/policies.py | 2 +- setup.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index fff052a..a619062 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ with open("tree_halfcheetah.pkl", "rb") as f: clf = load(f) # Render evaluate_policy( - DTPolicy(clf, env), + ObliqueDTPolicy(clf, env), env=Monitor(gym.make("HalfCheetah-v4", render_mode="human")), render=True, ) diff --git a/docs/api.rst b/docs/api.rst index ffb6ba3..efc4810 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -23,7 +23,7 @@ Main functions :toctree: generated/ :template: function.rst - interpreter.interpreter.Interpreter.train + interpreter.Interpreter.train interpreter.policies.ObliqueDTPolicy.get_oblique_data diff --git a/docs/conf.py b/docs/conf.py index 0366340..35f7914 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,7 +9,7 @@ project = "interpreter" copyright = "2024, Hector Kohler" author = "Hector Kohler" -release = "0.1.1" +release = "0.1.2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/usage.md b/docs/usage.md index 6ad59cd..0e22570 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -56,7 +56,7 @@ with open("tree_halfcheetah.pkl", "rb") as f: clf = load(f) # Render evaluate_policy( - DTPolicy(clf, env), + ObliqueDTPolicy(clf, env), env=Monitor(gym.make("HalfCheetah-v4", render_mode="human")), render=True, ) diff --git a/interpreter/interpreter.py b/interpreter/interpreter.py index 94279e3..d2e103f 100644 --- a/interpreter/interpreter.py +++ b/interpreter/interpreter.py @@ -79,7 +79,7 @@ def train(self, nb_iter): print("Fitting tree nb {} ...".format(t + 1)) S_new, A_new = self.tree_policy.generate_data(self.env, self.data_per_iter) S = np.concatenate((S, S_new)) - A = np.concatenate((A, A_new)) + A = np.concatenate((A, self.oracle.predict(S_new)[0])) self.tree_policy.fit_tree(S, A) tree_reward, _ = evaluate_policy(self.tree_policy, self.env) diff --git a/interpreter/policies.py b/interpreter/policies.py index c8dc941..54fcb1b 100644 --- a/interpreter/policies.py +++ b/interpreter/policies.py @@ -77,7 +77,7 @@ def generate_data(self, env, nb_data): ) if isinstance(env.action_space, gym.spaces.Discrete): assert env.action_space.n == self.action_space.n - A = np.zeros((nb_data, 1)) + A = np.zeros((nb_data)) elif isinstance(env.action_space, gym.spaces.Box): assert env.action_space.shape == self.action_space.shape A = np.zeros((nb_data, self.action_space.shape[0])) diff --git a/setup.py b/setup.py index 4d65cb7..450c087 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = "0.1.1" +__version__ = "0.1.2" packages = find_packages( exclude=[