diff --git a/.gitignore b/.gitignore index f2ea567..8759396 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ node_modules/ # Package package-lock.json +.history diff --git a/index.html b/index.html index 156df5f..701e3c9 100644 --- a/index.html +++ b/index.html @@ -14,6 +14,7 @@

Welcome to T-Rex Run

  • Neural Network
  • Neural Network for Multiplayers
  • Genetic + Neural Network
  • +
  • RL--q-learning
  • diff --git a/package.json b/package.json index 666ed03..7ed208c 100644 --- a/package.json +++ b/package.json @@ -47,6 +47,7 @@ "dependencies": { "@tensorflow/tfjs": "^0.9.0", "babel-polyfill": "^6.26.0", + "localforage": "^1.10.0", "package.json": "^2.0.1" } } diff --git a/q-learning.html b/q-learning.html new file mode 100644 index 0000000..93b89a2 --- /dev/null +++ b/q-learning.html @@ -0,0 +1,31 @@ + + + + + + + q-learning - T-Rex Runner + + + + +

    q-learning

    +
    +
    +
    +

    q-Table

    +
    + +
    + + + diff --git a/src/ai/models/qlearning/qlearningModel.js b/src/ai/models/qlearning/qlearningModel.js new file mode 100644 index 0000000..2033d4c --- /dev/null +++ b/src/ai/models/qlearning/qlearningModel.js @@ -0,0 +1,134 @@ + +import localforage from 'localforage' +import Model from '../Model'; + +export const Key = 'qlearning-chrome://dino' +localforage.config({ + name: Key +}); +let globalQTable = {} + +export function saveDB() { + localforage.setItem(Key, globalQTable) +} +export function clearDB() { + globalQTable = {} + saveDB() +} + +function getRandomInt(min, max) { + min = Math.ceil(min); + max = Math.floor(max); + return Math.floor(Math.random() * (max - min + 1)) + min; +} +function indexOfMax(arr) { + if (arr.length === 0) { + return -1; + } + let max = arr[0]; + let maxIndex = 0; + for (let i = 1; i < arr.length; i+=1) { + if (arr[i] > max) { + maxIndex = i; + max = arr[i]; + } + } + return maxIndex; +} +export default class QNetwork extends Model{ + constructor(actions, states) { + super(actions, states) + this.actions = actions; // 0 runing 1 jump + this.epsilon = 0.5; // greedy + this.epsilon_decay = 0.001; + this.currentState = 0; + this.alpha = 0.2 + this.gamma = 0.9 + this.prevState = '0,0' + this.preAction = 0 + } + initActions() { + return Array.from({length: 2}).map(item => 0) + } + async getQArrFromDB() { + const result = await localforage.getItem(Key) + return result + } + + async init(actionLen) { + const result = await this.getQArrFromDB() + console.info(`from indexedDB get qArr(${Object.keys(globalQTable).length})`, result) + if (result) { + globalQTable = result + } + } + showQArr() { + console.info(`>>>>>>> this.qTable(${Object.keys(globalQTable).length}): `) + } + + predict(inputXs) { + const inputX = inputXs[0] + // console.info('q-table >>>> ', globalQTable, inputX) + return this.think(inputX) + } + + think(state) { + this.currentState = globalQTable[state]; + if (!this.currentState) { + this.currentState = this.initActions() + globalQTable[state] = this.currentState + } + let action = null; + if (Math.random() < this.epsilon) { + action = 0 // getRandomInt(0,(this.actions - 1)) + } else { + action = indexOfMax(this.currentState); + } + this.epsilon = this.epsilon - (this.epsilon_decay * this.epsilon); + return action; + } + + giveReward(reward, state) { + const prevState = this.prevState + const action = this.preAction + if (reward !== 1) { + console.info('giveReward >>>>> ', reward, state, prevState, action); + } + let curArr = globalQTable[state]; + if (!curArr) { + curArr = this.initActions() + globalQTable[state] = curArr + } + let prevStateActions = globalQTable[prevState] + if (!prevStateActions) { + prevStateActions = this.initActions() + globalQTable[prevState] = prevStateActions + } + const maxQ = Math.max(...curArr); + const newQ = this.alpha * (this.gamma * maxQ + reward) + + (1 - this.alpha) * prevStateActions[action] + globalQTable[prevState][action] = newQ.toFixed(3); + this.prevState = state; + } + updateTable() { + const Keys = Object.keys(globalQTable); + const trs = Keys.map(key => ` + ${key} + ${globalQTable[key][0]} + ${globalQTable[key][1]} + `) + const tableText = ` + + + + + + + ${ + trs.join('') + } +
    Staterunningjump
    + ` + document.getElementById('table').innerHTML = tableText + } +} \ No newline at end of file diff --git a/src/apps/q-learning.js b/src/apps/q-learning.js new file mode 100644 index 0000000..3981f34 --- /dev/null +++ b/src/apps/q-learning.js @@ -0,0 +1,89 @@ +/* eslint-disable prefer-const */ +/* eslint-disable radix */ +import 'babel-polyfill'; + +import { CANVAS_WIDTH, CANVAS_HEIGHT } from '../game/constants'; +import { Runner } from '../game'; +import QlearningModel, { clearDB, saveDB } from '../ai/models/qlearning/qlearningModel'; + +const T_REX_COUNT = 1 +const InitPrevState = [0, 0] +const qlResolution = 8 + +let runner = null; + +function setup() { + // Initialize the game Runner. + runner = new Runner('.game', { + T_REX_COUNT, + onReset: handleReset, + onCrash: handleCrash, + onSuccess: handleSuccess, + onRunning: handleRunning + }); + // Set runner as a global variable if you need runtime debugging. + window.runner = runner; + // Initialize everything in the game and start the game. + runner.init(); +} + +let firstTime = true; +function handleReset({ tRexes }) { + if (firstTime) { + firstTime = false; + tRexes.forEach(async (tRex) => { + if (!tRex.model) { + // Initialize all the tRexes with random models + // for the very first time. + tRex.model = new QlearningModel(2); + await tRex.model.init(); + } + }); + } else { + saveDB() + } +} +window.clearDB = () => { + clearDB() +} + +function handleRunning({ tRex, state }) { + let action = 0; + const _state = convertStateToVector(state) + action = tRex.model.predictSingle(_state); + tRex.model.preAction = action + return action; +} + +function handleCrash({ tRex, state }) { + // tRex.model.train(); + const reward = -100 + const curState = convertStateToVector(state) + tRex.model.giveReward(reward, curState) + console.info('crash status >>>>>', curState) + + tRex.model.showQArr() + tRex.model.updateTable() +} + +function handleSuccess({ tRex, state }) { + const reward = 1 + const curState = convertStateToVector(state) + // eslint-disable-next-line prefer-const + tRex.model.giveReward(reward, curState) + // console.info('handleSuccess status >>>>>', curState) + +} + +function convertStateToVector(state) { + if (state) { + return [ + parseInt(state.obstacleX / qlResolution), + // parseInt(state.obstacleWidth / qlResolution), + state.speed.toFixed(1), + ].join(','); + } + return InitPrevState; +} + +document.addEventListener('DOMContentLoaded', setup); diff --git a/src/game/Runner.js b/src/game/Runner.js index 479eb2a..f7f4a7f 100644 --- a/src/game/Runner.js +++ b/src/game/Runner.js @@ -170,6 +170,7 @@ export default class Runner { this.tRexGroup = new TrexGroup(this.config.T_REX_COUNT, this.canvas, this.spriteDef.TREX); this.tRexGroup.onRunning = this.config.onRunning; this.tRexGroup.onCrash = this.config.onCrash; + this.tRexGroup.onSuccess = this.config.onSuccess || noop; this.tRex = this.tRexGroup.tRexes[0]; this.outerContainerEl.appendChild(this.containerEl); @@ -320,9 +321,9 @@ export default class Runner { const lives = this.tRexGroup.lives(); if (lives > 0) { - this.generationEl.innerText = `GENERATION #${Runner.generation} | LIVE x ${this.tRexGroup.lives()}`; + this.generationEl.innerText = `GENERATION(迭代次数) #${Runner.generation} | LIVE(存活个数) x ${this.tRexGroup.lives()}`; } else { - this.generationEl.innerHTML = `
    GENERATION #${Runner.generation} | GAME OVER
    `; + this.generationEl.innerHTML = `
    GENERATION(迭代次数) #${Runner.generation} | GAME OVER
    `; } } diff --git a/src/game/TrexGroup.js b/src/game/TrexGroup.js index a7f8048..687506c 100644 --- a/src/game/TrexGroup.js +++ b/src/game/TrexGroup.js @@ -5,6 +5,7 @@ export default class TrexGroup { onReset = noop; onRunning = noop; onCrash = noop; + onSuccess = noop; constructor(count, canvas, spriteDef) { this.tRexes = []; @@ -78,6 +79,7 @@ export default class TrexGroup { tRex.setDuck(true); } } + this.onSuccess({ tRex, state }); } } else { crashes += 1; diff --git a/webpack.config.js b/webpack.config.js index c8dab54..251d8c5 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -10,6 +10,7 @@ module.exports = { entry: { 'genetic': ['./apps/genetic.js'], 'genetic-nn': ['./apps/genetic-nn.js'], + 'q-learning': ['./apps/q-learning.js'], nn: ['./apps/nn.js'], nnm: ['./apps/nnm.js'], random: ['./apps/random.js'] diff --git a/yarn.lock b/yarn.lock index 4072ac6..465f3f5 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3413,6 +3413,11 @@ image-size@~0.5.0: resolved "https://registry.yarnpkg.com/image-size/-/image-size-0.5.5.tgz#09dfd4ab9d20e29eb1c3e80b8990378df9e3cb9c" integrity sha1-Cd/Uq50g4p6xw+gLiZA3jfnjy5w= +immediate@~3.0.5: + version "3.0.6" + resolved "https://registry.yarnpkg.com/immediate/-/immediate-3.0.6.tgz#9db1dbd0faf8de6fbe0f5dd5e56bb606280de69b" + integrity sha1-nbHb0Pr43m++D13V5Wu2BigN5ps= + import-local@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/import-local/-/import-local-2.0.0.tgz#55070be38a5993cf18ef6db7e961f5bee5c5a09d" @@ -3937,6 +3942,13 @@ levn@^0.3.0, levn@~0.3.0: prelude-ls "~1.1.2" type-check "~0.3.2" +lie@3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/lie/-/lie-3.1.1.tgz#9a436b2cc7746ca59de7a41fa469b3efb76bd87e" + integrity sha1-mkNrLMd0bKWd56QfpGmz77dr2H4= + dependencies: + immediate "~3.0.5" + limit-it@^3.0.0: version "3.2.10" resolved "https://registry.yarnpkg.com/limit-it/-/limit-it-3.2.10.tgz#a0e12007c9e7aeb46296309bca39bd7646d82887" @@ -3976,6 +3988,13 @@ loader-utils@^1.0.2, loader-utils@^1.1.0, loader-utils@^1.2.3, loader-utils@^1.4 emojis-list "^3.0.0" json5 "^1.0.1" +localforage@^1.10.0: + version "1.10.0" + resolved "https://registry.yarnpkg.com/localforage/-/localforage-1.10.0.tgz#5c465dc5f62b2807c3a84c0c6a1b1b3212781dd4" + integrity sha512-14/H1aX7hzBBmmh7sGPd+AOMkkIrHM3Z1PAyGgZigA1H1p5O5ANnMyWzvpAETtG68/dC4pC0ncy3+PPGzXZHPg== + dependencies: + lie "3.1.1" + locate-path@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-2.0.0.tgz#2b568b265eec944c6d9c0de9c3dbbbca0354cd8e"