/** * Deep Q-Network (DQN) * * Implements DQN with enhancements: * - Experience replay * - Target network * - Double DQN (optional) * - Dueling architecture (optional) * - Epsilon-greedy exploration * * Performance Target: <10ms per update step */ /** * Default DQN configuration */ export const DEFAULT_DQN_CONFIG = { algorithm: 'dqn', learningRate: 0.0001, gamma: 0.99, entropyCoef: 0, valueLossCoef: 1, maxGradNorm: 10, epochs: 1, miniBatchSize: 32, bufferSize: 10000, explorationInitial: 1.0, explorationFinal: 0.01, explorationDecay: 10000, targetUpdateFreq: 100, doubleDQN: true, duelingNetwork: false, }; /** * DQN Algorithm Implementation */ export class DQNAlgorithm { config; // Q-network weights qWeights; targetWeights; // Optimizer state qMomentum; // Replay buffer (circular) buffer = []; bufferIdx = 0; // Exploration epsilon; stepCount = 0; // Number of actions numActions = 4; inputDim = 768; // Statistics updateCount = 0; avgLoss = 0; constructor(config = {}) { this.config = { ...DEFAULT_DQN_CONFIG, ...config }; this.epsilon = this.config.explorationInitial; // Initialize Q-network (2 hidden layers) this.qWeights = this.initializeNetwork(); this.targetWeights = this.copyNetwork(this.qWeights); this.qMomentum = this.qWeights.map(w => new Float32Array(w.length)); } /** * Add experience from trajectory */ addExperience(trajectory) { for (let i = 0; i < trajectory.steps.length; i++) { const step = trajectory.steps[i]; const nextStep = i < trajectory.steps.length - 1 ? trajectory.steps[i + 1] : null; const experience = { state: step.stateBefore, action: this.hashAction(step.action), reward: step.reward, nextState: step.stateAfter, done: nextStep === null, }; // Add to circular buffer if (this.buffer.length < this.config.bufferSize) { this.buffer.push(experience); } else { this.buffer[this.bufferIdx] = experience; } this.bufferIdx = (this.bufferIdx + 1) % this.config.bufferSize; } } /** * Perform DQN update * Target: <10ms */ update() { const startTime = performance.now(); if (this.buffer.length < this.config.miniBatchSize) { return { loss: 0, epsilon: this.epsilon }; } // Sample mini-batch const batch = this.sampleBatch(); // Compute TD targets let totalLoss = 0; const gradients = this.qWeights.map(w => new Float32Array(w.length)); for (const exp of batch) { // Current Q-values const qValues = this.forward(exp.state, this.qWeights); const currentQ = qValues[exp.action]; // Target Q-value let targetQ; if (exp.done) { targetQ = exp.reward; } else { if (this.config.doubleDQN) { // Double DQN: use online network to select action, target to evaluate const nextQOnline = this.forward(exp.nextState, this.qWeights); const bestAction = this.argmax(nextQOnline); const nextQTarget = this.forward(exp.nextState, this.targetWeights); targetQ = exp.reward + this.config.gamma * nextQTarget[bestAction]; } else { // Standard DQN const nextQ = this.forward(exp.nextState, this.targetWeights); targetQ = exp.reward + this.config.gamma * Math.max(...nextQ); } } // TD error const tdError = targetQ - currentQ; const loss = tdError * tdError; totalLoss += loss; // Accumulate gradients this.accumulateGradients(gradients, exp.state, exp.action, tdError); } // Apply gradients this.applyGradients(gradients, batch.length); // Update target network periodically this.stepCount++; if (this.stepCount % this.config.targetUpdateFreq === 0) { this.targetWeights = this.copyNetwork(this.qWeights); } // Decay exploration this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay); this.updateCount++; this.avgLoss = totalLoss / batch.length; const elapsed = performance.now() - startTime; if (elapsed > 10) { console.warn(`DQN update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`); } return { loss: this.avgLoss, epsilon: this.epsilon, }; } /** * Get action using epsilon-greedy */ getAction(state, explore = true) { if (explore && Math.random() < this.epsilon) { return Math.floor(Math.random() * this.numActions); } const qValues = this.forward(state, this.qWeights); return this.argmax(qValues); } /** * Get Q-values for a state */ getQValues(state) { return this.forward(state, this.qWeights); } /** * Get statistics */ getStats() { return { updateCount: this.updateCount, bufferSize: this.buffer.length, epsilon: this.epsilon, avgLoss: this.avgLoss, stepCount: this.stepCount, }; } // ========================================================================== // Private Methods // ========================================================================== initializeNetwork() { // Simple 2-layer network: input -> hidden -> output const hiddenDim = 64; const weights = []; // Layer 1: input_dim -> hidden const w1 = new Float32Array(this.inputDim * hiddenDim); const scale1 = Math.sqrt(2 / this.inputDim); for (let i = 0; i < w1.length; i++) { w1[i] = (Math.random() - 0.5) * scale1; } weights.push(w1); // Layer 2: hidden -> num_actions const w2 = new Float32Array(hiddenDim * this.numActions); const scale2 = Math.sqrt(2 / hiddenDim); for (let i = 0; i < w2.length; i++) { w2[i] = (Math.random() - 0.5) * scale2; } weights.push(w2); return weights; } copyNetwork(weights) { return weights.map(w => new Float32Array(w)); } forward(state, weights) { const hiddenDim = 64; // Layer 1: ReLU(W1 * x) const hidden = new Float32Array(hiddenDim); for (let h = 0; h < hiddenDim; h++) { let sum = 0; for (let i = 0; i < Math.min(state.length, this.inputDim); i++) { sum += state[i] * weights[0][i * hiddenDim + h]; } hidden[h] = Math.max(0, sum); // ReLU } // Layer 2: W2 * hidden (no activation for Q-values) const output = new Float32Array(this.numActions); for (let a = 0; a < this.numActions; a++) { let sum = 0; for (let h = 0; h < hiddenDim; h++) { sum += hidden[h] * weights[1][h * this.numActions + a]; } output[a] = sum; } return output; } accumulateGradients(gradients, state, action, tdError) { const hiddenDim = 64; // Forward pass to get hidden activations const hidden = new Float32Array(hiddenDim); for (let h = 0; h < hiddenDim; h++) { let sum = 0; for (let i = 0; i < Math.min(state.length, this.inputDim); i++) { sum += state[i] * this.qWeights[0][i * hiddenDim + h]; } hidden[h] = Math.max(0, sum); } // Gradient for layer 2 (only for selected action) for (let h = 0; h < hiddenDim; h++) { gradients[1][h * this.numActions + action] += hidden[h] * tdError; } // Gradient for layer 1 (backprop through ReLU) for (let h = 0; h < hiddenDim; h++) { if (hidden[h] > 0) { // ReLU gradient const grad = tdError * this.qWeights[1][h * this.numActions + action]; for (let i = 0; i < Math.min(state.length, this.inputDim); i++) { gradients[0][i * hiddenDim + h] += state[i] * grad; } } } } applyGradients(gradients, batchSize) { const lr = this.config.learningRate / batchSize; const beta = 0.9; for (let layer = 0; layer < gradients.length; layer++) { for (let i = 0; i < gradients[layer].length; i++) { // Gradient clipping const grad = Math.max(Math.min(gradients[layer][i], this.config.maxGradNorm), -this.config.maxGradNorm); // Momentum update this.qMomentum[layer][i] = beta * this.qMomentum[layer][i] + (1 - beta) * grad; this.qWeights[layer][i] += lr * this.qMomentum[layer][i]; } } } sampleBatch() { const batch = []; const indices = new Set(); while (indices.size < this.config.miniBatchSize && indices.size < this.buffer.length) { indices.add(Math.floor(Math.random() * this.buffer.length)); } for (const idx of indices) { batch.push(this.buffer[idx]); } return batch; } hashAction(action) { let hash = 0; for (let i = 0; i < action.length; i++) { hash = (hash * 31 + action.charCodeAt(i)) % this.numActions; } return hash; } argmax(values) { let maxIdx = 0; let maxVal = values[0]; for (let i = 1; i < values.length; i++) { if (values[i] > maxVal) { maxVal = values[i]; maxIdx = i; } } return maxIdx; } } /** * Factory function */ export function createDQN(config) { return new DQNAlgorithm(config); } //# sourceMappingURL=dqn.js.map