Skip to content

feat: add q-learning 算法 #10

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ node_modules/

# Package
package-lock.json
.history
1 change: 1 addition & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ <h1>Welcome to T-Rex Run</h1>
<li><a href="neural-network.html">Neural Network</a></li>
<li><a href="neural-network-multiplayer.html">Neural Network for Multiplayers</a></li>
<li><a href="genetic-neural-network.html">Genetic + Neural Network</a></li>
<li><a href="q-learning.html">RL--q-learning</a></li>
</ul>
</body>
</html>
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"dependencies": {
"@tensorflow/tfjs": "^0.9.0",
"babel-polyfill": "^6.26.0",
"localforage": "^1.10.0",
"package.json": "^2.0.1"
}
}
31 changes: 31 additions & 0 deletions q-learning.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<title>q-learning - T-Rex Runner</title>
<style>
h3 {
padding-left: 20px;
}
#table{
height: 600px;
overflow: scroll;
padding: 20px;
}
</style>
<script src="assets/vendor.js"></script>
</head>
<body class="page">
<h1>q-learning</h1>
<div class="game">
<div class="generation"></div>
</div>
<h3>q-Table</h3>
<div id= "table">

</div>
<script src="assets/q-learning.js"></script>
</body>
</html>
134 changes: 134 additions & 0 deletions src/ai/models/qlearning/qlearningModel.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@

import localforage from 'localforage'
import Model from '../Model';

export const Key = 'qlearning-chrome://dino'
localforage.config({
name: Key
});
let globalQTable = {}

export function saveDB() {
localforage.setItem(Key, globalQTable)
}
export function clearDB() {
globalQTable = {}
saveDB()
}

function getRandomInt(min, max) {
min = Math.ceil(min);
max = Math.floor(max);
return Math.floor(Math.random() * (max - min + 1)) + min;
}
function indexOfMax(arr) {
if (arr.length === 0) {
return -1;
}
let max = arr[0];
let maxIndex = 0;
for (let i = 1; i < arr.length; i+=1) {
if (arr[i] > max) {
maxIndex = i;
max = arr[i];
}
}
return maxIndex;
}
export default class QNetwork extends Model{
constructor(actions, states) {
super(actions, states)
this.actions = actions; // 0 runing 1 jump
this.epsilon = 0.5; // greedy
this.epsilon_decay = 0.001;
this.currentState = 0;
this.alpha = 0.2
this.gamma = 0.9
this.prevState = '0,0'
this.preAction = 0
}
initActions() {
return Array.from({length: 2}).map(item => 0)
}
async getQArrFromDB() {
const result = await localforage.getItem(Key)
return result
}

async init(actionLen) {
const result = await this.getQArrFromDB()
console.info(`from indexedDB get qArr(${Object.keys(globalQTable).length})`, result)
if (result) {
globalQTable = result
}
}
showQArr() {
console.info(`>>>>>>> this.qTable(${Object.keys(globalQTable).length}): `)
}

predict(inputXs) {
const inputX = inputXs[0]
// console.info('q-table >>>> ', globalQTable, inputX)
return this.think(inputX)
}

think(state) {
this.currentState = globalQTable[state];
if (!this.currentState) {
this.currentState = this.initActions()
globalQTable[state] = this.currentState
}
let action = null;
if (Math.random() < this.epsilon) {
action = 0 // getRandomInt(0,(this.actions - 1))
} else {
action = indexOfMax(this.currentState);
}
this.epsilon = this.epsilon - (this.epsilon_decay * this.epsilon);
return action;
}

giveReward(reward, state) {
const prevState = this.prevState
const action = this.preAction
if (reward !== 1) {
console.info('giveReward >>>>> ', reward, state, prevState, action);
}
let curArr = globalQTable[state];
if (!curArr) {
curArr = this.initActions()
globalQTable[state] = curArr
}
let prevStateActions = globalQTable[prevState]
if (!prevStateActions) {
prevStateActions = this.initActions()
globalQTable[prevState] = prevStateActions
}
const maxQ = Math.max(...curArr);
const newQ = this.alpha * (this.gamma * maxQ + reward)
+ (1 - this.alpha) * prevStateActions[action]
globalQTable[prevState][action] = newQ.toFixed(3);
this.prevState = state;
}
updateTable() {
const Keys = Object.keys(globalQTable);
const trs = Keys.map(key => `<tr>
<td>${key}</td>
<td>${globalQTable[key][0]}</td>
<td>${globalQTable[key][1]}</td>
</tr>`)
const tableText = `
<table>
<tr>
<td>State</td>
<td>running</td>
<td>jump</td>
</tr>
${
trs.join('')
}
</table>
`
document.getElementById('table').innerHTML = tableText
}
}
89 changes: 89 additions & 0 deletions src/apps/q-learning.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* eslint-disable prefer-const */
/* eslint-disable radix */
import 'babel-polyfill';

import { CANVAS_WIDTH, CANVAS_HEIGHT } from '../game/constants';
import { Runner } from '../game';
import QlearningModel, { clearDB, saveDB } from '../ai/models/qlearning/qlearningModel';

const T_REX_COUNT = 1
const InitPrevState = [0, 0]
const qlResolution = 8

let runner = null;

function setup() {
// Initialize the game Runner.
runner = new Runner('.game', {
T_REX_COUNT,
onReset: handleReset,
onCrash: handleCrash,
onSuccess: handleSuccess,
onRunning: handleRunning
});
// Set runner as a global variable if you need runtime debugging.
window.runner = runner;
// Initialize everything in the game and start the game.
runner.init();
}

