tasq/node_modules/agentic-flow/dist/core/attention-native.js

268 lines
8.0 KiB
JavaScript

/**
* Native Attention Wrappers
*
* Properly wraps @ruvector/attention native Rust implementations
* with TypedArray conversions and proper error handling
*/
import * as nativeAttention from '@ruvector/attention';
/**
* Convert regular array to Float32Array if needed
*/
function toFloat32Array(input) {
if (input instanceof Float32Array) {
return input;
}
return new Float32Array(input);
}
/**
* Convert Float32Array to regular array for JS compatibility
*/
function toArray(input) {
return Array.from(input);
}
/**
* Native Multi-Head Attention (uses Rust implementation)
*
* This wrapper properly converts between JavaScript arrays and TypedArrays
* required by the native Rust implementation.
*/
export class MultiHeadAttention {
nativeInstance;
hiddenDim;
numHeads;
constructor(config) {
this.hiddenDim = config.hiddenDim;
this.numHeads = config.numHeads || 8;
try {
// Create native Rust instance
this.nativeInstance = new nativeAttention.MultiHeadAttention(this.hiddenDim, this.numHeads);
}
catch (error) {
throw new Error(`Failed to initialize native MultiHeadAttention: ${error.message}`);
}
}
/**
* Forward pass using native Rust implementation
*/
forward(query, key, value, mask) {
try {
// Convert to Float32Array for native code
const q = toFloat32Array(query);
const k = toFloat32Array(key);
const v = toFloat32Array(value);
const m = mask ? toFloat32Array(mask) : undefined;
// Call native compute method
const result = this.nativeInstance.compute(q, k, v, m);
// Convert result back to regular arrays
return {
output: toArray(result),
attentionWeights: [[]] // Native doesn't return weights separately
};
}
catch (error) {
throw new Error(`MultiHeadAttention forward failed: ${error.message}`);
}
}
/**
* Get dimensions
*/
get dim() {
return this.nativeInstance.dim;
}
get headDim() {
return this.nativeInstance.headDim;
}
}
/**
* Native Flash Attention (uses Rust implementation)
*
* Memory-efficient attention with tiling/chunking
*/
export class FlashAttention {
nativeInstance;
hiddenDim;
constructor(config) {
this.hiddenDim = config.hiddenDim;
try {
this.nativeInstance = new nativeAttention.FlashAttention(this.hiddenDim);
}
catch (error) {
throw new Error(`Failed to initialize native FlashAttention: ${error.message}`);
}
}
/**
* Forward pass with batch support
*/
forward(query, key, value, numHeads = 8) {
try {
// Convert batch to Float32Array
const q = query.map(toFloat32Array);
const k = key.map(toFloat32Array);
const v = value.map(toFloat32Array);
// Call native compute
const result = this.nativeInstance.compute(q, k, v, numHeads);
// Convert result back
return {
output: result.map((r) => toArray(r)),
attentionScores: [[]]
};
}
catch (error) {
throw new Error(`FlashAttention forward failed: ${error.message}`);
}
}
}
/**
* Native Linear Attention (uses Rust implementation)
*
* O(n) complexity approximation of attention
*/
export class LinearAttention {
nativeInstance;
hiddenDim;
constructor(config) {
this.hiddenDim = config.hiddenDim;
try {
// LinearAttention constructor: (hiddenDim, seqLen)
// We'll use hiddenDim for both since seqLen varies per forward call
this.nativeInstance = new nativeAttention.LinearAttention(this.hiddenDim, this.hiddenDim);
}
catch (error) {
throw new Error(`Failed to initialize native LinearAttention: ${error.message}`);
}
}
forward(query, key, value) {
try {
const q = query.map(toFloat32Array);
const k = key.map(toFloat32Array);
const v = value.map(toFloat32Array);
const result = this.nativeInstance.compute(q, k, v);
return {
output: result.map((r) => toArray(r))
};
}
catch (error) {
throw new Error(`LinearAttention forward failed: ${error.message}`);
}
}
}
/**
* Native Hyperbolic Attention (uses Rust implementation)
*/
export class HyperbolicAttention {
nativeInstance;
hiddenDim;
constructor(config) {
this.hiddenDim = config.hiddenDim;
try {
this.nativeInstance = new nativeAttention.HyperbolicAttention(this.hiddenDim);
}
catch (error) {
throw new Error(`Failed to initialize native HyperbolicAttention: ${error.message}`);
}
}
forward(query, key, value) {
try {
const q = toFloat32Array(query);
const k = toFloat32Array(key);
const v = toFloat32Array(value);
const result = this.nativeInstance.compute(q, k, v);
return {
output: toArray(result.output),
distance: result.distance
};
}
catch (error) {
throw new Error(`HyperbolicAttention forward failed: ${error.message}`);
}
}
}
/**
* Native MoE Attention (uses Rust implementation)
*
* Mixture of Experts with top-k routing
*/
export class MoEAttention {
nativeInstance;
hiddenDim;
numExperts;
constructor(config) {
this.hiddenDim = config.hiddenDim;
this.numExperts = config.numExperts || 4;
try {
this.nativeInstance = new nativeAttention.MoEAttention(this.hiddenDim, this.numExperts);
}
catch (error) {
throw new Error(`Failed to initialize native MoEAttention: ${error.message}`);
}
}
forward(query, key, value, topK = 2) {
try {
const q = toFloat32Array(query);
const k = toFloat32Array(key);
const v = toFloat32Array(value);
const result = this.nativeInstance.compute(q, k, v, topK);
return {
output: toArray(result.output),
expertWeights: toArray(result.expertWeights)
};
}
catch (error) {
throw new Error(`MoEAttention forward failed: ${error.message}`);
}
}
}
/**
* Scaled Dot-Product Attention (pure JavaScript, always works)
*/
export function scaledDotProductAttention(query, key, value, mask) {
const dk = query.length;
// Compute attention scores: Q · K^T / sqrt(dk)
let score = 0;
for (let i = 0; i < dk; i++) {
score += query[i] * key[i];
}
score /= Math.sqrt(dk);
// Apply mask if provided
if (mask && mask[0] === 0) {
score = -Infinity;
}
// Softmax (single score version)
const expScore = Math.exp(score);
const weight = expScore;
// Weighted value
const output = value.map(v => v * weight);
return { output, weights: [weight] };
}
/**
* Check if native attention is available
*/
export function isNativeAttentionAvailable() {
try {
const test = new nativeAttention.MultiHeadAttention(128, 4);
return true;
}
catch {
return false;
}
}
/**
* Factory function to create appropriate attention module
*/
export function createAttention(type, config) {
switch (type) {
case 'multi-head':
return new MultiHeadAttention(config);
case 'flash':
return new FlashAttention(config);
case 'linear':
return new LinearAttention(config);
case 'hyperbolic':
return new HyperbolicAttention(config);
case 'moe':
return new MoEAttention(config);
default:
return new MultiHeadAttention(config);
}
}
//# sourceMappingURL=attention-native.js.map