tasq/node_modules/@claude-flow/neural/dist/algorithms/q-learning.js

259 lines
8.2 KiB
JavaScript

/**
* Tabular Q-Learning
*
* Classic Q-learning algorithm with:
* - Epsilon-greedy exploration
* - State hashing for continuous states
* - Eligibility traces (optional)
* - Experience replay
*
* Suitable for smaller state spaces or discretized environments.
* Performance Target: <1ms per update
*/
/**
* Default Q-Learning configuration
*/
export const DEFAULT_QLEARNING_CONFIG = {
algorithm: 'q-learning',
learningRate: 0.1,
gamma: 0.99,
entropyCoef: 0,
valueLossCoef: 1,
maxGradNorm: 1,
epochs: 1,
miniBatchSize: 1,
explorationInitial: 1.0,
explorationFinal: 0.01,
explorationDecay: 10000,
maxStates: 10000,
useEligibilityTraces: false,
traceDecay: 0.9,
};
/**
* Q-Learning Algorithm Implementation
*/
export class QLearning {
config;
// Q-table
qTable = new Map();
// Exploration
epsilon;
stepCount = 0;
// Number of actions
numActions = 4;
// Eligibility traces
traces = new Map();
// Statistics
updateCount = 0;
avgTDError = 0;
constructor(config = {}) {
this.config = { ...DEFAULT_QLEARNING_CONFIG, ...config };
this.epsilon = this.config.explorationInitial;
}
/**
* Update Q-values from trajectory
*/
update(trajectory) {
const startTime = performance.now();
if (trajectory.steps.length === 0) {
return { tdError: 0 };
}
let totalTDError = 0;
// Reset eligibility traces for new trajectory
if (this.config.useEligibilityTraces) {
this.traces.clear();
}
for (let i = 0; i < trajectory.steps.length; i++) {
const step = trajectory.steps[i];
const stateKey = this.hashState(step.stateBefore);
const action = this.hashAction(step.action);
// Get or create Q-entry
const qEntry = this.getOrCreateEntry(stateKey);
// Current Q-value
const currentQ = qEntry.qValues[action];
// Compute target Q-value
let targetQ;
if (i === trajectory.steps.length - 1) {
// Terminal state
targetQ = step.reward;
}
else {
const nextStateKey = this.hashState(step.stateAfter);
const nextEntry = this.getOrCreateEntry(nextStateKey);
const maxNextQ = Math.max(...nextEntry.qValues);
targetQ = step.reward + this.config.gamma * maxNextQ;
}
// TD error
const tdError = targetQ - currentQ;
totalTDError += Math.abs(tdError);
if (this.config.useEligibilityTraces) {
// Update eligibility trace
this.updateTrace(stateKey, action);
// Update all states with traces
this.updateWithTraces(tdError);
}
else {
// Simple Q-learning update
qEntry.qValues[action] += this.config.learningRate * tdError;
qEntry.visits++;
qEntry.lastUpdate = Date.now();
}
}
// Decay exploration
this.stepCount += trajectory.steps.length;
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
// Prune Q-table if too large
if (this.qTable.size > this.config.maxStates) {
this.pruneQTable();
}
this.updateCount++;
this.avgTDError = totalTDError / trajectory.steps.length;
const elapsed = performance.now() - startTime;
if (elapsed > 1) {
console.warn(`Q-learning update exceeded target: ${elapsed.toFixed(2)}ms > 1ms`);
}
return { tdError: this.avgTDError };
}
/**
* Get action using epsilon-greedy policy
*/
getAction(state, explore = true) {
if (explore && Math.random() < this.epsilon) {
return Math.floor(Math.random() * this.numActions);
}
const stateKey = this.hashState(state);
const entry = this.qTable.get(stateKey);
if (!entry) {
return Math.floor(Math.random() * this.numActions);
}
return this.argmax(entry.qValues);
}
/**
* Get Q-values for a state
*/
getQValues(state) {
const stateKey = this.hashState(state);
const entry = this.qTable.get(stateKey);
if (!entry) {
return new Float32Array(this.numActions);
}
return new Float32Array(entry.qValues);
}
/**
* Get statistics
*/
getStats() {
return {
updateCount: this.updateCount,
qTableSize: this.qTable.size,
epsilon: this.epsilon,
avgTDError: this.avgTDError,
stepCount: this.stepCount,
};
}
/**
* Reset Q-table
*/
reset() {
this.qTable.clear();
this.traces.clear();
this.epsilon = this.config.explorationInitial;
this.stepCount = 0;
this.updateCount = 0;
this.avgTDError = 0;
}
// ==========================================================================
// Private Methods
// ==========================================================================
hashState(state) {
// Discretize state by binning values
const bins = 10;
const parts = [];
// Use first 8 dimensions for hashing
for (let i = 0; i < Math.min(8, state.length); i++) {
const normalized = (state[i] + 1) / 2; // Assume [-1, 1] range
const bin = Math.floor(Math.max(0, Math.min(bins - 1, normalized * bins)));
parts.push(bin);
}
return parts.join(',');
}
hashAction(action) {
let hash = 0;
for (let i = 0; i < action.length; i++) {
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
}
return hash;
}
getOrCreateEntry(stateKey) {
let entry = this.qTable.get(stateKey);
if (!entry) {
entry = {
qValues: new Float32Array(this.numActions),
visits: 0,
lastUpdate: Date.now(),
};
this.qTable.set(stateKey, entry);
}
return entry;
}
updateTrace(stateKey, action) {
// Decay all existing traces
for (const [key, trace] of this.traces) {
for (let a = 0; a < this.numActions; a++) {
trace[a] *= this.config.gamma * this.config.traceDecay;
}
// Remove near-zero traces
const maxTrace = Math.max(...trace);
if (maxTrace < 0.001) {
this.traces.delete(key);
}
}
// Set trace for current state-action
let trace = this.traces.get(stateKey);
if (!trace) {
trace = new Float32Array(this.numActions);
this.traces.set(stateKey, trace);
}
trace[action] = 1.0;
}
updateWithTraces(tdError) {
const lr = this.config.learningRate;
for (const [stateKey, trace] of this.traces) {
const entry = this.qTable.get(stateKey);
if (entry) {
for (let a = 0; a < this.numActions; a++) {
entry.qValues[a] += lr * tdError * trace[a];
}
entry.visits++;
entry.lastUpdate = Date.now();
}
}
}
pruneQTable() {
// Remove least recently used states
const entries = Array.from(this.qTable.entries())
.sort((a, b) => a[1].lastUpdate - b[1].lastUpdate);
const toRemove = entries.length - Math.floor(this.config.maxStates * 0.8);
for (let i = 0; i < toRemove; i++) {
this.qTable.delete(entries[i][0]);
}
}
argmax(values) {
let maxIdx = 0;
let maxVal = values[0];
for (let i = 1; i < values.length; i++) {
if (values[i] > maxVal) {
maxVal = values[i];
maxIdx = i;
}
}
return maxIdx;
}
}
/**
* Factory function
*/
export function createQLearning(config) {
return new QLearning(config);
}
//# sourceMappingURL=q-learning.js.map