let firstTime = true;
function handleReset({ tRexes }) {
if (firstTime) {
firstTime = false;
tRexes.forEach(async (tRex) => {
if (!tRex.model) {
// Initialize all the tRexes with random models
// for the very first time.
tRex.model = new QlearningModel(2);
await tRex.model.init();
}
});
} else {
saveDB()
}
}
window.clearDB = () => {
clearDB()
}

function handleRunning({ tRex, state }) {
let action = 0;
const _state = convertStateToVector(state)
action = tRex.model.predictSingle(_state);
tRex.model.preAction = action
return action;
}

function handleCrash({ tRex, state }) {
// tRex.model.train();
const reward = -100
const curState = convertStateToVector(state)
tRex.model.giveReward(reward, curState)
console.info('crash status >>>>>', curState)

tRex.model.showQArr()
tRex.model.updateTable()
}

function handleSuccess({ tRex, state }) {
const reward = 1
const curState = convertStateToVector(state)
// eslint-disable-next-line prefer-const
tRex.model.giveReward(reward, curState)
// console.info('handleSuccess status >>>>>', curState)

}

function convertStateToVector(state) {
if (state) {
return [
parseInt(state.obstacleX / qlResolution),
// parseInt(state.obstacleWidth / qlResolution),
state.speed.toFixed(1),
].join(',');
}
return InitPrevState;
}

document.addEventListener('DOMContentLoaded', setup);
5 changes: 3 additions & 2 deletions src/game/Runner.js
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export default class Runner {
this.tRexGroup = new TrexGroup(this.config.T_REX_COUNT, this.canvas, this.spriteDef.TREX);
this.tRexGroup.onRunning = this.config.onRunning;
this.tRexGroup.onCrash = this.config.onCrash;
this.tRexGroup.onSuccess = this.config.onSuccess || noop;
this.tRex = this.tRexGroup.tRexes[0];

this.outerContainerEl.appendChild(this.containerEl);
Expand Down Expand Up @@ -320,9 +321,9 @@ export default class Runner {

const lives = this.tRexGroup.lives();
if (lives > 0) {
this.generationEl.innerText = `GENERATION #${Runner.generation} | LIVE x ${this.tRexGroup.lives()}`;
this.generationEl.innerText = `GENERATION(迭代次数) #${Runner.generation} | LIVE(存活个数) x ${this.tRexGroup.lives()}`;
} else {
this.generationEl.innerHTML = `<div style="color: red;">GENERATION #${Runner.generation} | GAME OVER</div>`;
this.generationEl.innerHTML = `<div style="color: red;">GENERATION(迭代次数) #${Runner.generation} | GAME OVER</div>`;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/game/TrexGroup.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export default class TrexGroup {
onReset = noop;
onRunning = noop;
onCrash = noop;
onSuccess = noop;

constructor(count, canvas, spriteDef) {
this.tRexes = [];
Expand Down Expand Up @@ -78,6 +79,7 @@ export default class TrexGroup {
tRex.setDuck(true);
}
}
this.onSuccess({ tRex, state });
}
} else {
crashes += 1;
Expand Down
1 change: 1 addition & 0 deletions webpack.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module.exports = {
entry: {
'genetic': ['./apps/genetic.js'],
'genetic-nn': ['./apps/genetic-nn.js'],
'q-learning': ['./apps/q-learning.js'],
nn: ['./apps/nn.js'],
nnm: ['./apps/nnm.js'],
random: ['./apps/random.js']
Expand Down
19 changes: 19 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3413,6 +3413,11 @@ image-size@~0.5.0:
resolved "https://registry.yarnpkg.com/image-size/-/image-size-0.5.5.tgz#09dfd4ab9d20e29eb1c3e80b8990378df9e3cb9c"
integrity sha1-Cd/Uq50g4p6xw+gLiZA3jfnjy5w=

immediate@~3.0.5:
version "3.0.6"
resolved "https://registry.yarnpkg.com/immediate/-/immediate-3.0.6.tgz#9db1dbd0faf8de6fbe0f5dd5e56bb606280de69b"
integrity sha1-nbHb0Pr43m++D13V5Wu2BigN5ps=

import-local@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/import-local/-/import-local-2.0.0.tgz#55070be38a5993cf18ef6db7e961f5bee5c5a09d"
Expand Down Expand Up @@ -3937,6 +3942,13 @@ levn@^0.3.0, levn@~0.3.0:
prelude-ls "~1.1.2"
type-check "~0.3.2"

lie@3.1.1:
version "3.1.1"
resolved "https://registry.yarnpkg.com/lie/-/lie-3.1.1.tgz#9a436b2cc7746ca59de7a41fa469b3efb76bd87e"
integrity sha1-mkNrLMd0bKWd56QfpGmz77dr2H4=
dependencies:
immediate "~3.0.5"

limit-it@^3.0.0:
version "3.2.10"
resolved "https://registry.yarnpkg.com/limit-it/-/limit-it-3.2.10.tgz#a0e12007c9e7aeb46296309bca39bd7646d82887"
Expand Down Expand Up @@ -3976,6 +3988,13 @@ loader-utils@^1.0.2, loader-utils@^1.1.0, loader-utils@^1.2.3, loader-utils@^1.4
emojis-list "^3.0.0"
json5 "^1.0.1"

localforage@^1.10.0:
version "1.10.0"
resolved "https://registry.yarnpkg.com/localforage/-/localforage-1.10.0.tgz#5c465dc5f62b2807c3a84c0c6a1b1b3212781dd4"
integrity sha512-14/H1aX7hzBBmmh7sGPd+AOMkkIrHM3Z1PAyGgZigA1H1p5O5ANnMyWzvpAETtG68/dC4pC0ncy3+PPGzXZHPg==
dependencies:
lie "3.1.1"

locate-path@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-2.0.0.tgz#2b568b265eec944c6d9c0de9c3dbbbca0354cd8e"
Expand Down