tasq/node_modules/agentic-flow/dist/core/gnn-wrapper.js

236 lines
7.7 KiB
JavaScript

/**
* GNN Compatibility Wrapper
*
* Fixes API issues with @ruvector/gnn by:
* 1. Auto-converting regular arrays to Float32Array
* 2. Providing fallback implementations for broken functions
* 3. Type-safe interface matching documentation
*/
// Dynamic GNN import with graceful degradation
const gnn = await (async () => {
try {
const module = await import('@ruvector/gnn');
return module.default || module;
}
catch {
return null;
}
})();
/**
* Fixed differentiableSearch that accepts regular arrays
* Automatically converts to Float32Array internally
*/
export function differentiableSearch(query, candidateEmbeddings, k, temperature = 1.0) {
// Convert to Float32Array
const queryTyped = new Float32Array(query);
const candidatesTyped = candidateEmbeddings.map(candidate => new Float32Array(candidate));
try {
const result = gnn.differentiableSearch(queryTyped, candidatesTyped, k, temperature);
return {
indices: Array.from(result.indices),
weights: Array.from(result.weights)
};
}
catch (error) {
throw new Error(`GNN differentiableSearch failed: ${error.message}`);
}
}
/**
* Fallback hierarchicalForward using simple matrix multiplication
* Since the native implementation is broken
*/
export function hierarchicalForward(input, weights, inputDim, outputDim) {
try {
// Try native implementation first
const inputTyped = new Float32Array(input);
const weightsTyped = Array.isArray(weights[0])
? weights.map(w => new Float32Array(w))
: new Float32Array(weights);
const result = gnn.hierarchicalForward(inputTyped, weightsTyped, inputDim, outputDim);
return Array.from(result);
}
catch (error) {
// Fallback to JavaScript implementation
console.warn('GNN hierarchicalForward failed, using fallback implementation');
return hierarchicalForwardFallback(input, weights, inputDim, outputDim);
}
}
/**
* Fallback implementation using basic matrix multiplication
*/
function hierarchicalForwardFallback(input, weights, inputDim, outputDim) {
// Simple matrix multiplication: output = input * weights^T
const output = new Array(outputDim).fill(0);
if (Array.isArray(weights[0])) {
// weights is 2D array [outputDim][inputDim]
const weightsMatrix = weights;
for (let i = 0; i < outputDim; i++) {
for (let j = 0; j < inputDim; j++) {
output[i] += input[j] * weightsMatrix[i][j];
}
}
}
else {
// weights is 1D array (flattened)
const weightsFlat = weights;
for (let i = 0; i < outputDim; i++) {
for (let j = 0; j < inputDim; j++) {
output[i] += input[j] * weightsFlat[i * inputDim + j];
}
}
}
return output;
}
/**
* RuvectorLayer wrapper with fallback
*/
export class RuvectorLayer {
inputDim;
outputDim;
weights;
activation;
constructor(inputDim, outputDim, activation = 'relu') {
this.inputDim = inputDim;
this.outputDim = outputDim;
this.activation = activation;
// Initialize random weights
this.weights = Array.from({ length: outputDim }, () => Array.from({ length: inputDim }, () => (Math.random() - 0.5) * 0.1));
}
forward(input) {
if (input.length !== this.inputDim) {
throw new Error(`Input dimension mismatch: expected ${this.inputDim}, got ${input.length}`);
}
// Matrix multiplication
let output = hierarchicalForwardFallback(input, this.weights, this.inputDim, this.outputDim);
// Apply activation
output = this.applyActivation(output);
return output;
}
applyActivation(values) {
switch (this.activation) {
case 'relu':
return values.map(v => Math.max(0, v));
case 'tanh':
return values.map(v => Math.tanh(v));
case 'sigmoid':
return values.map(v => 1 / (1 + Math.exp(-v)));
case 'none':
return values;
default:
return values;
}
}
getWeights() {
return this.weights;
}
setWeights(weights) {
if (weights.length !== this.outputDim || weights[0].length !== this.inputDim) {
throw new Error('Weight dimensions do not match layer dimensions');
}
this.weights = weights;
}
}
/**
* TensorCompress wrapper with working compression levels
*/
export class TensorCompress {
config;
constructor(config) {
if (typeof config === 'string') {
this.config = { levelType: config };
}
else {
this.config = config;
}
}
compress(tensor) {
switch (this.config.levelType) {
case 'none':
return tensor;
case 'half':
// 16-bit float compression (approximate with scale)
const scale = this.config.scale || 1.0;
return tensor.map(v => Math.round(v / scale) * scale);
case 'pq8':
case 'pq4':
// Product quantization (simplified)
const bits = this.config.levelType === 'pq8' ? 8 : 4;
const levels = Math.pow(2, bits);
const min = Math.min(...tensor);
const max = Math.max(...tensor);
const range = max - min;
return tensor.map(v => {
const quantized = Math.round(((v - min) / range) * (levels - 1));
return (quantized / (levels - 1)) * range + min;
});
case 'binary':
// Binary quantization
const threshold = this.config.threshold || 0;
return tensor.map(v => (v > threshold ? 1 : 0));
default:
return tensor;
}
}
decompress(compressed) {
// For simple compressions, decompression is identity
return compressed;
}
getCompressionRatio() {
switch (this.config.levelType) {
case 'none':
return 1;
case 'half':
return 2; // 32-bit → 16-bit
case 'pq8':
return 4; // 32-bit → 8-bit
case 'pq4':
return 8; // 32-bit → 4-bit
case 'binary':
return 32; // 32-bit → 1-bit
default:
return 1;
}
}
}
/**
* Get compression level configuration
* Fixed version that returns proper config objects
*/
export function getCompressionLevel(level) {
const configs = {
none: { levelType: 'none' },
half: { levelType: 'half', scale: 1.0 },
pq8: { levelType: 'pq8', subvectors: 8, centroids: 256 },
pq4: { levelType: 'pq4', subvectors: 8, centroids: 16, outlierThreshold: 0.1 },
binary: { levelType: 'binary', threshold: 0.0 }
};
if (!(level in configs)) {
throw new Error(`Invalid compression level: ${level}. Valid options: ${Object.keys(configs).join(', ')}`);
}
return configs[level];
}
/**
* Check if GNN native module is available and working
*/
export function isGNNAvailable() {
try {
const query = new Float32Array([1.0, 0.0]);
const candidates = [new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])];
gnn.differentiableSearch(query, candidates, 2, 1.0);
return true;
}
catch {
return false;
}
}
/**
* Initialize GNN module (if needed)
*/
export function initGNN() {
if (typeof gnn.init === 'function') {
gnn.init();
}
}
// Export original for advanced users who want direct access
export { gnn as gnnRaw };
//# sourceMappingURL=gnn-wrapper.js.map