303 lines
10 KiB
JavaScript
303 lines
10 KiB
JavaScript
/**
|
|
* Deep Q-Network (DQN)
|
|
*
|
|
* Implements DQN with enhancements:
|
|
* - Experience replay
|
|
* - Target network
|
|
* - Double DQN (optional)
|
|
* - Dueling architecture (optional)
|
|
* - Epsilon-greedy exploration
|
|
*
|
|
* Performance Target: <10ms per update step
|
|
*/
|
|
/**
|
|
* Default DQN configuration
|
|
*/
|
|
export const DEFAULT_DQN_CONFIG = {
|
|
algorithm: 'dqn',
|
|
learningRate: 0.0001,
|
|
gamma: 0.99,
|
|
entropyCoef: 0,
|
|
valueLossCoef: 1,
|
|
maxGradNorm: 10,
|
|
epochs: 1,
|
|
miniBatchSize: 32,
|
|
bufferSize: 10000,
|
|
explorationInitial: 1.0,
|
|
explorationFinal: 0.01,
|
|
explorationDecay: 10000,
|
|
targetUpdateFreq: 100,
|
|
doubleDQN: true,
|
|
duelingNetwork: false,
|
|
};
|
|
/**
|
|
* DQN Algorithm Implementation
|
|
*/
|
|
export class DQNAlgorithm {
|
|
config;
|
|
// Q-network weights
|
|
qWeights;
|
|
targetWeights;
|
|
// Optimizer state
|
|
qMomentum;
|
|
// Replay buffer (circular)
|
|
buffer = [];
|
|
bufferIdx = 0;
|
|
// Exploration
|
|
epsilon;
|
|
stepCount = 0;
|
|
// Number of actions
|
|
numActions = 4;
|
|
inputDim = 768;
|
|
// Statistics
|
|
updateCount = 0;
|
|
avgLoss = 0;
|
|
constructor(config = {}) {
|
|
this.config = { ...DEFAULT_DQN_CONFIG, ...config };
|
|
this.epsilon = this.config.explorationInitial;
|
|
// Initialize Q-network (2 hidden layers)
|
|
this.qWeights = this.initializeNetwork();
|
|
this.targetWeights = this.copyNetwork(this.qWeights);
|
|
this.qMomentum = this.qWeights.map(w => new Float32Array(w.length));
|
|
}
|
|
/**
|
|
* Add experience from trajectory
|
|
*/
|
|
addExperience(trajectory) {
|
|
for (let i = 0; i < trajectory.steps.length; i++) {
|
|
const step = trajectory.steps[i];
|
|
const nextStep = i < trajectory.steps.length - 1
|
|
? trajectory.steps[i + 1]
|
|
: null;
|
|
const experience = {
|
|
state: step.stateBefore,
|
|
action: this.hashAction(step.action),
|
|
reward: step.reward,
|
|
nextState: step.stateAfter,
|
|
done: nextStep === null,
|
|
};
|
|
// Add to circular buffer
|
|
if (this.buffer.length < this.config.bufferSize) {
|
|
this.buffer.push(experience);
|
|
}
|
|
else {
|
|
this.buffer[this.bufferIdx] = experience;
|
|
}
|
|
this.bufferIdx = (this.bufferIdx + 1) % this.config.bufferSize;
|
|
}
|
|
}
|
|
/**
|
|
* Perform DQN update
|
|
* Target: <10ms
|
|
*/
|
|
update() {
|
|
const startTime = performance.now();
|
|
if (this.buffer.length < this.config.miniBatchSize) {
|
|
return { loss: 0, epsilon: this.epsilon };
|
|
}
|
|
// Sample mini-batch
|
|
const batch = this.sampleBatch();
|
|
// Compute TD targets
|
|
let totalLoss = 0;
|
|
const gradients = this.qWeights.map(w => new Float32Array(w.length));
|
|
for (const exp of batch) {
|
|
// Current Q-values
|
|
const qValues = this.forward(exp.state, this.qWeights);
|
|
const currentQ = qValues[exp.action];
|
|
// Target Q-value
|
|
let targetQ;
|
|
if (exp.done) {
|
|
targetQ = exp.reward;
|
|
}
|
|
else {
|
|
if (this.config.doubleDQN) {
|
|
// Double DQN: use online network to select action, target to evaluate
|
|
const nextQOnline = this.forward(exp.nextState, this.qWeights);
|
|
const bestAction = this.argmax(nextQOnline);
|
|
const nextQTarget = this.forward(exp.nextState, this.targetWeights);
|
|
targetQ = exp.reward + this.config.gamma * nextQTarget[bestAction];
|
|
}
|
|
else {
|
|
// Standard DQN
|
|
const nextQ = this.forward(exp.nextState, this.targetWeights);
|
|
targetQ = exp.reward + this.config.gamma * Math.max(...nextQ);
|
|
}
|
|
}
|
|
// TD error
|
|
const tdError = targetQ - currentQ;
|
|
const loss = tdError * tdError;
|
|
totalLoss += loss;
|
|
// Accumulate gradients
|
|
this.accumulateGradients(gradients, exp.state, exp.action, tdError);
|
|
}
|
|
// Apply gradients
|
|
this.applyGradients(gradients, batch.length);
|
|
// Update target network periodically
|
|
this.stepCount++;
|
|
if (this.stepCount % this.config.targetUpdateFreq === 0) {
|
|
this.targetWeights = this.copyNetwork(this.qWeights);
|
|
}
|
|
// Decay exploration
|
|
this.epsilon = Math.max(this.config.explorationFinal, this.config.explorationInitial - this.stepCount / this.config.explorationDecay);
|
|
this.updateCount++;
|
|
this.avgLoss = totalLoss / batch.length;
|
|
const elapsed = performance.now() - startTime;
|
|
if (elapsed > 10) {
|
|
console.warn(`DQN update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
}
|
|
return {
|
|
loss: this.avgLoss,
|
|
epsilon: this.epsilon,
|
|
};
|
|
}
|
|
/**
|
|
* Get action using epsilon-greedy
|
|
*/
|
|
getAction(state, explore = true) {
|
|
if (explore && Math.random() < this.epsilon) {
|
|
return Math.floor(Math.random() * this.numActions);
|
|
}
|
|
const qValues = this.forward(state, this.qWeights);
|
|
return this.argmax(qValues);
|
|
}
|
|
/**
|
|
* Get Q-values for a state
|
|
*/
|
|
getQValues(state) {
|
|
return this.forward(state, this.qWeights);
|
|
}
|
|
/**
|
|
* Get statistics
|
|
*/
|
|
getStats() {
|
|
return {
|
|
updateCount: this.updateCount,
|
|
bufferSize: this.buffer.length,
|
|
epsilon: this.epsilon,
|
|
avgLoss: this.avgLoss,
|
|
stepCount: this.stepCount,
|
|
};
|
|
}
|
|
// ==========================================================================
|
|
// Private Methods
|
|
// ==========================================================================
|
|
initializeNetwork() {
|
|
// Simple 2-layer network: input -> hidden -> output
|
|
const hiddenDim = 64;
|
|
const weights = [];
|
|
// Layer 1: input_dim -> hidden
|
|
const w1 = new Float32Array(this.inputDim * hiddenDim);
|
|
const scale1 = Math.sqrt(2 / this.inputDim);
|
|
for (let i = 0; i < w1.length; i++) {
|
|
w1[i] = (Math.random() - 0.5) * scale1;
|
|
}
|
|
weights.push(w1);
|
|
// Layer 2: hidden -> num_actions
|
|
const w2 = new Float32Array(hiddenDim * this.numActions);
|
|
const scale2 = Math.sqrt(2 / hiddenDim);
|
|
for (let i = 0; i < w2.length; i++) {
|
|
w2[i] = (Math.random() - 0.5) * scale2;
|
|
}
|
|
weights.push(w2);
|
|
return weights;
|
|
}
|
|
copyNetwork(weights) {
|
|
return weights.map(w => new Float32Array(w));
|
|
}
|
|
forward(state, weights) {
|
|
const hiddenDim = 64;
|
|
// Layer 1: ReLU(W1 * x)
|
|
const hidden = new Float32Array(hiddenDim);
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
let sum = 0;
|
|
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
sum += state[i] * weights[0][i * hiddenDim + h];
|
|
}
|
|
hidden[h] = Math.max(0, sum); // ReLU
|
|
}
|
|
// Layer 2: W2 * hidden (no activation for Q-values)
|
|
const output = new Float32Array(this.numActions);
|
|
for (let a = 0; a < this.numActions; a++) {
|
|
let sum = 0;
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
sum += hidden[h] * weights[1][h * this.numActions + a];
|
|
}
|
|
output[a] = sum;
|
|
}
|
|
return output;
|
|
}
|
|
accumulateGradients(gradients, state, action, tdError) {
|
|
const hiddenDim = 64;
|
|
// Forward pass to get hidden activations
|
|
const hidden = new Float32Array(hiddenDim);
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
let sum = 0;
|
|
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
sum += state[i] * this.qWeights[0][i * hiddenDim + h];
|
|
}
|
|
hidden[h] = Math.max(0, sum);
|
|
}
|
|
// Gradient for layer 2 (only for selected action)
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
gradients[1][h * this.numActions + action] += hidden[h] * tdError;
|
|
}
|
|
// Gradient for layer 1 (backprop through ReLU)
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
if (hidden[h] > 0) { // ReLU gradient
|
|
const grad = tdError * this.qWeights[1][h * this.numActions + action];
|
|
for (let i = 0; i < Math.min(state.length, this.inputDim); i++) {
|
|
gradients[0][i * hiddenDim + h] += state[i] * grad;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
applyGradients(gradients, batchSize) {
|
|
const lr = this.config.learningRate / batchSize;
|
|
const beta = 0.9;
|
|
for (let layer = 0; layer < gradients.length; layer++) {
|
|
for (let i = 0; i < gradients[layer].length; i++) {
|
|
// Gradient clipping
|
|
const grad = Math.max(Math.min(gradients[layer][i], this.config.maxGradNorm), -this.config.maxGradNorm);
|
|
// Momentum update
|
|
this.qMomentum[layer][i] = beta * this.qMomentum[layer][i] + (1 - beta) * grad;
|
|
this.qWeights[layer][i] += lr * this.qMomentum[layer][i];
|
|
}
|
|
}
|
|
}
|
|
sampleBatch() {
|
|
const batch = [];
|
|
const indices = new Set();
|
|
while (indices.size < this.config.miniBatchSize && indices.size < this.buffer.length) {
|
|
indices.add(Math.floor(Math.random() * this.buffer.length));
|
|
}
|
|
for (const idx of indices) {
|
|
batch.push(this.buffer[idx]);
|
|
}
|
|
return batch;
|
|
}
|
|
hashAction(action) {
|
|
let hash = 0;
|
|
for (let i = 0; i < action.length; i++) {
|
|
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
}
|
|
return hash;
|
|
}
|
|
argmax(values) {
|
|
let maxIdx = 0;
|
|
let maxVal = values[0];
|
|
for (let i = 1; i < values.length; i++) {
|
|
if (values[i] > maxVal) {
|
|
maxVal = values[i];
|
|
maxIdx = i;
|
|
}
|
|
}
|
|
return maxIdx;
|
|
}
|
|
}
|
|
/**
|
|
* Factory function
|
|
*/
|
|
export function createDQN(config) {
|
|
return new DQNAlgorithm(config);
|
|
}
|
|
//# sourceMappingURL=dqn.js.map
|