/** * ONNX Runtime Local Inference Provider for Phi-4 * * Uses onnxruntime-node for true local CPU/GPU inference */ import * as ort from 'onnxruntime-node'; import { get_encoding } from 'tiktoken'; import { ensurePhi4Model, ModelDownloader } from '../../utils/model-downloader.js'; export class ONNXLocalProvider { name = 'onnx-local'; type = 'custom'; supportsStreaming = false; // Streaming requires complex token generation loop supportsTools = false; supportsMCP = false; session = null; config; tokenizer = null; tiktoken = null; constructor(config = {}) { this.config = { modelPath: config.modelPath || './models/phi-4/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx', executionProviders: config.executionProviders || ['cpu'], maxTokens: config.maxTokens || 100, temperature: config.temperature || 0.7 }; } /** * Load optimized tiktoken tokenizer (cl100k_base for Phi-4) */ async loadTokenizer() { if (this.tiktoken) return; try { // Use cl100k_base encoding (GPT-4, similar to Phi-4) this.tiktoken = get_encoding('cl100k_base'); console.log('✅ Tokenizer loaded (tiktoken cl100k_base)'); } catch (error) { console.error('❌ Failed to load tiktoken:', error); throw new Error(`Tokenizer loading failed: ${error}`); } } /** * Encode text using tiktoken (fast BPE) */ encode(text) { return Array.from(this.tiktoken.encode(text)); } /** * Decode tokens using tiktoken */ decode(ids) { try { const decoded = this.tiktoken.decode(new Uint32Array(ids)); // tiktoken returns buffer, convert to string if (typeof decoded === 'string') { return decoded; } else if (decoded instanceof Uint8Array || decoded instanceof Buffer) { return new TextDecoder().decode(decoded); } return String(decoded); } catch (error) { console.warn('Decode error, returning raw IDs:', error); return ids.join(','); } } /** * Initialize ONNX session (with automatic model download) */ async initializeSession() { if (this.session) return; try { // Ensure model is downloaded console.log(`🔍 Checking for Phi-4 ONNX model...`); const modelPath = await ensurePhi4Model((progress) => { if (progress.percentage % 10 < 1) { // Log every ~10% console.log(` 📥 Downloading: ${ModelDownloader.formatProgress(progress)}`); } }); // Update config with actual model path this.config.modelPath = modelPath; console.log(`📦 Loading ONNX model: ${this.config.modelPath}`); this.session = await ort.InferenceSession.create(this.config.modelPath, { executionProviders: this.config.executionProviders, graphOptimizationLevel: 'all', enableCpuMemArena: true, enableMemPattern: true }); console.log(`✅ ONNX model loaded`); console.log(`🔧 Execution providers: ${this.config.executionProviders.join(', ')}`); // Load tokenizer await this.loadTokenizer(); } catch (error) { const providerError = { name: 'ONNXInitError', message: `Failed to initialize ONNX model: ${error}`, provider: 'onnx-local', retryable: false }; throw providerError; } } /** * Format messages for Phi-4 chat template */ formatMessages(messages) { let prompt = ''; for (const msg of messages) { const content = typeof msg.content === 'string' ? msg.content : msg.content.map(c => c.type === 'text' ? c.text : '').join(''); if (msg.role === 'system') { prompt += `<|system|>\n${content}<|end|>\n`; } else if (msg.role === 'user') { prompt += `<|user|>\n${content}<|end|>\n`; } else if (msg.role === 'assistant') { prompt += `<|assistant|>\n${content}<|end|>\n`; } } prompt += '<|assistant|>\n'; return prompt; } /** * Initialize KV cache tensors for all 32 layers * Phi-4 architecture: 32 layers, 8 KV heads, 128 head_dim */ initializeKVCache(batchSize, sequenceLength) { const numLayers = 32; const numKVHeads = 8; const headDim = 128; // 3072 / 24 = 128 const kvCache = {}; // Initialize empty cache for each layer (key and value) for (let i = 0; i < numLayers; i++) { // Empty cache: [batch_size, num_kv_heads, 0, head_dim] const emptyCache = new Float32Array(0); kvCache[`past_key_values.${i}.key`] = new ort.Tensor('float32', emptyCache, [batchSize, numKVHeads, 0, headDim]); kvCache[`past_key_values.${i}.value`] = new ort.Tensor('float32', emptyCache, [batchSize, numKVHeads, 0, headDim]); } return kvCache; } /** * Chat completion using ONNX with KV cache */ async chat(params) { await this.initializeSession(); const startTime = Date.now(); const prompt = this.formatMessages(params.messages); try { // Tokenize input using optimized tiktoken const inputIds = this.encode(prompt); console.log(`📝 Input tokens: ${inputIds.length}`); // Initialize KV cache (reusable for batch) let pastKVCache = this.initializeKVCache(1, 0); // Track all generated tokens const allTokenIds = [...inputIds]; const outputIds = []; // Pre-allocate tensor buffers for performance const maxSeqLen = inputIds.length + (params.maxTokens || this.config.maxTokens); // Autoregressive generation loop const maxNewTokens = params.maxTokens || this.config.maxTokens; for (let step = 0; step < maxNewTokens; step++) { // For first step, use all input tokens; for subsequent steps, use only last token const currentInputIds = step === 0 ? inputIds : [outputIds[outputIds.length - 1]]; const currentSeqLen = currentInputIds.length; // Create input tensor for current step const inputTensor = new ort.Tensor('int64', BigInt64Array.from(currentInputIds.map(BigInt)), [1, currentSeqLen]); // Create attention mask for current step const totalSeqLen = allTokenIds.length; const attentionMask = new ort.Tensor('int64', BigInt64Array.from(Array(totalSeqLen).fill(1n)), [1, totalSeqLen]); // Build feeds with input, attention mask, and KV cache const feeds = { input_ids: inputTensor, attention_mask: attentionMask, ...pastKVCache }; // Run inference const results = await this.session.run(feeds); // Get logits for next token (last position) const logits = results.logits.data; const vocabSize = results.logits.dims[results.logits.dims.length - 1]; // Extract logits for last token const lastTokenLogitsOffset = (currentSeqLen - 1) * vocabSize; // Apply temperature and get next token let nextToken = 0; let maxVal = -Infinity; for (let i = 0; i < vocabSize; i++) { const logit = logits[lastTokenLogitsOffset + i] / (params.temperature || this.config.temperature); if (logit > maxVal) { maxVal = logit; nextToken = i; } } // Add to output outputIds.push(nextToken); allTokenIds.push(nextToken); // Check for end token (2 is typical EOS for Phi models) if (nextToken === 2 || nextToken === 0) { console.log(`🛑 Stop token detected: ${nextToken}`); break; } // Update KV cache from outputs for next iteration pastKVCache = {}; for (let i = 0; i < 32; i++) { pastKVCache[`past_key_values.${i}.key`] = results[`present.${i}.key`]; pastKVCache[`past_key_values.${i}.value`] = results[`present.${i}.value`]; } // Progress indicator if ((step + 1) % 10 === 0) { console.log(`🔄 Generated ${step + 1} tokens...`); } } // Decode output using optimized tiktoken const generatedText = this.decode(outputIds); const latency = Date.now() - startTime; const tokensPerSecond = (outputIds.length / (latency / 1000)).toFixed(1); console.log(`✅ Generated: ${generatedText}`); console.log(`⏱️ Latency: ${latency}ms (${tokensPerSecond} tokens/sec)`); const content = [{ type: 'text', text: generatedText.trim() }]; return { id: `onnx-local-${Date.now()}`, model: this.config.modelPath, content, stopReason: 'end_turn', usage: { inputTokens: inputIds.length, outputTokens: outputIds.length }, metadata: { provider: 'onnx-local', model: 'Phi-4-mini-instruct-onnx', latency, cost: 0, // Local inference is free executionProviders: this.config.executionProviders, tokensPerSecond: parseFloat(tokensPerSecond) } }; } catch (error) { const providerError = { name: 'ONNXInferenceError', message: `ONNX inference failed: ${error}`, provider: 'onnx-local', retryable: true }; throw providerError; } } /** * Streaming not implemented (requires complex generation loop) */ async *stream(params) { throw new Error('Streaming not yet implemented for ONNX local inference'); } /** * Validate capabilities */ validateCapabilities(features) { const supported = ['chat']; return features.every(f => supported.includes(f)); } /** * Get model info */ getModelInfo() { return { modelPath: this.config.modelPath, executionProviders: this.config.executionProviders, initialized: this.session !== null, tokenizerLoaded: this.tiktoken !== null }; } /** * Cleanup resources */ async dispose() { if (this.session) { // ONNX Runtime sessions don't have explicit disposal in Node.js this.session = null; } if (this.tiktoken) { this.tiktoken.free(); this.tiktoken = null; } } } //# sourceMappingURL=onnx-local.js.map