259 lines
8.2 KiB
JavaScript
259 lines
8.2 KiB
JavaScript
/**
|
|
* Tabular Q-Learning
|
|
*
|
|
* Classic Q-learning algorithm with:
|
|
* - Epsilon-greedy exploration
|
|
* - State hashing for continuous states
|
|
* - Eligibility traces (optional)
|
|
* - Experience replay
|
|
*
|
|
* Suitable for smaller state spaces or discretized environments.
|
|
* Performance Target: <1ms per update
|
|
*/
|
|
/**
|
|
* Default Q-Learning configuration
|
|
*/
|
|
export const DEFAULT_QLEARNING_CONFIG = {
|
|
algorithm: 'q-learning',
|
|
learningRate: 0.1,
|
|
gamma: 0.99,
|
|
entropyCoef: 0,
|
|
valueLossCoef: 1,
|
|
maxGradNorm: 1,
|
|
epochs: 1,
|
|
miniBatchSize: 1,
|
|
explorationInitial: 1.0,
|
|
explorationFinal: 0.01,
|
|
explorationDecay: 10000,
|
|
maxStates: 10000,
|
|
useEligibilityTraces: false,
|
|
traceDecay: 0.9,
|
|
};
|
|
/**
|
|
* Q-Learning Algorithm Implementation
|
|
*/
|
|
export class QLearning {
|
|
config;
|
|
// Q-table
|
|
qTable = new Map();
|
|
// Exploration
|
|
epsilon;
|
|
stepCount = 0;
|
|
// Number of actions
|
|
numActions = 4;
|
|
// Eligibility traces
|
|
traces = new Map();
|
|
// Statistics
|
|
updateCount = 0;
|
|
avgTDError = 0;
|
|
constructor(config = {}) {
|
|
this.config = { ...DEFAULT_QLEARNING_CONFIG, ...config };
|
|
this.epsilon = this.config.explorationInitial;
|
|
}
|
|
/**
|
|
* Update Q-values from trajectory
|
|
*/
|
|
update(trajectory) {
|
|
const startTime = performance.now();
|
|
if (trajectory.steps.length === 0) {
|
|
return { tdError: 0 };
|
|
}
|
|
let totalTDError = 0;
|
|
// Reset eligibility traces for new trajectory
|
|
if (this.config.useEligibilityTraces) {
|
|
this.traces.clear();
|
|
}
|
|
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
const step = trajectory.steps[i];
|
|
const stateKey = this.hashState(step.stateBefore);
|
|
const action = this.hashAction(step.action);
|
|
// Get or create Q-entry
|
|
const qEntry = this.getOrCreateEntry(stateKey);
|
|
// Current Q-value
|
|
const currentQ = qEntry.qValues[action];
|
|
// Compute target Q-value
|
|
let targetQ;
|
|
if (i === trajectory.steps.length - 1) {
|
|
// Terminal state
|
|
targetQ = step.reward;
|
|
}
|
|
else {
|
|
const nextStateKey = this.hashState(step.stateAfter);
|
|
const nextEntry = this.getOrCreateEntry(nextStateKey);
|
|
const maxNextQ = Math.max(...nextEntry.qValues);
|
|
targetQ = step.reward + this.config.gamma * maxNextQ;
|
|
}
|
|
// TD error
|
|
const tdError = targetQ - currentQ;
|
|
totalTDError += Math.abs(tdError);
|
|
if (this.config.useEligibilityTraces) {
|
|
// Update eligibility trace
|
|
this.updateTrace(stateKey, action);
|
|
// Update all states with traces
|
|
this.updateWithTraces(tdError);
|
|
}
|
|
else {
|
|
// Simple Q-learning update
|
|
qEntry.qValues[action] += this.config.learningRate * tdError;
|
|
qEntry.visits++;
|
|
qEntry.lastUpdate = Date.now();
|
|
}
|
|
}
|
|
// Decay exploration
|
|
this.stepCount += trajectory.steps.length;
|
|
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
// Prune Q-table if too large
|
|
if (this.qTable.size > this.config.maxStates) {
|
|
this.pruneQTable();
|
|
}
|
|
this.updateCount++;
|
|
this.avgTDError = totalTDError / trajectory.steps.length;
|
|
const elapsed = performance.now() - startTime;
|
|
if (elapsed > 1) {
|
|
console.warn(`Q-learning update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
|
|
}
|
|
return { tdError: this.avgTDError };
|
|
}
|
|
/**
|
|
* Get action using epsilon-greedy policy
|
|
*/
|
|
getAction(state, explore = true) {
|
|
if (explore && Math.random() < this.epsilon) {
|
|
return Math.floor(Math.random() * this.numActions);
|
|
}
|
|
const stateKey = this.hashState(state);
|
|
const entry = this.qTable.get(stateKey);
|
|
if (!entry) {
|
|
return Math.floor(Math.random() * this.numActions);
|
|
}
|
|
return this.argmax(entry.qValues);
|
|
}
|
|
/**
|
|
* Get Q-values for a state
|
|
*/
|
|
getQValues(state) {
|
|
const stateKey = this.hashState(state);
|
|
const entry = this.qTable.get(stateKey);
|
|
if (!entry) {
|
|
return new Float32Array(this.numActions);
|
|
}
|
|
return new Float32Array(entry.qValues);
|
|
}
|
|
/**
|
|
* Get statistics
|
|
*/
|
|
getStats() {
|
|
return {
|
|
updateCount: this.updateCount,
|
|
qTableSize: this.qTable.size,
|
|
epsilon: this.epsilon,
|
|
avgTDError: this.avgTDError,
|
|
stepCount: this.stepCount,
|
|
};
|
|
}
|
|
/**
|
|
* Reset Q-table
|
|
*/
|
|
reset() {
|
|
this.qTable.clear();
|
|
this.traces.clear();
|
|
this.epsilon = this.config.explorationInitial;
|
|
this.stepCount = 0;
|
|
this.updateCount = 0;
|
|
this.avgTDError = 0;
|
|
}
|
|
// ==========================================================================
|
|
// Private Methods
|
|
// ==========================================================================
|
|
hashState(state) {
|
|
// Discretize state by binning values
|
|
const bins = 10;
|
|
const parts = [];
|
|
// Use first 8 dimensions for hashing
|
|
for (let i = 0; i < Math.min(8, state.length); i++) {
|
|
const normalized = (state[i] + 1) / 2; // Assume [-1, 1] range
|
|
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
|
|
parts.push(bin);
|
|
}
|
|
return parts.join(',');
|
|
}
|
|
hashAction(action) {
|
|
let hash = 0;
|
|
for (let i = 0; i < action.length; i++) {
|
|
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
}
|
|
return hash;
|
|
}
|
|
getOrCreateEntry(stateKey) {
|
|
let entry = this.qTable.get(stateKey);
|
|
if (!entry) {
|
|
entry = {
|
|
qValues: new Float32Array(this.numActions),
|
|
visits: 0,
|
|
lastUpdate: Date.now(),
|
|
};
|
|
this.qTable.set(stateKey, entry);
|
|
}
|
|
return entry;
|
|
}
|
|
updateTrace(stateKey, action) {
|
|
// Decay all existing traces
|
|
for (const [key, trace] of this.traces) {
|
|
for (let a = 0; a < this.numActions; a++) {
|
|
trace[a] *= this.config.gamma * this.config.traceDecay;
|
|
}
|
|
// Remove near-zero traces
|
|
const maxTrace = Math.max(...trace);
|
|
if (maxTrace < 0.001) {
|
|
this.traces.delete(key);
|
|
}
|
|
}
|
|
// Set trace for current state-action
|
|
let trace = this.traces.get(stateKey);
|
|
if (!trace) {
|
|
trace = new Float32Array(this.numActions);
|
|
this.traces.set(stateKey, trace);
|
|
}
|
|
trace[action] = 1.0;
|
|
}
|
|
updateWithTraces(tdError) {
|
|
const lr = this.config.learningRate;
|
|
for (const [stateKey, trace] of this.traces) {
|
|
const entry = this.qTable.get(stateKey);
|
|
if (entry) {
|
|
for (let a = 0; a < this.numActions; a++) {
|
|
entry.qValues[a] += lr * tdError * trace[a];
|
|
}
|
|
entry.visits++;
|
|
entry.lastUpdate = Date.now();
|
|
}
|
|
}
|
|
}
|
|
pruneQTable() {
|
|
// Remove least recently used states
|
|
const entries = Array.from(this.qTable.entries())
|
|
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
|
|
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
|
|
for (let i = 0; i < toRemove; i++) {
|
|
this.qTable.delete(entries[i][0]);
|
|
}
|
|
}
|
|
argmax(values) {
|
|
let maxIdx = 0;
|
|
let maxVal = values[0];
|
|
for (let i = 1; i < values.length; i++) {
|
|
if (values[i] > maxVal) {
|
|
maxVal = values[i];
|
|
maxIdx = i;
|
|
}
|
|
}
|
|
return maxIdx;
|
|
}
|
|
}
|
|
/**
|
|
* Factory function
|
|
*/
|
|
export function createQLearning(config) {
|
|
return new QLearning(config);
|
|
}
|
|
//# sourceMappingURL=q-learning.js.map
|