tasq/node_modules/@claude-flow/neural/dist/modes/balanced.js

234 lines
8.6 KiB
JavaScript

/**
* Balanced Mode Implementation
*
* General-purpose mode with:
* - +25% quality improvement
* - 18ms overhead
* - Rank-4 LoRA
* - Pattern caching
* - Standard learning pipeline
*/
import { BaseModeImplementation } from './base.js';
/**
* Balanced mode for general-purpose learning
*/
export class BalancedMode extends BaseModeImplementation {
mode = 'balanced';
// Pattern cache
patternCache = new Map();
cacheHits = 0;
cacheMisses = 0;
// Learning state
gradientAccumulator = new Map();
momentumBuffers = new Map();
// Stats
totalPatternMatches = 0;
totalPatternTime = 0;
totalLearnTime = 0;
learnIterations = 0;
qualityImprovements = [];
async initialize() {
await super.initialize();
this.patternCache.clear();
this.gradientAccumulator.clear();
this.momentumBuffers.clear();
}
async cleanup() {
this.patternCache.clear();
this.gradientAccumulator.clear();
this.momentumBuffers.clear();
await super.cleanup();
}
/**
* Find patterns using similarity search with caching
*/
async findPatterns(embedding, k, patterns) {
const startTime = performance.now();
// Check cache
const cacheKey = this.computeCacheKey(embedding);
const cached = this.patternCache.get(cacheKey);
if (cached && cached.length >= k) {
this.cacheHits++;
this.totalPatternTime += performance.now() - startTime;
this.totalPatternMatches++;
return cached.slice(0, k);
}
this.cacheMisses++;
// Compute similarities for all patterns
const matches = [];
for (const pattern of patterns) {
const similarity = this.cosineSimilarity(embedding, pattern.embedding);
matches.push({
pattern,
similarity,
confidence: similarity * pattern.successRate,
latencyMs: 0,
});
}
// Sort by similarity descending
matches.sort((a, b) => b.similarity - a.similarity);
const topK = matches.slice(0, k);
// Cache result
if (this.patternCache.size > 500) {
const firstKey = this.patternCache.keys().next().value;
if (firstKey)
this.patternCache.delete(firstKey);
}
this.patternCache.set(cacheKey, topK);
this.totalPatternTime += performance.now() - startTime;
this.totalPatternMatches++;
return topK;
}
/**
* Learn from trajectories using standard gradient descent
*/
async learn(trajectories, config, ewcState) {
const startTime = performance.now();
if (trajectories.length === 0)
return 0;
const qualityThreshold = config.qualityThreshold;
const learningRate = config.learningRate;
// Separate positive and negative examples
const goodTrajectories = trajectories.filter(t => t.qualityScore >= qualityThreshold);
const badTrajectories = trajectories.filter(t => t.qualityScore < qualityThreshold);
if (goodTrajectories.length === 0)
return 0;
// Compute gradients from trajectory pairs
let totalGradientNorm = 0;
for (const good of goodTrajectories) {
// Use last step embedding as "goal state"
if (good.steps.length === 0)
continue;
const goalState = good.steps[good.steps.length - 1].stateAfter;
// Positive gradient: move toward good outcomes
const posGradient = this.computeGradient(goalState, good.qualityScore);
totalGradientNorm += this.accumulateGradient('positive', posGradient, learningRate);
// Negative gradient: move away from bad outcomes (contrastive)
for (const bad of badTrajectories.slice(0, 3)) {
if (bad.steps.length === 0)
continue;
const badState = bad.steps[bad.steps.length - 1].stateAfter;
const negGradient = this.computeGradient(badState, -bad.qualityScore);
totalGradientNorm += this.accumulateGradient('negative', negGradient, learningRate * 0.5);
}
}
// Apply EWC regularization
const ewcPenalty = this.computeEWCPenalty(ewcState, config.ewcLambda);
totalGradientNorm += ewcPenalty;
// Compute improvement delta
const avgGoodQuality = goodTrajectories.reduce((s, t) => s + t.qualityScore, 0) / goodTrajectories.length;
const baselineQuality = 0.5;
const improvementDelta = avgGoodQuality - baselineQuality;
this.qualityImprovements.push(improvementDelta);
if (this.qualityImprovements.length > 100) {
this.qualityImprovements = this.qualityImprovements.slice(-100);
}
this.totalLearnTime += performance.now() - startTime;
this.learnIterations++;
return Math.max(0, improvementDelta);
}
/**
* Apply LoRA adaptations with rank-4
*/
async applyLoRA(input, weights) {
if (!weights) {
return input;
}
const output = new Float32Array(input.length);
output.set(input);
const rank = this.config.loraRank;
// Apply to all target modules
for (const module of ['q_proj', 'v_proj', 'k_proj', 'o_proj']) {
const A = weights.A.get(module);
const B = weights.B.get(module);
if (A && B) {
const adapted = this.applyLoRATransform(input, A, B, rank);
const alpha = 0.2; // Moderate blending
for (let i = 0; i < output.length; i++) {
output[i] = output[i] * (1 - alpha) + adapted[i] * alpha;
}
}
}
return output;
}
getStats() {
const avgImprovement = this.qualityImprovements.length > 0
? this.qualityImprovements.reduce((a, b) => a + b, 0) / this.qualityImprovements.length
: 0;
return {
cacheHitRate: this.cacheHits + this.cacheMisses > 0
? this.cacheHits / (this.cacheHits + this.cacheMisses)
: 0,
avgPatternMatchMs: this.totalPatternMatches > 0
? this.totalPatternTime / this.totalPatternMatches
: 0,
avgLearnMs: this.learnIterations > 0
? this.totalLearnTime / this.learnIterations
: 0,
avgImprovement,
patternCacheSize: this.patternCache.size,
learnIterations: this.learnIterations,
};
}
/**
* Compute cache key from embedding
*/
computeCacheKey(embedding) {
const keyParts = [];
for (let i = 0; i < Math.min(16, embedding.length); i++) {
keyParts.push(embedding[i].toFixed(2));
}
return keyParts.join(',');
}
/**
* Compute gradient from state and reward
*/
computeGradient(state, reward) {
const gradient = new Float32Array(state.length);
for (let i = 0; i < state.length; i++) {
gradient[i] = state[i] * reward;
}
return gradient;
}
/**
* Accumulate gradient with momentum
*/
accumulateGradient(key, gradient, lr) {
let momentum = this.momentumBuffers.get(key);
if (!momentum) {
momentum = new Float32Array(gradient.length);
this.momentumBuffers.set(key, momentum);
}
let accumulator = this.gradientAccumulator.get(key);
if (!accumulator) {
accumulator = new Float32Array(gradient.length);
this.gradientAccumulator.set(key, accumulator);
}
const beta = 0.9; // Momentum coefficient
let norm = 0;
for (let i = 0; i < gradient.length; i++) {
momentum[i] = beta * momentum[i] + (1 - beta) * gradient[i];
accumulator[i] += lr * momentum[i];
norm += momentum[i] * momentum[i];
}
return Math.sqrt(norm);
}
/**
* Compute EWC penalty for continual learning
*/
computeEWCPenalty(ewcState, lambda) {
let penalty = 0;
for (const [key, fisher] of ewcState.fisher) {
const means = ewcState.means.get(key);
const current = this.gradientAccumulator.get(key);
if (means && current) {
for (let i = 0; i < Math.min(fisher.length, means.length, current.length); i++) {
const diff = current[i] - means[i];
penalty += fisher[i] * diff * diff;
}
}
}
return lambda * penalty * 0.5;
}
}
//# sourceMappingURL=balanced.js.map