/** * ONNX Runtime Provider for Local Model Inference * * Supports CPU and GPU execution providers for optimized local inference * Compatible with Phi-3, Llama, and other ONNX models */ // Dynamic imports for optional ONNX dependencies let ort; let transformers; async function ensureOnnxDependencies() { if (!ort) { try { const ortModule = await import('onnxruntime-node'); ort = ortModule; } catch (e) { throw new Error('onnxruntime-node not installed. Run: npm install onnxruntime-node'); } } if (!transformers) { try { const transformersModule = await import('@xenova/transformers'); transformers = transformersModule; transformers.env.allowLocalModels = true; } catch (e) { throw new Error('@xenova/transformers not installed. Run: npm install @xenova/transformers'); } } } export class ONNXProvider { name = 'onnx'; type = 'custom'; supportsStreaming = true; supportsTools = false; supportsMCP = false; session = null; generator = null; config; executionProviders = []; constructor(config = {}) { this.config = { modelId: config.modelId || 'Xenova/Phi-3-mini-4k-instruct', maxTokens: config.maxTokens || 512, temperature: config.temperature || 0.7, ...config }; } /** * Detect available execution providers */ async detectExecutionProviders() { const providers = []; // Try CUDA for NVIDIA GPUs try { if (process.platform === 'linux') { providers.push('cuda'); this.executionProviders.push('cuda'); } } catch (e) { // CUDA not available } // Try DirectML for Windows GPUs try { if (process.platform === 'win32') { providers.push('dml'); this.executionProviders.push('dml'); } } catch (e) { // DirectML not available } // Always fallback to CPU providers.push('cpu'); this.executionProviders.push('cpu'); console.log(`🔧 ONNX Execution Providers: ${this.executionProviders.join(', ')}`); return providers; } /** * Initialize ONNX session with model */ async initializeSession() { if (this.generator) return; try { await ensureOnnxDependencies(); console.log(`📦 Loading ONNX model: ${this.config.modelId}`); // Use Transformers.js for easier model loading this.generator = await transformers.pipeline('text-generation', this.config.modelId, { quantized: true, // Use quantized models for better CPU performance }); console.log(`✅ ONNX model loaded successfully`); } catch (error) { const providerError = { name: 'ONNXInitError', message: `Failed to initialize ONNX model: ${error}`, provider: 'onnx', retryable: false }; throw providerError; } } /** * Format messages for model input */ formatMessages(messages) { // Simple chat template for Phi-3 let prompt = ''; for (const msg of messages) { const content = typeof msg.content === 'string' ? msg.content : msg.content.map(c => c.type === 'text' ? c.text : '').join(''); if (msg.role === 'user') { prompt += `<|user|>\n${content}<|end|>\n`; } else if (msg.role === 'assistant') { prompt += `<|assistant|>\n${content}<|end|>\n`; } else if (msg.role === 'system') { prompt += `<|system|>\n${content}<|end|>\n`; } } prompt += '<|assistant|>\n'; return prompt; } /** * Chat completion */ async chat(params) { await this.initializeSession(); const startTime = Date.now(); const prompt = this.formatMessages(params.messages); try { const result = await this.generator(prompt, { max_new_tokens: params.maxTokens || this.config.maxTokens, temperature: params.temperature || this.config.temperature, do_sample: true, top_p: 0.9, }); const generatedText = result[0].generated_text; // Extract only the new assistant response const assistantResponse = generatedText .split('<|assistant|>') .pop() ?.split('<|end|>')[0] ?.trim() || ''; const latency = Date.now() - startTime; // Estimate token counts (rough approximation) const inputTokens = Math.ceil(prompt.length / 4); const outputTokens = Math.ceil(assistantResponse.length / 4); const content = [{ type: 'text', text: assistantResponse }]; return { id: `onnx-${Date.now()}`, model: this.config.modelId || 'onnx-model', content, stopReason: 'end_turn', usage: { inputTokens, outputTokens }, metadata: { provider: 'onnx', model: this.config.modelId, latency, cost: 0, // Local inference is free executionProviders: this.executionProviders } }; } catch (error) { const providerError = { name: 'ONNXInferenceError', message: `ONNX inference failed: ${error}`, provider: 'onnx', retryable: true }; throw providerError; } } /** * Streaming generation */ async *stream(params) { await this.initializeSession(); const prompt = this.formatMessages(params.messages); try { // Note: Transformers.js doesn't natively support streaming // We'll simulate it by yielding tokens as they're generated const result = await this.generator(prompt, { max_new_tokens: params.maxTokens || this.config.maxTokens, temperature: params.temperature || this.config.temperature, do_sample: true, top_p: 0.9, }); const generatedText = result[0].generated_text; const assistantResponse = generatedText .split('<|assistant|>') .pop() ?.split('<|end|>')[0] ?.trim() || ''; // Simulate streaming by chunking the response const words = assistantResponse.split(' '); for (let i = 0; i < words.length; i++) { const chunk = words[i] + (i < words.length - 1 ? ' ' : ''); yield { type: 'content_block_delta', delta: { type: 'text_delta', text: chunk } }; // Small delay to simulate real streaming await new Promise(resolve => setTimeout(resolve, 10)); } yield { type: 'message_stop' }; } catch (error) { const providerError = { name: 'ONNXStreamError', message: `ONNX streaming failed: ${error}`, provider: 'onnx', retryable: true }; throw providerError; } } /** * Validate capabilities */ validateCapabilities(features) { const supported = ['chat', 'stream']; return features.every(f => supported.includes(f)); } /** * Get model info */ getModelInfo() { return { modelId: this.config.modelId, executionProviders: this.executionProviders, supportsGPU: this.executionProviders.includes('cuda') || this.executionProviders.includes('dml'), initialized: this.generator !== null }; } /** * Cleanup resources */ async dispose() { if (this.generator) { this.generator = null; } if (this.session) { await this.session.release(); this.session = null; } } } //# sourceMappingURL=onnx.js.map