tasq/node_modules/@claude-flow/neural/dist/modes/batch.js

316 lines
11 KiB
JavaScript

/**
* Batch Mode Implementation
*
* Optimized for high-throughput processing with:
* - Large batch sizes (128)
* - Rank-8 LoRA
* - Gradient accumulation
* - Async batch processing
* - 50ms latency budget
*/
import { BaseModeImplementation } from './base.js';
/**
* Batch mode for high-throughput processing
*/
export class BatchMode extends BaseModeImplementation {
mode = 'batch';
// Batch processing queues
patternQueue = [];
learningQueue = [];
// Batch buffers
embeddingBuffer = null;
batchEmbeddings = [];
// Gradient accumulation
accumulatedGradients = new Map();
gradientSteps = 0;
// Batch processing state
isBatchProcessing = false;
batchTimer = null;
// Stats
totalBatches = 0;
totalItems = 0;
totalBatchTime = 0;
learnIterations = 0;
async initialize() {
await super.initialize();
this.patternQueue = [];
this.learningQueue = [];
this.accumulatedGradients.clear();
this.gradientSteps = 0;
}
async cleanup() {
if (this.batchTimer) {
clearTimeout(this.batchTimer);
}
this.patternQueue = [];
this.learningQueue = [];
this.accumulatedGradients.clear();
await super.cleanup();
}
/**
* Find patterns - queues for batch processing
*/
async findPatterns(embedding, k, patterns) {
// For immediate needs, process synchronously
if (patterns.length < 100) {
return this.findPatternsDirect(embedding, k, patterns);
}
// Queue for batch processing
return new Promise(resolve => {
this.patternQueue.push({ embedding, k, resolve });
this.scheduleBatchProcessing(patterns);
});
}
/**
* Learn from trajectories - accumulates for batch
*/
async learn(trajectories, config, ewcState) {
const startTime = performance.now();
if (trajectories.length === 0)
return 0;
// Add to learning queue
this.learningQueue.push(...trajectories);
// Process when queue is full
if (this.learningQueue.length >= config.batchSize) {
return this.processBatchLearning(config, ewcState);
}
// Return estimated improvement
const avgQuality = trajectories.reduce((s, t) => s + t.qualityScore, 0) / trajectories.length;
this.totalBatchTime += performance.now() - startTime;
return Math.max(0, avgQuality - 0.5) * 0.5; // Partial estimate
}
/**
* Apply LoRA with rank-8
*/
async applyLoRA(input, weights) {
if (!weights) {
return input;
}
// Batch mode can process multiple inputs efficiently
this.batchEmbeddings.push(new Float32Array(input));
// Process immediately for single requests
if (this.batchEmbeddings.length === 1) {
const output = await this.applyLoRADirect(input, weights);
this.batchEmbeddings = [];
return output;
}
// For multiple inputs, process as batch
const outputs = await this.applyLoRABatch(this.batchEmbeddings, weights);
this.batchEmbeddings = [];
return outputs[outputs.length - 1];
}
getStats() {
return {
totalBatches: this.totalBatches,
avgItemsPerBatch: this.totalBatches > 0 ? this.totalItems / this.totalBatches : 0,
avgBatchTimeMs: this.totalBatches > 0 ? this.totalBatchTime / this.totalBatches : 0,
pendingPatternRequests: this.patternQueue.length,
pendingTrajectories: this.learningQueue.length,
accumulatedGradientSteps: this.gradientSteps,
learnIterations: this.learnIterations,
};
}
// ========================================================================
// Direct processing (for small batches)
// ========================================================================
/**
* Direct pattern matching without batching
*/
findPatternsDirect(embedding, k, patterns) {
const matches = [];
for (const pattern of patterns) {
const similarity = this.cosineSimilarity(embedding, pattern.embedding);
matches.push({
pattern,
similarity,
confidence: similarity * pattern.successRate,
latencyMs: 0,
});
}
matches.sort((a, b) => b.similarity - a.similarity);
return matches.slice(0, k);
}
/**
* Direct LoRA application
*/
async applyLoRADirect(input, weights) {
const output = new Float32Array(input.length);
output.set(input);
const rank = this.config.loraRank;
for (const module of ['q_proj', 'v_proj', 'k_proj', 'o_proj']) {
const A = weights.A.get(module);
const B = weights.B.get(module);
if (A && B) {
const adapted = this.applyLoRATransform(input, A, B, rank);
const alpha = 0.25;
for (let i = 0; i < output.length; i++) {
output[i] = output[i] * (1 - alpha) + adapted[i] * alpha;
}
}
}
return output;
}
// ========================================================================
// Batch processing
// ========================================================================
/**
* Schedule batch processing
*/
scheduleBatchProcessing(patterns) {
if (this.batchTimer)
return;
this.batchTimer = setTimeout(() => {
this.processBatchPatterns(patterns);
}, 10); // Wait 10ms to accumulate requests
}
/**
* Process pattern requests in batch
*/
async processBatchPatterns(patterns) {
this.batchTimer = null;
if (this.patternQueue.length === 0)
return;
const startTime = performance.now();
this.isBatchProcessing = true;
const batch = this.patternQueue;
this.patternQueue = [];
// Pre-compute pattern embeddings matrix
const patternMatrix = patterns.map(p => p.embedding);
// Process all queries in batch
for (const request of batch) {
const matches = this.batchSimilaritySearch(request.embedding, request.k, patterns, patternMatrix);
request.resolve(matches);
}
this.totalBatches++;
this.totalItems += batch.length;
this.totalBatchTime += performance.now() - startTime;
this.isBatchProcessing = false;
}
/**
* Batch similarity search
*/
batchSimilaritySearch(query, k, patterns, patternMatrix) {
const similarities = [];
for (let i = 0; i < patternMatrix.length; i++) {
const sim = this.cosineSimilarity(query, patternMatrix[i]);
similarities.push({ idx: i, sim });
}
similarities.sort((a, b) => b.sim - a.sim);
const topK = similarities.slice(0, k);
return topK.map(s => ({
pattern: patterns[s.idx],
similarity: s.sim,
confidence: s.sim * patterns[s.idx].successRate,
latencyMs: 0,
}));
}
/**
* Process batch learning
*/
async processBatchLearning(config, ewcState) {
const startTime = performance.now();
const batch = this.learningQueue.slice(0, config.batchSize);
this.learningQueue = this.learningQueue.slice(config.batchSize);
const qualityThreshold = config.qualityThreshold;
const learningRate = config.learningRate;
// Separate by quality
const good = batch.filter(t => t.qualityScore >= qualityThreshold);
const bad = batch.filter(t => t.qualityScore < qualityThreshold);
if (good.length === 0) {
this.totalBatchTime += performance.now() - startTime;
return 0;
}
// Accumulate gradients
for (const trajectory of good) {
this.accumulateTrajectoryGradient(trajectory, learningRate);
}
// Contrastive learning from bad examples
for (const trajectory of bad.slice(0, good.length)) {
this.accumulateTrajectoryGradient(trajectory, -learningRate * 0.3);
}
this.gradientSteps++;
// Apply accumulated gradients every N steps
if (this.gradientSteps >= 4) {
await this.applyAccumulatedGradients(ewcState, config.ewcLambda);
this.gradientSteps = 0;
}
// Compute improvement
const avgQuality = good.reduce((s, t) => s + t.qualityScore, 0) / good.length;
const improvement = avgQuality - 0.5;
this.learnIterations++;
this.totalBatchTime += performance.now() - startTime;
return Math.max(0, improvement);
}
/**
* Accumulate gradient from trajectory
*/
accumulateTrajectoryGradient(trajectory, scale) {
if (trajectory.steps.length === 0)
return;
const key = trajectory.domain;
let gradient = this.accumulatedGradients.get(key);
if (!gradient) {
const dim = trajectory.steps[0].stateAfter.length;
gradient = new Float32Array(dim);
this.accumulatedGradients.set(key, gradient);
}
// Add trajectory contribution
const weight = trajectory.qualityScore * scale;
for (const step of trajectory.steps) {
for (let i = 0; i < Math.min(gradient.length, step.stateAfter.length); i++) {
gradient[i] += step.stateAfter[i] * weight * step.reward;
}
}
}
/**
* Apply accumulated gradients with EWC
*/
async applyAccumulatedGradients(ewcState, ewcLambda) {
for (const [key, gradient] of this.accumulatedGradients) {
// Normalize gradient
const norm = Math.sqrt(gradient.reduce((s, v) => s + v * v, 0));
if (norm > 0) {
for (let i = 0; i < gradient.length; i++) {
gradient[i] /= norm;
}
}
// Apply EWC penalty
const fisher = ewcState.fisher.get(key);
const means = ewcState.means.get(key);
if (fisher && means) {
for (let i = 0; i < gradient.length; i++) {
const penalty = ewcLambda * fisher[i] * (gradient[i] - means[i]);
gradient[i] -= penalty;
}
}
// Clear gradient for next accumulation
gradient.fill(0);
}
}
/**
* Apply LoRA to batch of inputs
*/
async applyLoRABatch(inputs, weights) {
const outputs = [];
const rank = this.config.loraRank;
// Process all inputs together for cache efficiency
for (const input of inputs) {
const output = new Float32Array(input.length);
output.set(input);
for (const module of ['q_proj', 'v_proj', 'k_proj', 'o_proj']) {
const A = weights.A.get(module);
const B = weights.B.get(module);
if (A && B) {
const adapted = this.applyLoRATransform(input, A, B, rank);
const alpha = 0.25;
for (let i = 0; i < output.length; i++) {
output[i] = output[i] * (1 - alpha) + adapted[i] * alpha;
}
}
}
outputs.push(output);
}
return outputs;
}
}
//# sourceMappingURL=batch.js.map