234 lines
8.6 KiB
JavaScript
234 lines
8.6 KiB
JavaScript
/**
|
|
* Balanced Mode Implementation
|
|
*
|
|
* General-purpose mode with:
|
|
* - +25% quality improvement
|
|
* - 18ms overhead
|
|
* - Rank-4 LoRA
|
|
* - Pattern caching
|
|
* - Standard learning pipeline
|
|
*/
|
|
import { BaseModeImplementation } from './base.js';
|
|
/**
|
|
* Balanced mode for general-purpose learning
|
|
*/
|
|
export class BalancedMode extends BaseModeImplementation {
|
|
mode = 'balanced';
|
|
// Pattern cache
|
|
patternCache = new Map();
|
|
cacheHits = 0;
|
|
cacheMisses = 0;
|
|
// Learning state
|
|
gradientAccumulator = new Map();
|
|
momentumBuffers = new Map();
|
|
// Stats
|
|
totalPatternMatches = 0;
|
|
totalPatternTime = 0;
|
|
totalLearnTime = 0;
|
|
learnIterations = 0;
|
|
qualityImprovements = [];
|
|
async initialize() {
|
|
await super.initialize();
|
|
this.patternCache.clear();
|
|
this.gradientAccumulator.clear();
|
|
this.momentumBuffers.clear();
|
|
}
|
|
async cleanup() {
|
|
this.patternCache.clear();
|
|
this.gradientAccumulator.clear();
|
|
this.momentumBuffers.clear();
|
|
await super.cleanup();
|
|
}
|
|
/**
|
|
* Find patterns using similarity search with caching
|
|
*/
|
|
async findPatterns(embedding, k, patterns) {
|
|
const startTime = performance.now();
|
|
// Check cache
|
|
const cacheKey = this.computeCacheKey(embedding);
|
|
const cached = this.patternCache.get(cacheKey);
|
|
if (cached && cached.length >= k) {
|
|
this.cacheHits++;
|
|
this.totalPatternTime += performance.now() - startTime;
|
|
this.totalPatternMatches++;
|
|
return cached.slice(0, k);
|
|
}
|
|
this.cacheMisses++;
|
|
// Compute similarities for all patterns
|
|
const matches = [];
|
|
for (const pattern of patterns) {
|
|
const similarity = this.cosineSimilarity(embedding, pattern.embedding);
|
|
matches.push({
|
|
pattern,
|
|
similarity,
|
|
confidence: similarity * pattern.successRate,
|
|
latencyMs: 0,
|
|
});
|
|
}
|
|
// Sort by similarity descending
|
|
matches.sort((a, b) => b.similarity - a.similarity);
|
|
const topK = matches.slice(0, k);
|
|
// Cache result
|
|
if (this.patternCache.size > 500) {
|
|
const firstKey = this.patternCache.keys().next().value;
|
|
if (firstKey)
|
|
this.patternCache.delete(firstKey);
|
|
}
|
|
this.patternCache.set(cacheKey, topK);
|
|
this.totalPatternTime += performance.now() - startTime;
|
|
this.totalPatternMatches++;
|
|
return topK;
|
|
}
|
|
/**
|
|
* Learn from trajectories using standard gradient descent
|
|
*/
|
|
async learn(trajectories, config, ewcState) {
|
|
const startTime = performance.now();
|
|
if (trajectories.length === 0)
|
|
return 0;
|
|
const qualityThreshold = config.qualityThreshold;
|
|
const learningRate = config.learningRate;
|
|
// Separate positive and negative examples
|
|
const goodTrajectories = trajectories.filter(t => t.qualityScore >= qualityThreshold);
|
|
const badTrajectories = trajectories.filter(t => t.qualityScore < qualityThreshold);
|
|
if (goodTrajectories.length === 0)
|
|
return 0;
|
|
// Compute gradients from trajectory pairs
|
|
let totalGradientNorm = 0;
|
|
for (const good of goodTrajectories) {
|
|
// Use last step embedding as "goal state"
|
|
if (good.steps.length === 0)
|
|
continue;
|
|
const goalState = good.steps[good.steps.length - 1].stateAfter;
|
|
// Positive gradient: move toward good outcomes
|
|
const posGradient = this.computeGradient(goalState, good.qualityScore);
|
|
totalGradientNorm += this.accumulateGradient('positive', posGradient, learningRate);
|
|
// Negative gradient: move away from bad outcomes (contrastive)
|
|
for (const bad of badTrajectories.slice(0, 3)) {
|
|
if (bad.steps.length === 0)
|
|
continue;
|
|
const badState = bad.steps[bad.steps.length - 1].stateAfter;
|
|
const negGradient = this.computeGradient(badState, -bad.qualityScore);
|
|
totalGradientNorm += this.accumulateGradient('negative', negGradient, learningRate * 0.5);
|
|
}
|
|
}
|
|
// Apply EWC regularization
|
|
const ewcPenalty = this.computeEWCPenalty(ewcState, config.ewcLambda);
|
|
totalGradientNorm += ewcPenalty;
|
|
// Compute improvement delta
|
|
const avgGoodQuality = goodTrajectories.reduce((s, t) => s + t.qualityScore, 0) / goodTrajectories.length;
|
|
const baselineQuality = 0.5;
|
|
const improvementDelta = avgGoodQuality - baselineQuality;
|
|
this.qualityImprovements.push(improvementDelta);
|
|
if (this.qualityImprovements.length > 100) {
|
|
this.qualityImprovements = this.qualityImprovements.slice(-100);
|
|
}
|
|
this.totalLearnTime += performance.now() - startTime;
|
|
this.learnIterations++;
|
|
return Math.max(0, improvementDelta);
|
|
}
|
|
/**
|
|
* Apply LoRA adaptations with rank-4
|
|
*/
|
|
async applyLoRA(input, weights) {
|
|
if (!weights) {
|
|
return input;
|
|
}
|
|
const output = new Float32Array(input.length);
|
|
output.set(input);
|
|
const rank = this.config.loraRank;
|
|
// Apply to all target modules
|
|
for (const module of ['q_proj', 'v_proj', 'k_proj', 'o_proj']) {
|
|
const A = weights.A.get(module);
|
|
const B = weights.B.get(module);
|
|
if (A && B) {
|
|
const adapted = this.applyLoRATransform(input, A, B, rank);
|
|
const alpha = 0.2; // Moderate blending
|
|
for (let i = 0; i < output.length; i++) {
|
|
output[i] = output[i] * (1 - alpha) + adapted[i] * alpha;
|
|
}
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
getStats() {
|
|
const avgImprovement = this.qualityImprovements.length > 0
|
|
? this.qualityImprovements.reduce((a, b) => a + b, 0) / this.qualityImprovements.length
|
|
: 0;
|
|
return {
|
|
cacheHitRate: this.cacheHits + this.cacheMisses > 0
|
|
? this.cacheHits / (this.cacheHits + this.cacheMisses)
|
|
: 0,
|
|
avgPatternMatchMs: this.totalPatternMatches > 0
|
|
? this.totalPatternTime / this.totalPatternMatches
|
|
: 0,
|
|
avgLearnMs: this.learnIterations > 0
|
|
? this.totalLearnTime / this.learnIterations
|
|
: 0,
|
|
avgImprovement,
|
|
patternCacheSize: this.patternCache.size,
|
|
learnIterations: this.learnIterations,
|
|
};
|
|
}
|
|
/**
|
|
* Compute cache key from embedding
|
|
*/
|
|
computeCacheKey(embedding) {
|
|
const keyParts = [];
|
|
for (let i = 0; i < Math.min(16, embedding.length); i++) {
|
|
keyParts.push(embedding[i].toFixed(2));
|
|
}
|
|
return keyParts.join(',');
|
|
}
|
|
/**
|
|
* Compute gradient from state and reward
|
|
*/
|
|
computeGradient(state, reward) {
|
|
const gradient = new Float32Array(state.length);
|
|
for (let i = 0; i < state.length; i++) {
|
|
gradient[i] = state[i] * reward;
|
|
}
|
|
return gradient;
|
|
}
|
|
/**
|
|
* Accumulate gradient with momentum
|
|
*/
|
|
accumulateGradient(key, gradient, lr) {
|
|
let momentum = this.momentumBuffers.get(key);
|
|
if (!momentum) {
|
|
momentum = new Float32Array(gradient.length);
|
|
this.momentumBuffers.set(key, momentum);
|
|
}
|
|
let accumulator = this.gradientAccumulator.get(key);
|
|
if (!accumulator) {
|
|
accumulator = new Float32Array(gradient.length);
|
|
this.gradientAccumulator.set(key, accumulator);
|
|
}
|
|
const beta = 0.9; // Momentum coefficient
|
|
let norm = 0;
|
|
for (let i = 0; i < gradient.length; i++) {
|
|
momentum[i] = beta * momentum[i] + (1 - beta) * gradient[i];
|
|
accumulator[i] += lr * momentum[i];
|
|
norm += momentum[i] * momentum[i];
|
|
}
|
|
return Math.sqrt(norm);
|
|
}
|
|
/**
|
|
* Compute EWC penalty for continual learning
|
|
*/
|
|
computeEWCPenalty(ewcState, lambda) {
|
|
let penalty = 0;
|
|
for (const [key, fisher] of ewcState.fisher) {
|
|
const means = ewcState.means.get(key);
|
|
const current = this.gradientAccumulator.get(key);
|
|
if (means && current) {
|
|
for (let i = 0; i < Math.min(fisher.length, means.length, current.length); i++) {
|
|
const diff = current[i] - means[i];
|
|
penalty += fisher[i] * diff * diff;
|
|
}
|
|
}
|
|
}
|
|
return lambda * penalty * 0.5;
|
|
}
|
|
}
|
|
//# sourceMappingURL=balanced.js.map
|