tasq/node_modules/agentic-flow/dist/router/providers/onnx.js

265 lines
8.6 KiB
JavaScript

/**
* ONNX Runtime Provider for Local Model Inference
*
* Supports CPU and GPU execution providers for optimized local inference
* Compatible with Phi-3, Llama, and other ONNX models
*/
// Dynamic imports for optional ONNX dependencies
let ort;
let transformers;
async function ensureOnnxDependencies() {
if (!ort) {
try {
const ortModule = await import('onnxruntime-node');
ort = ortModule;
}
catch (e) {
throw new Error('onnxruntime-node not installed. Run: npm install onnxruntime-node');
}
}
if (!transformers) {
try {
const transformersModule = await import('@xenova/transformers');
transformers = transformersModule;
transformers.env.allowLocalModels = true;
}
catch (e) {
throw new Error('@xenova/transformers not installed. Run: npm install @xenova/transformers');
}
}
}
export class ONNXProvider {
name = 'onnx';
type = 'custom';
supportsStreaming = true;
supportsTools = false;
supportsMCP = false;
session = null;
generator = null;
config;
executionProviders = [];
constructor(config = {}) {
this.config = {
modelId: config.modelId || 'Xenova/Phi-3-mini-4k-instruct',
maxTokens: config.maxTokens || 512,
temperature: config.temperature || 0.7,
...config
};
}
/**
* Detect available execution providers
*/
async detectExecutionProviders() {
const providers = [];
// Try CUDA for NVIDIA GPUs
try {
if (process.platform === 'linux') {
providers.push('cuda');
this.executionProviders.push('cuda');
}
}
catch (e) {
// CUDA not available
}
// Try DirectML for Windows GPUs
try {
if (process.platform === 'win32') {
providers.push('dml');
this.executionProviders.push('dml');
}
}
catch (e) {
// DirectML not available
}
// Always fallback to CPU
providers.push('cpu');
this.executionProviders.push('cpu');
console.log(`🔧 ONNX Execution Providers: ${this.executionProviders.join(', ')}`);
return providers;
}
/**
* Initialize ONNX session with model
*/
async initializeSession() {
if (this.generator)
return;
try {
await ensureOnnxDependencies();
console.log(`📦 Loading ONNX model: ${this.config.modelId}`);
// Use Transformers.js for easier model loading
this.generator = await transformers.pipeline('text-generation', this.config.modelId, {
quantized: true, // Use quantized models for better CPU performance
});
console.log(`✅ ONNX model loaded successfully`);
}
catch (error) {
const providerError = {
name: 'ONNXInitError',
message: `Failed to initialize ONNX model: ${error}`,
provider: 'onnx',
retryable: false
};
throw providerError;
}
}
/**
* Format messages for model input
*/
formatMessages(messages) {
// Simple chat template for Phi-3
let prompt = '';
for (const msg of messages) {
const content = typeof msg.content === 'string'
? msg.content
: msg.content.map(c => c.type === 'text' ? c.text : '').join('');
if (msg.role === 'user') {
prompt += `<|user|>\n${content}<|end|>\n`;
}
else if (msg.role === 'assistant') {
prompt += `<|assistant|>\n${content}<|end|>\n`;
}
else if (msg.role === 'system') {
prompt += `<|system|>\n${content}<|end|>\n`;
}
}
prompt += '<|assistant|>\n';
return prompt;
}
/**
* Chat completion
*/
async chat(params) {
await this.initializeSession();
const startTime = Date.now();
const prompt = this.formatMessages(params.messages);
try {
const result = await this.generator(prompt, {
max_new_tokens: params.maxTokens || this.config.maxTokens,
temperature: params.temperature || this.config.temperature,
do_sample: true,
top_p: 0.9,
});
const generatedText = result[0].generated_text;
// Extract only the new assistant response
const assistantResponse = generatedText
.split('<|assistant|>')
.pop()
?.split('<|end|>')[0]
?.trim() || '';
const latency = Date.now() - startTime;
// Estimate token counts (rough approximation)
const inputTokens = Math.ceil(prompt.length / 4);
const outputTokens = Math.ceil(assistantResponse.length / 4);
const content = [{
type: 'text',
text: assistantResponse
}];
return {
id: `onnx-${Date.now()}`,
model: this.config.modelId || 'onnx-model',
content,
stopReason: 'end_turn',
usage: {
inputTokens,
outputTokens
},
metadata: {
provider: 'onnx',
model: this.config.modelId,
latency,
cost: 0, // Local inference is free
executionProviders: this.executionProviders
}
};
}
catch (error) {
const providerError = {
name: 'ONNXInferenceError',
message: `ONNX inference failed: ${error}`,
provider: 'onnx',
retryable: true
};
throw providerError;
}
}
/**
* Streaming generation
*/
async *stream(params) {
await this.initializeSession();
const prompt = this.formatMessages(params.messages);
try {
// Note: Transformers.js doesn't natively support streaming
// We'll simulate it by yielding tokens as they're generated
const result = await this.generator(prompt, {
max_new_tokens: params.maxTokens || this.config.maxTokens,
temperature: params.temperature || this.config.temperature,
do_sample: true,
top_p: 0.9,
});
const generatedText = result[0].generated_text;
const assistantResponse = generatedText
.split('<|assistant|>')
.pop()
?.split('<|end|>')[0]
?.trim() || '';
// Simulate streaming by chunking the response
const words = assistantResponse.split(' ');
for (let i = 0; i < words.length; i++) {
const chunk = words[i] + (i < words.length - 1 ? ' ' : '');
yield {
type: 'content_block_delta',
delta: {
type: 'text_delta',
text: chunk
}
};
// Small delay to simulate real streaming
await new Promise(resolve => setTimeout(resolve, 10));
}
yield {
type: 'message_stop'
};
}
catch (error) {
const providerError = {
name: 'ONNXStreamError',
message: `ONNX streaming failed: ${error}`,
provider: 'onnx',
retryable: true
};
throw providerError;
}
}
/**
* Validate capabilities
*/
validateCapabilities(features) {
const supported = ['chat', 'stream'];
return features.every(f => supported.includes(f));
}
/**
* Get model info
*/
getModelInfo() {
return {
modelId: this.config.modelId,
executionProviders: this.executionProviders,
supportsGPU: this.executionProviders.includes('cuda') || this.executionProviders.includes('dml'),
initialized: this.generator !== null
};
}
/**
* Cleanup resources
*/
async dispose() {
if (this.generator) {
this.generator = null;
}
if (this.session) {
await this.session.release();
this.session = null;
}
}
}
//# sourceMappingURL=onnx.js.map