tasq/node_modules/@claude-flow/neural/__tests__/algorithms.test.ts

583 lines
16 KiB
TypeScript

/**
* RL Algorithms Tests
*
* Tests for reinforcement learning algorithms:
* - Q-Learning
* - SARSA
* - DQN
* - PPO
* - Decision Transformer
*
* Performance target: <10ms per update
*/
import { describe, it, expect, beforeEach } from 'vitest';
import { QLearning, createQLearning } from '../src/algorithms/q-learning.js';
import { SARSAAlgorithm, createSARSA } from '../src/algorithms/sarsa.js';
import { DQNAlgorithm, createDQN } from '../src/algorithms/dqn.js';
import { PPOAlgorithm, createPPO } from '../src/algorithms/ppo.js';
import { DecisionTransformer, createDecisionTransformer } from '../src/algorithms/decision-transformer.js';
import type { Trajectory } from '../src/types.js';
// Helper function to create test trajectories
function createTestTrajectory(steps: number = 5): Trajectory {
return {
trajectoryId: `test-traj-${Date.now()}`,
context: 'Test task',
domain: 'code',
steps: Array.from({ length: steps }, (_, i) => ({
stepId: `step-${i}`,
timestamp: Date.now() + i * 100,
action: `action-${i % 4}`, // 4 discrete actions
stateBefore: new Float32Array(768).fill(i * 0.1),
stateAfter: new Float32Array(768).fill((i + 1) * 0.1),
reward: 0.5 + (i / steps) * 0.5, // Increasing rewards
})),
qualityScore: 0.75,
isComplete: true,
startTime: Date.now() - 1000,
endTime: Date.now(),
};
}
describe('Q-Learning Algorithm', () => {
let qlearning: QLearning;
beforeEach(() => {
qlearning = createQLearning({
learningRate: 0.1,
gamma: 0.99,
explorationInitial: 1.0,
explorationFinal: 0.01,
explorationDecay: 1000,
});
});
it('should initialize correctly', () => {
expect(qlearning).toBeDefined();
const stats = qlearning.getStats();
expect(stats.updateCount).toBe(0);
expect(stats.qTableSize).toBe(0);
expect(stats.epsilon).toBeCloseTo(1.0);
});
it('should update Q-values from trajectory', () => {
const trajectory = createTestTrajectory(5);
const result = qlearning.update(trajectory);
expect(result.tdError).toBeGreaterThanOrEqual(0);
const stats = qlearning.getStats();
expect(stats.updateCount).toBe(1);
expect(stats.qTableSize).toBeGreaterThan(0);
});
it('should update under performance target (<1ms)', () => {
const trajectory = createTestTrajectory(10);
const startTime = performance.now();
qlearning.update(trajectory);
const elapsed = performance.now() - startTime;
expect(elapsed).toBeLessThan(10); // Reasonable target for small trajectories
});
it('should decay exploration rate', () => {
const trajectory = createTestTrajectory(5);
const initialEpsilon = qlearning.getStats().epsilon;
for (let i = 0; i < 10; i++) {
qlearning.update(trajectory);
}
const finalEpsilon = qlearning.getStats().epsilon;
expect(finalEpsilon).toBeLessThan(initialEpsilon);
});
it('should select actions with epsilon-greedy', () => {
const state = new Float32Array(768).fill(0.5);
// First call should be random (high epsilon)
const action1 = qlearning.getAction(state, true);
expect(action1).toBeGreaterThanOrEqual(0);
expect(action1).toBeLessThan(4);
// Without exploration, should be deterministic
const action2 = qlearning.getAction(state, false);
expect(action2).toBeDefined();
});
it('should return Q-values for a state', () => {
const trajectory = createTestTrajectory(5);
qlearning.update(trajectory);
const state = new Float32Array(768).fill(0.5);
const qValues = qlearning.getQValues(state);
expect(qValues).toBeInstanceOf(Float32Array);
expect(qValues.length).toBe(4);
});
it('should handle eligibility traces', () => {
const qlearningWithTraces = createQLearning({
useEligibilityTraces: true,
traceDecay: 0.9,
});
const trajectory = createTestTrajectory(10);
expect(() => qlearningWithTraces.update(trajectory)).not.toThrow();
});
it('should prune Q-table when over capacity', () => {
const smallQLearning = createQLearning({
maxStates: 10,
});
// Add many different trajectories to fill Q-table
for (let i = 0; i < 20; i++) {
const trajectory = createTestTrajectory(5);
smallQLearning.update(trajectory);
}
const stats = smallQLearning.getStats();
expect(stats.qTableSize).toBeLessThanOrEqual(10);
});
it('should reset correctly', () => {
const trajectory = createTestTrajectory(5);
qlearning.update(trajectory);
qlearning.reset();
const stats = qlearning.getStats();
expect(stats.updateCount).toBe(0);
expect(stats.qTableSize).toBe(0);
expect(stats.epsilon).toBeCloseTo(1.0);
});
});
describe('SARSA Algorithm', () => {
let sarsa: SARSAAlgorithm;
beforeEach(() => {
sarsa = createSARSA({
learningRate: 0.1,
gamma: 0.99,
explorationInitial: 1.0,
explorationFinal: 0.01,
explorationDecay: 1000,
});
});
it('should initialize correctly', () => {
expect(sarsa).toBeDefined();
const stats = sarsa.getStats();
expect(stats.updateCount).toBe(0);
expect(stats.qTableSize).toBe(0);
});
it('should update using SARSA rule', () => {
const trajectory = createTestTrajectory(5);
const result = sarsa.update(trajectory);
expect(result.tdError).toBeGreaterThanOrEqual(0);
const stats = sarsa.getStats();
expect(stats.updateCount).toBe(1);
});
it('should handle expected SARSA variant', () => {
const expectedSARSA = createSARSA({
useExpectedSARSA: true,
});
const trajectory = createTestTrajectory(5);
expect(() => expectedSARSA.update(trajectory)).not.toThrow();
});
it('should return action probabilities', () => {
const state = new Float32Array(768).fill(0.5);
const probs = sarsa.getActionProbabilities(state);
expect(probs).toBeInstanceOf(Float32Array);
expect(probs.length).toBe(4);
// Probabilities should sum to ~1
const sum = Array.from(probs).reduce((a, b) => a + b, 0);
expect(sum).toBeCloseTo(1.0, 2);
});
it('should select actions with epsilon-greedy policy', () => {
const state = new Float32Array(768).fill(0.5);
const action = sarsa.getAction(state, true);
expect(action).toBeGreaterThanOrEqual(0);
expect(action).toBeLessThan(4);
});
it('should handle eligibility traces (SARSA-lambda)', () => {
const sarsaLambda = createSARSA({
useEligibilityTraces: true,
traceDecay: 0.9,
});
const trajectory = createTestTrajectory(10);
expect(() => sarsaLambda.update(trajectory)).not.toThrow();
});
it('should handle short trajectories gracefully', () => {
const shortTrajectory = createTestTrajectory(1);
const result = sarsa.update(shortTrajectory);
expect(result.tdError).toBe(0); // Not enough steps for SARSA
});
it('should reset algorithm state', () => {
const trajectory = createTestTrajectory(5);
sarsa.update(trajectory);
sarsa.reset();
const stats = sarsa.getStats();
expect(stats.updateCount).toBe(0);
expect(stats.qTableSize).toBe(0);
});
});
describe('DQN Algorithm', () => {
let dqn: DQNAlgorithm;
beforeEach(() => {
dqn = createDQN({
learningRate: 0.0001,
bufferSize: 1000,
miniBatchSize: 32,
doubleDQN: true,
targetUpdateFreq: 100,
});
});
it('should initialize correctly', () => {
expect(dqn).toBeDefined();
const stats = dqn.getStats();
expect(stats.updateCount).toBe(0);
expect(stats.bufferSize).toBe(0);
});
it('should add experience to replay buffer', () => {
const trajectory = createTestTrajectory(10);
dqn.addExperience(trajectory);
const stats = dqn.getStats();
expect(stats.bufferSize).toBe(10);
});
it('should perform DQN update', () => {
// Add enough experiences
for (let i = 0; i < 5; i++) {
dqn.addExperience(createTestTrajectory(10));
}
const result = dqn.update();
expect(result.loss).toBeGreaterThanOrEqual(0);
expect(result.epsilon).toBeGreaterThan(0);
});
it('should update under performance target (<10ms)', () => {
// Add experiences
for (let i = 0; i < 5; i++) {
dqn.addExperience(createTestTrajectory(10));
}
const startTime = performance.now();
dqn.update();
const elapsed = performance.now() - startTime;
// Allow generous overhead for neural network in test environment
// (actual production target is <10ms, but tests run in CI may be slower)
expect(elapsed).toBeLessThan(500);
});
it('should use double DQN when enabled', () => {
const doubleDQN = createDQN({
doubleDQN: true,
miniBatchSize: 16,
});
for (let i = 0; i < 3; i++) {
doubleDQN.addExperience(createTestTrajectory(10));
}
expect(() => doubleDQN.update()).not.toThrow();
});
it('should select actions with epsilon-greedy', () => {
const state = new Float32Array(768).fill(0.5);
const action = dqn.getAction(state, true);
expect(action).toBeGreaterThanOrEqual(0);
expect(action).toBeLessThan(4);
});
it('should return Q-values for a state', () => {
const state = new Float32Array(768).fill(0.5);
const qValues = dqn.getQValues(state);
expect(qValues).toBeInstanceOf(Float32Array);
expect(qValues.length).toBe(4);
});
it('should update target network periodically', () => {
const dqnWithFreqUpdate = createDQN({
targetUpdateFreq: 5,
miniBatchSize: 16,
});
for (let i = 0; i < 3; i++) {
dqnWithFreqUpdate.addExperience(createTestTrajectory(10));
}
// Perform multiple updates to trigger target network update
for (let i = 0; i < 10; i++) {
dqnWithFreqUpdate.update();
}
const stats = dqnWithFreqUpdate.getStats();
expect(stats.stepCount).toBeGreaterThan(5);
});
it('should handle circular replay buffer correctly', () => {
const smallDQN = createDQN({
bufferSize: 10,
miniBatchSize: 4,
});
// Add more experiences than buffer size
for (let i = 0; i < 15; i++) {
smallDQN.addExperience(createTestTrajectory(2));
}
const stats = smallDQN.getStats();
expect(stats.bufferSize).toBe(10);
});
});
describe('PPO Algorithm', () => {
let ppo: PPOAlgorithm;
beforeEach(() => {
ppo = createPPO({
learningRate: 0.0003,
clipRange: 0.2,
gaeLambda: 0.95,
epochs: 4,
miniBatchSize: 64,
});
});
it('should initialize correctly', () => {
expect(ppo).toBeDefined();
const stats = ppo.getStats();
expect(stats.updateCount).toBe(0);
});
it('should add experience from trajectory', () => {
const trajectory = createTestTrajectory(10);
expect(() => ppo.addExperience(trajectory)).not.toThrow();
const stats = ppo.getStats();
expect(stats.bufferSize).toBe(10);
});
it('should perform PPO update with clipping', () => {
// Add enough experiences
for (let i = 0; i < 10; i++) {
ppo.addExperience(createTestTrajectory(10));
}
const result = ppo.update();
// Policy loss can be negative in PPO (we minimize -surrogate_objective)
expect(typeof result.policyLoss).toBe('number');
expect(result.valueLoss).toBeGreaterThanOrEqual(0);
expect(result.entropy).toBeGreaterThanOrEqual(0);
});
it('should update under performance target (<10ms for small batches)', () => {
const smallPPO = createPPO({
miniBatchSize: 16,
epochs: 1,
});
for (let i = 0; i < 3; i++) {
smallPPO.addExperience(createTestTrajectory(10));
}
const startTime = performance.now();
smallPPO.update();
const elapsed = performance.now() - startTime;
expect(elapsed).toBeLessThan(100); // Allow overhead for PPO complexity
});
it('should compute GAE advantages', () => {
const trajectory = createTestTrajectory(20);
expect(() => ppo.addExperience(trajectory)).not.toThrow();
// Verify experiences were added with advantages
const stats = ppo.getStats();
expect(stats.bufferSize).toBe(20);
});
it('should sample actions from policy', () => {
const state = new Float32Array(768).fill(0.5);
const result = ppo.getAction(state);
expect(result.action).toBeGreaterThanOrEqual(0);
expect(result.action).toBeLessThan(4);
expect(result.logProb).toBeDefined();
expect(result.value).toBeDefined();
});
it('should handle multiple training epochs', () => {
const multiEpochPPO = createPPO({
epochs: 8,
miniBatchSize: 32,
});
for (let i = 0; i < 5; i++) {
multiEpochPPO.addExperience(createTestTrajectory(10));
}
expect(() => multiEpochPPO.update()).not.toThrow();
});
it('should clear buffer after update', () => {
for (let i = 0; i < 10; i++) {
ppo.addExperience(createTestTrajectory(10));
}
ppo.update();
const stats = ppo.getStats();
expect(stats.bufferSize).toBe(0);
});
});
describe('Decision Transformer', () => {
let dt: DecisionTransformer;
beforeEach(() => {
dt = createDecisionTransformer({
contextLength: 20,
numHeads: 4,
numLayers: 2,
hiddenDim: 64,
embeddingDim: 32,
});
});
it('should initialize correctly', () => {
expect(dt).toBeDefined();
const stats = dt.getStats();
expect(stats.updateCount).toBe(0);
expect(stats.bufferSize).toBe(0);
expect(stats.contextLength).toBe(20);
expect(stats.numLayers).toBe(2);
});
it('should add complete trajectories to buffer', () => {
const trajectory = createTestTrajectory(10);
dt.addTrajectory(trajectory);
const stats = dt.getStats();
expect(stats.bufferSize).toBe(1);
});
it('should not add incomplete trajectories', () => {
const incompleteTrajectory: Trajectory = {
...createTestTrajectory(5),
isComplete: false,
};
dt.addTrajectory(incompleteTrajectory);
const stats = dt.getStats();
expect(stats.bufferSize).toBe(0);
});
it('should train on buffered trajectories', () => {
// Add multiple trajectories
for (let i = 0; i < 5; i++) {
dt.addTrajectory(createTestTrajectory(10));
}
const result = dt.train();
expect(result.loss).toBeGreaterThanOrEqual(0);
expect(result.accuracy).toBeGreaterThanOrEqual(0);
expect(result.accuracy).toBeLessThanOrEqual(1);
});
it('should train under performance target (<10ms per batch)', () => {
for (let i = 0; i < 3; i++) {
dt.addTrajectory(createTestTrajectory(5));
}
const startTime = performance.now();
dt.train();
const elapsed = performance.now() - startTime;
expect(elapsed).toBeLessThan(100); // Allow overhead for transformer
});
it('should get action conditioned on target return', () => {
const states = [
new Float32Array(768).fill(0.1),
new Float32Array(768).fill(0.2),
new Float32Array(768).fill(0.3),
];
const actions = [0, 1, 2];
const targetReturn = 0.9;
const action = dt.getAction(states, actions, targetReturn);
expect(action).toBeGreaterThanOrEqual(0);
expect(action).toBeLessThan(4);
});
it('should handle causal attention masking', () => {
// Train with sequence data
for (let i = 0; i < 5; i++) {
dt.addTrajectory(createTestTrajectory(15));
}
expect(() => dt.train()).not.toThrow();
});
it('should maintain bounded trajectory buffer', () => {
// Add more than max capacity (1000)
for (let i = 0; i < 1100; i++) {
dt.addTrajectory(createTestTrajectory(5));
}
const stats = dt.getStats();
expect(stats.bufferSize).toBe(1000);
});
it('should handle varying trajectory lengths', () => {
dt.addTrajectory(createTestTrajectory(3));
dt.addTrajectory(createTestTrajectory(10));
dt.addTrajectory(createTestTrajectory(25));
expect(() => dt.train()).not.toThrow();
});
it('should compute returns-to-go correctly', () => {
const trajectory = createTestTrajectory(5);
dt.addTrajectory(trajectory);
expect(() => dt.train()).not.toThrow();
const stats = dt.getStats();
expect(stats.avgLoss).toBeGreaterThanOrEqual(0);
});
});