-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint (cherry picked from commit 1a8e0bd) * Not a workable version (cherry picked from commit 3498e18) * vessel * ckpt * . * vessel * . * . * checkpoint callback * . * cleanup * logger * . * test * . * add test * . * . * . * . * New reward * Add train API * fix mypy * fix lint * More comment * 3.7 compat * fix test * fix test * . * Resolve comments * fix typehint
- Loading branch information
Showing
17 changed files
with
1,410 additions
and
145 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import cast | ||
|
||
import numpy as np | ||
from qlib.rl.reward import Reward | ||
|
||
from .simulator_simple import SAOEState, SAOEMetrics | ||
|
||
__all__ = ["PAPenaltyReward"] | ||
|
||
|
||
class PAPenaltyReward(Reward[SAOEState]): | ||
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time. | ||
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`. | ||
Parameters | ||
---------- | ||
penalty | ||
The penalty for large volume in a short time. | ||
""" | ||
|
||
def __init__(self, penalty: float = 100.0): | ||
self.penalty = penalty | ||
|
||
def reward(self, simulator_state: SAOEState) -> float: | ||
whole_order = simulator_state.order.amount | ||
assert whole_order > 0 | ||
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict()) | ||
pa = last_step["pa"] * last_step["amount"] / whole_order | ||
|
||
# Inspect the "break-down" of the latest step: trading amount at every tick | ||
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :] | ||
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum() | ||
|
||
reward = pa + penalty | ||
|
||
# Throw error in case of NaN | ||
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}" | ||
|
||
self.log("reward/pa", pa) | ||
self.log("reward/penalty", penalty) | ||
return reward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Train, test, inference utilities.""" | ||
|
||
from .api import backtest, train | ||
from .callbacks import EarlyStopping, Checkpoint | ||
from .trainer import Trainer | ||
from .vessel import TrainingVessel, TrainingVesselBase |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Callable, Sequence, cast, Any | ||
|
||
from tianshou.policy import BasePolicy | ||
|
||
from qlib.rl.simulator import InitialStateType, Simulator | ||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter | ||
from qlib.rl.reward import Reward | ||
from qlib.rl.utils import FiniteEnvType, LogWriter | ||
|
||
from .vessel import TrainingVessel | ||
from .trainer import Trainer | ||
|
||
|
||
def train( | ||
simulator_fn: Callable[[InitialStateType], Simulator], | ||
state_interpreter: StateInterpreter, | ||
action_interpreter: ActionInterpreter, | ||
initial_states: Sequence[InitialStateType], | ||
policy: BasePolicy, | ||
reward: Reward, | ||
vessel_kwargs: dict[str, Any], | ||
trainer_kwargs: dict[str, Any], | ||
) -> None: | ||
"""Train a policy with the parallelism provided by RL framework. | ||
Experimental API. Parameters might change shortly. | ||
Parameters | ||
---------- | ||
simulator_fn | ||
Callable receiving initial seed, returning a simulator. | ||
state_interpreter | ||
Interprets the state of simulators. | ||
action_interpreter | ||
Interprets the policy actions. | ||
initial_states | ||
Initial states to iterate over. Every state will be run exactly once. | ||
policy | ||
Policy to train against. | ||
reward | ||
Reward function. | ||
vessel_kwargs | ||
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``. | ||
trainer_kwargs | ||
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``. | ||
""" | ||
|
||
vessel = TrainingVessel( | ||
simulator_fn=simulator_fn, | ||
state_interpreter=state_interpreter, | ||
action_interpreter=action_interpreter, | ||
policy=policy, | ||
train_initial_states=initial_states, | ||
reward=reward, # ignore none | ||
**vessel_kwargs, | ||
) | ||
trainer = Trainer(**trainer_kwargs) | ||
trainer.fit(vessel) | ||
|
||
|
||
def backtest( | ||
simulator_fn: Callable[[InitialStateType], Simulator], | ||
state_interpreter: StateInterpreter, | ||
action_interpreter: ActionInterpreter, | ||
initial_states: Sequence[InitialStateType], | ||
policy: BasePolicy, | ||
logger: LogWriter | list[LogWriter], | ||
reward: Reward | None = None, | ||
finite_env_type: FiniteEnvType = "subproc", | ||
concurrency: int = 2, | ||
) -> None: | ||
"""Backtest with the parallelism provided by RL framework. | ||
Experimental API. Parameters might change shortly. | ||
Parameters | ||
---------- | ||
simulator_fn | ||
Callable receiving initial seed, returning a simulator. | ||
state_interpreter | ||
Interprets the state of simulators. | ||
action_interpreter | ||
Interprets the policy actions. | ||
initial_states | ||
Initial states to iterate over. Every state will be run exactly once. | ||
policy | ||
Policy to test against. | ||
logger | ||
Logger to record the backtest results. Logger must be present because | ||
without logger, all information will be lost. | ||
reward | ||
Optional reward function. For backtest, this is for testing the rewards | ||
and logging them only. | ||
finite_env_type | ||
Type of finite env implementation. | ||
concurrency | ||
Parallel workers. | ||
""" | ||
|
||
vessel = TrainingVessel( | ||
simulator_fn=simulator_fn, | ||
state_interpreter=state_interpreter, | ||
action_interpreter=action_interpreter, | ||
policy=policy, | ||
test_initial_states=initial_states, | ||
reward=cast(Reward, reward), # ignore none | ||
) | ||
trainer = Trainer( | ||
finite_env_type=finite_env_type, | ||
concurrency=concurrency, | ||
loggers=logger, | ||
) | ||
trainer.test(vessel) |
Oops, something went wrong.