tasq/node_modules/@claude-flow/neural/dist/algorithms/curiosity.js

392 lines
15 KiB
JavaScript

/**
* Curiosity-Driven Exploration
*
* Implements intrinsic motivation for exploration:
* - Intrinsic Curiosity Module (ICM)
* - Random Network Distillation (RND)
* - Forward and inverse dynamics models
* - Exploration bonus generation
*
* Performance Target: <5ms per forward pass
*/
/**
* Default Curiosity configuration
*/
export const DEFAULT_CURIOSITY_CONFIG = {
algorithm: 'curiosity',
learningRate: 0.0001,
gamma: 0.99,
entropyCoef: 0.01,
valueLossCoef: 0.5,
maxGradNorm: 0.5,
epochs: 1,
miniBatchSize: 32,
intrinsicCoef: 0.1,
forwardLR: 0.001,
inverseLR: 0.001,
featureDim: 64,
useRND: false,
};
/**
* Curiosity-Driven Exploration Module
*/
export class CuriosityModule {
config;
// Feature encoder
featureEncoder;
// Forward dynamics model: predicts next feature from current feature + action
forwardModel;
// Inverse dynamics model: predicts action from current and next features
inverseModel;
// RND target and predictor networks
rndTarget;
rndPredictor;
// Optimizer state
forwardMomentum;
inverseMomentum;
rndMomentum;
// Dimensions
stateDim = 768;
numActions = 4;
// Running statistics for normalization
intrinsicMean = 0;
intrinsicVar = 1;
updateCount = 0;
// Statistics
avgForwardLoss = 0;
avgInverseLoss = 0;
avgIntrinsicReward = 0;
constructor(config = {}) {
this.config = { ...DEFAULT_CURIOSITY_CONFIG, ...config };
const featureDim = this.config.featureDim;
// Initialize feature encoder: state_dim -> feature_dim
this.featureEncoder = this.initWeight(this.stateDim, featureDim);
// Forward model: (feature_dim + num_actions) -> feature_dim
this.forwardModel = this.initWeight(featureDim + this.numActions, featureDim);
// Inverse model: (2 * feature_dim) -> num_actions
this.inverseModel = this.initWeight(2 * featureDim, this.numActions);
// RND networks
this.rndTarget = this.initWeight(this.stateDim, featureDim);
this.rndPredictor = this.initWeight(this.stateDim, featureDim);
// Momentum buffers
this.forwardMomentum = new Float32Array(this.forwardModel.length);
this.inverseMomentum = new Float32Array(this.inverseModel.length);
this.rndMomentum = new Float32Array(this.rndPredictor.length);
}
/**
* Compute intrinsic reward for a transition
*/
computeIntrinsicReward(state, action, nextState) {
if (this.config.useRND) {
return this.computeRNDReward(nextState);
}
else {
return this.computeICMReward(state, action, nextState);
}
}
/**
* Compute ICM-based intrinsic reward (prediction error)
*/
computeICMReward(state, action, nextState) {
const startTime = performance.now();
// Encode states to features
const stateFeature = this.encodeState(state);
const nextStateFeature = this.encodeState(nextState);
// Predict next state feature
const actionIdx = this.hashAction(action);
const predictedFeature = this.forwardPredict(stateFeature, actionIdx);
// Compute prediction error as intrinsic reward
let error = 0;
for (let i = 0; i < this.config.featureDim; i++) {
error += (predictedFeature[i] - nextStateFeature[i]) ** 2;
}
// Normalize intrinsic reward
const intrinsic = this.normalizeIntrinsic(error);
const elapsed = performance.now() - startTime;
if (elapsed > 5) {
console.warn(`ICM reward exceeded target: ${elapsed.toFixed(2)}ms > 5ms`);
}
return intrinsic * this.config.intrinsicCoef;
}
/**
* Compute RND-based intrinsic reward
*/
computeRNDReward(state) {
const startTime = performance.now();
// Target network output (fixed random features)
const targetOutput = this.rndForward(state, this.rndTarget);
// Predictor network output (trained to match target)
const predictorOutput = this.rndForward(state, this.rndPredictor);
// Compute prediction error
let error = 0;
for (let i = 0; i < this.config.featureDim; i++) {
error += (predictorOutput[i] - targetOutput[i]) ** 2;
}
// Normalize
const intrinsic = this.normalizeIntrinsic(error);
const elapsed = performance.now() - startTime;
if (elapsed > 5) {
console.warn(`RND reward exceeded target: ${elapsed.toFixed(2)}ms > 5ms`);
}
return intrinsic * this.config.intrinsicCoef;
}
/**
* Update curiosity models from trajectory
*/
update(trajectory) {
const startTime = performance.now();
if (trajectory.steps.length < 2) {
return { forwardLoss: 0, inverseLoss: 0 };
}
let totalForwardLoss = 0;
let totalInverseLoss = 0;
let count = 0;
for (let i = 0; i < trajectory.steps.length - 1; i++) {
const step = trajectory.steps[i];
const nextStep = trajectory.steps[i + 1];
const stateFeature = this.encodeState(step.stateAfter);
const nextStateFeature = this.encodeState(nextStep.stateAfter);
const actionIdx = this.hashAction(step.action);
// Update forward model
const forwardLoss = this.updateForwardModel(stateFeature, actionIdx, nextStateFeature);
totalForwardLoss += forwardLoss;
// Update inverse model
const inverseLoss = this.updateInverseModel(stateFeature, nextStateFeature, actionIdx);
totalInverseLoss += inverseLoss;
// Update RND predictor if using RND
if (this.config.useRND) {
this.updateRNDPredictor(nextStep.stateAfter);
}
count++;
}
this.updateCount++;
this.avgForwardLoss = count > 0 ? totalForwardLoss / count : 0;
this.avgInverseLoss = count > 0 ? totalInverseLoss / count : 0;
const elapsed = performance.now() - startTime;
if (elapsed > 10) {
console.warn(`Curiosity update exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
}
return {
forwardLoss: this.avgForwardLoss,
inverseLoss: this.avgInverseLoss,
};
}
/**
* Add intrinsic rewards to trajectory
*/
augmentTrajectory(trajectory) {
const augmented = { ...trajectory, steps: [...trajectory.steps] };
for (let i = 0; i < augmented.steps.length - 1; i++) {
const step = augmented.steps[i];
const nextStep = augmented.steps[i + 1];
const intrinsic = this.computeIntrinsicReward(step.stateAfter, step.action, nextStep.stateAfter);
// Augment reward
augmented.steps[i] = {
...step,
reward: step.reward + intrinsic,
};
}
return augmented;
}
/**
* Get statistics
*/
getStats() {
return {
updateCount: this.updateCount,
avgForwardLoss: this.avgForwardLoss,
avgInverseLoss: this.avgInverseLoss,
avgIntrinsicReward: this.avgIntrinsicReward,
intrinsicMean: this.intrinsicMean,
intrinsicStd: Math.sqrt(this.intrinsicVar),
};
}
// ==========================================================================
// Private Methods
// ==========================================================================
initWeight(inputDim, outputDim) {
const weight = new Float32Array(inputDim * outputDim);
const scale = Math.sqrt(2 / inputDim);
for (let i = 0; i < weight.length; i++) {
weight[i] = (Math.random() - 0.5) * scale;
}
return weight;
}
encodeState(state) {
const featureDim = this.config.featureDim;
const feature = new Float32Array(featureDim);
for (let f = 0; f < featureDim; f++) {
let sum = 0;
for (let s = 0; s < Math.min(state.length, this.stateDim); s++) {
sum += state[s] * this.featureEncoder[s * featureDim + f];
}
feature[f] = Math.max(0, sum); // ReLU
}
return feature;
}
forwardPredict(stateFeature, action) {
const featureDim = this.config.featureDim;
const inputDim = featureDim + this.numActions;
const predicted = new Float32Array(featureDim);
// Concatenate feature and one-hot action
const input = new Float32Array(inputDim);
input.set(stateFeature);
input[featureDim + action] = 1;
// Forward pass
for (let f = 0; f < featureDim; f++) {
let sum = 0;
for (let i = 0; i < inputDim; i++) {
sum += input[i] * this.forwardModel[i * featureDim + f];
}
predicted[f] = sum;
}
return predicted;
}
inversePredict(stateFeature, nextStateFeature) {
const featureDim = this.config.featureDim;
const logits = new Float32Array(this.numActions);
// Concatenate features
const input = new Float32Array(2 * featureDim);
input.set(stateFeature);
input.set(nextStateFeature, featureDim);
// Forward pass
for (let a = 0; a < this.numActions; a++) {
let sum = 0;
for (let i = 0; i < 2 * featureDim; i++) {
sum += input[i] * this.inverseModel[i * this.numActions + a];
}
logits[a] = sum;
}
return this.softmax(logits);
}
rndForward(state, weights) {
const featureDim = this.config.featureDim;
const output = new Float32Array(featureDim);
for (let f = 0; f < featureDim; f++) {
let sum = 0;
for (let s = 0; s < Math.min(state.length, this.stateDim); s++) {
sum += state[s] * weights[s * featureDim + f];
}
output[f] = Math.max(0, sum); // ReLU
}
return output;
}
updateForwardModel(stateFeature, action, targetFeature) {
const featureDim = this.config.featureDim;
const inputDim = featureDim + this.numActions;
const lr = this.config.forwardLR;
const beta = 0.9;
// Forward pass
const predicted = this.forwardPredict(stateFeature, action);
// Compute loss and gradient
let loss = 0;
const grad = new Float32Array(predicted.length);
for (let f = 0; f < featureDim; f++) {
const diff = predicted[f] - targetFeature[f];
loss += diff * diff;
grad[f] = 2 * diff;
}
// Backprop to weights
const input = new Float32Array(inputDim);
input.set(stateFeature);
input[featureDim + action] = 1;
for (let i = 0; i < inputDim; i++) {
for (let f = 0; f < featureDim; f++) {
const weightGrad = input[i] * grad[f];
const idx = i * featureDim + f;
this.forwardMomentum[idx] = beta * this.forwardMomentum[idx] + (1 - beta) * weightGrad;
this.forwardModel[idx] -= lr * this.forwardMomentum[idx];
}
}
return loss;
}
updateInverseModel(stateFeature, nextStateFeature, targetAction) {
const featureDim = this.config.featureDim;
const lr = this.config.inverseLR;
const beta = 0.9;
// Forward pass
const probs = this.inversePredict(stateFeature, nextStateFeature);
// Cross-entropy loss
const loss = -Math.log(probs[targetAction] + 1e-8);
// Gradient
const grad = new Float32Array(this.numActions);
for (let a = 0; a < this.numActions; a++) {
grad[a] = probs[a] - (a === targetAction ? 1 : 0);
}
// Backprop to weights
const input = new Float32Array(2 * featureDim);
input.set(stateFeature);
input.set(nextStateFeature, featureDim);
for (let i = 0; i < 2 * featureDim; i++) {
for (let a = 0; a < this.numActions; a++) {
const weightGrad = input[i] * grad[a];
const idx = i * this.numActions + a;
this.inverseMomentum[idx] = beta * this.inverseMomentum[idx] + (1 - beta) * weightGrad;
this.inverseModel[idx] -= lr * this.inverseMomentum[idx];
}
}
return loss;
}
updateRNDPredictor(state) {
const featureDim = this.config.featureDim;
const lr = this.config.learningRate;
const beta = 0.9;
// Target output (fixed)
const targetOutput = this.rndForward(state, this.rndTarget);
// Predictor output
const predictorOutput = this.rndForward(state, this.rndPredictor);
// Gradient
const grad = new Float32Array(featureDim);
for (let f = 0; f < featureDim; f++) {
grad[f] = 2 * (predictorOutput[f] - targetOutput[f]);
}
// Update predictor weights
for (let s = 0; s < Math.min(state.length, this.stateDim); s++) {
for (let f = 0; f < featureDim; f++) {
if (predictorOutput[f] > 0) { // ReLU gradient
const weightGrad = state[s] * grad[f];
const idx = s * featureDim + f;
this.rndMomentum[idx] = beta * this.rndMomentum[idx] + (1 - beta) * weightGrad;
this.rndPredictor[idx] -= lr * this.rndMomentum[idx];
}
}
}
}
normalizeIntrinsic(raw) {
// Update running statistics
const alpha = 0.01;
this.intrinsicMean = (1 - alpha) * this.intrinsicMean + alpha * raw;
this.intrinsicVar = (1 - alpha) * this.intrinsicVar + alpha * (raw - this.intrinsicMean) ** 2;
// Normalize
const normalized = (raw - this.intrinsicMean) / (Math.sqrt(this.intrinsicVar) + 1e-8);
// Clip to reasonable range
return Math.max(-5, Math.min(5, normalized));
}
softmax(logits) {
const max = Math.max(...logits);
const exps = new Float32Array(logits.length);
let sum = 0;
for (let i = 0; i < logits.length; i++) {
exps[i] = Math.exp(logits[i] - max);
sum += exps[i];
}
for (let i = 0; i < exps.length; i++) {
exps[i] /= sum;
}
return exps;
}
hashAction(action) {
let hash = 0;
for (let i = 0; i < action.length; i++) {
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
}
return hash;
}
}
/**
* Factory function
*/
export function createCuriosity(config) {
return new CuriosityModule(config);
}
//# sourceMappingURL=curiosity.js.map