From ce5063a3465d851f2e03be74f412bf0f8586e4a2 Mon Sep 17 00:00:00 2001 From: Florian Felten Date: Tue, 4 Apr 2023 11:32:34 +0200 Subject: [PATCH 1/2] Fix OLS --- examples/pcn_minecart.py | 1 + .../multi_policy/linear_support/linear_support.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/pcn_minecart.py b/examples/pcn_minecart.py index 0126d5ee..aabc577f 100644 --- a/examples/pcn_minecart.py +++ b/examples/pcn_minecart.py @@ -24,6 +24,7 @@ def make_env(): ) agent.train( + eval_env=make_env(), total_timesteps=int(1e7), ref_point=np.array([0, 0, -200.0]), num_er_episodes=20, diff --git a/morl_baselines/multi_policy/linear_support/linear_support.py b/morl_baselines/multi_policy/linear_support/linear_support.py index f88fc352..bfa9ac48 100644 --- a/morl_baselines/multi_policy/linear_support/linear_support.py +++ b/morl_baselines/multi_policy/linear_support/linear_support.py @@ -6,6 +6,7 @@ import cdd import cvxpy as cp import numpy as np +from cvxpy import SolverError from gymnasium.core import Env from morl_baselines.common.evaluation import policy_evaluation_mo @@ -289,7 +290,12 @@ def max_value_lp(self, w_new: np.ndarray) -> float: # such that it is consistent with other optimal values for other visited weights constraints = [W @ v <= V] prob = cp.Problem(objective, constraints) - return prob.solve(verbose=False) + try: + result = prob.solve(verbose=False) + except SolverError: + print("SCS solver error, trying another one.") + result = prob.solve(solver=cp.SCS, verbose=False) + return result def compute_corner_weights(self) -> List[np.ndarray]: """Returns the corner weights for the current set of values. From f1f84ec1b5bcfb986e78dea232d6deed805a3920 Mon Sep 17 00:00:00 2001 From: Florian Felten Date: Tue, 4 Apr 2023 11:34:35 +0200 Subject: [PATCH 2/2] Fix typo --- morl_baselines/multi_policy/linear_support/linear_support.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/morl_baselines/multi_policy/linear_support/linear_support.py b/morl_baselines/multi_policy/linear_support/linear_support.py index bfa9ac48..08c12ca9 100644 --- a/morl_baselines/multi_policy/linear_support/linear_support.py +++ b/morl_baselines/multi_policy/linear_support/linear_support.py @@ -293,7 +293,7 @@ def max_value_lp(self, w_new: np.ndarray) -> float: try: result = prob.solve(verbose=False) except SolverError: - print("SCS solver error, trying another one.") + print("ECOS solver error, trying another one.") result = prob.solve(solver=cp.SCS, verbose=False) return result