295 lines
12 KiB
JavaScript
295 lines
12 KiB
JavaScript
/**
|
|
* ONNX Runtime Local Inference Provider for Phi-4
|
|
*
|
|
* Uses onnxruntime-node for true local CPU/GPU inference
|
|
*/
|
|
import * as ort from 'onnxruntime-node';
|
|
import { get_encoding } from 'tiktoken';
|
|
import { ensurePhi4Model, ModelDownloader } from '../../utils/model-downloader.js';
|
|
export class ONNXLocalProvider {
|
|
name = 'onnx-local';
|
|
type = 'custom';
|
|
supportsStreaming = false; // Streaming requires complex token generation loop
|
|
supportsTools = false;
|
|
supportsMCP = false;
|
|
session = null;
|
|
config;
|
|
tokenizer = null;
|
|
tiktoken = null;
|
|
constructor(config = {}) {
|
|
this.config = {
|
|
modelPath: config.modelPath || './models/phi-4/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx',
|
|
executionProviders: config.executionProviders || ['cpu'],
|
|
maxTokens: config.maxTokens || 100,
|
|
temperature: config.temperature || 0.7
|
|
};
|
|
}
|
|
/**
|
|
* Load optimized tiktoken tokenizer (cl100k_base for Phi-4)
|
|
*/
|
|
async loadTokenizer() {
|
|
if (this.tiktoken)
|
|
return;
|
|
try {
|
|
// Use cl100k_base encoding (GPT-4, similar to Phi-4)
|
|
this.tiktoken = get_encoding('cl100k_base');
|
|
console.log('✅ Tokenizer loaded (tiktoken cl100k_base)');
|
|
}
|
|
catch (error) {
|
|
console.error('❌ Failed to load tiktoken:', error);
|
|
throw new Error(`Tokenizer loading failed: ${error}`);
|
|
}
|
|
}
|
|
/**
|
|
* Encode text using tiktoken (fast BPE)
|
|
*/
|
|
encode(text) {
|
|
return Array.from(this.tiktoken.encode(text));
|
|
}
|
|
/**
|
|
* Decode tokens using tiktoken
|
|
*/
|
|
decode(ids) {
|
|
try {
|
|
const decoded = this.tiktoken.decode(new Uint32Array(ids));
|
|
// tiktoken returns buffer, convert to string
|
|
if (typeof decoded === 'string') {
|
|
return decoded;
|
|
}
|
|
else if (decoded instanceof Uint8Array || decoded instanceof Buffer) {
|
|
return new TextDecoder().decode(decoded);
|
|
}
|
|
return String(decoded);
|
|
}
|
|
catch (error) {
|
|
console.warn('Decode error, returning raw IDs:', error);
|
|
return ids.join(',');
|
|
}
|
|
}
|
|
/**
|
|
* Initialize ONNX session (with automatic model download)
|
|
*/
|
|
async initializeSession() {
|
|
if (this.session)
|
|
return;
|
|
try {
|
|
// Ensure model is downloaded
|
|
console.log(`🔍 Checking for Phi-4 ONNX model...`);
|
|
const modelPath = await ensurePhi4Model((progress) => {
|
|
if (progress.percentage % 10 < 1) { // Log every ~10%
|
|
console.log(` 📥 Downloading: ${ModelDownloader.formatProgress(progress)}`);
|
|
}
|
|
});
|
|
// Update config with actual model path
|
|
this.config.modelPath = modelPath;
|
|
console.log(`📦 Loading ONNX model: ${this.config.modelPath}`);
|
|
this.session = await ort.InferenceSession.create(this.config.modelPath, {
|
|
executionProviders: this.config.executionProviders,
|
|
graphOptimizationLevel: 'all',
|
|
enableCpuMemArena: true,
|
|
enableMemPattern: true
|
|
});
|
|
console.log(`✅ ONNX model loaded`);
|
|
console.log(`🔧 Execution providers: ${this.config.executionProviders.join(', ')}`);
|
|
// Load tokenizer
|
|
await this.loadTokenizer();
|
|
}
|
|
catch (error) {
|
|
const providerError = {
|
|
name: 'ONNXInitError',
|
|
message: `Failed to initialize ONNX model: ${error}`,
|
|
provider: 'onnx-local',
|
|
retryable: false
|
|
};
|
|
throw providerError;
|
|
}
|
|
}
|
|
/**
|
|
* Format messages for Phi-4 chat template
|
|
*/
|
|
formatMessages(messages) {
|
|
let prompt = '';
|
|
for (const msg of messages) {
|
|
const content = typeof msg.content === 'string'
|
|
? msg.content
|
|
: msg.content.map(c => c.type === 'text' ? c.text : '').join('');
|
|
if (msg.role === 'system') {
|
|
prompt += `<|system|>\n${content}<|end|>\n`;
|
|
}
|
|
else if (msg.role === 'user') {
|
|
prompt += `<|user|>\n${content}<|end|>\n`;
|
|
}
|
|
else if (msg.role === 'assistant') {
|
|
prompt += `<|assistant|>\n${content}<|end|>\n`;
|
|
}
|
|
}
|
|
prompt += '<|assistant|>\n';
|
|
return prompt;
|
|
}
|
|
/**
|
|
* Initialize KV cache tensors for all 32 layers
|
|
* Phi-4 architecture: 32 layers, 8 KV heads, 128 head_dim
|
|
*/
|
|
initializeKVCache(batchSize, sequenceLength) {
|
|
const numLayers = 32;
|
|
const numKVHeads = 8;
|
|
const headDim = 128; // 3072 / 24 = 128
|
|
const kvCache = {};
|
|
// Initialize empty cache for each layer (key and value)
|
|
for (let i = 0; i < numLayers; i++) {
|
|
// Empty cache: [batch_size, num_kv_heads, 0, head_dim]
|
|
const emptyCache = new Float32Array(0);
|
|
kvCache[`past_key_values.${i}.key`] = new ort.Tensor('float32', emptyCache, [batchSize, numKVHeads, 0, headDim]);
|
|
kvCache[`past_key_values.${i}.value`] = new ort.Tensor('float32', emptyCache, [batchSize, numKVHeads, 0, headDim]);
|
|
}
|
|
return kvCache;
|
|
}
|
|
/**
|
|
* Chat completion using ONNX with KV cache
|
|
*/
|
|
async chat(params) {
|
|
await this.initializeSession();
|
|
const startTime = Date.now();
|
|
const prompt = this.formatMessages(params.messages);
|
|
try {
|
|
// Tokenize input using optimized tiktoken
|
|
const inputIds = this.encode(prompt);
|
|
console.log(`📝 Input tokens: ${inputIds.length}`);
|
|
// Initialize KV cache (reusable for batch)
|
|
let pastKVCache = this.initializeKVCache(1, 0);
|
|
// Track all generated tokens
|
|
const allTokenIds = [...inputIds];
|
|
const outputIds = [];
|
|
// Pre-allocate tensor buffers for performance
|
|
const maxSeqLen = inputIds.length + (params.maxTokens || this.config.maxTokens);
|
|
// Autoregressive generation loop
|
|
const maxNewTokens = params.maxTokens || this.config.maxTokens;
|
|
for (let step = 0; step < maxNewTokens; step++) {
|
|
// For first step, use all input tokens; for subsequent steps, use only last token
|
|
const currentInputIds = step === 0 ? inputIds : [outputIds[outputIds.length - 1]];
|
|
const currentSeqLen = currentInputIds.length;
|
|
// Create input tensor for current step
|
|
const inputTensor = new ort.Tensor('int64', BigInt64Array.from(currentInputIds.map(BigInt)), [1, currentSeqLen]);
|
|
// Create attention mask for current step
|
|
const totalSeqLen = allTokenIds.length;
|
|
const attentionMask = new ort.Tensor('int64', BigInt64Array.from(Array(totalSeqLen).fill(1n)), [1, totalSeqLen]);
|
|
// Build feeds with input, attention mask, and KV cache
|
|
const feeds = {
|
|
input_ids: inputTensor,
|
|
attention_mask: attentionMask,
|
|
...pastKVCache
|
|
};
|
|
// Run inference
|
|
const results = await this.session.run(feeds);
|
|
// Get logits for next token (last position)
|
|
const logits = results.logits.data;
|
|
const vocabSize = results.logits.dims[results.logits.dims.length - 1];
|
|
// Extract logits for last token
|
|
const lastTokenLogitsOffset = (currentSeqLen - 1) * vocabSize;
|
|
// Apply temperature and get next token
|
|
let nextToken = 0;
|
|
let maxVal = -Infinity;
|
|
for (let i = 0; i < vocabSize; i++) {
|
|
const logit = logits[lastTokenLogitsOffset + i] / (params.temperature || this.config.temperature);
|
|
if (logit > maxVal) {
|
|
maxVal = logit;
|
|
nextToken = i;
|
|
}
|
|
}
|
|
// Add to output
|
|
outputIds.push(nextToken);
|
|
allTokenIds.push(nextToken);
|
|
// Check for end token (2 is typical EOS for Phi models)
|
|
if (nextToken === 2 || nextToken === 0) {
|
|
console.log(`🛑 Stop token detected: ${nextToken}`);
|
|
break;
|
|
}
|
|
// Update KV cache from outputs for next iteration
|
|
pastKVCache = {};
|
|
for (let i = 0; i < 32; i++) {
|
|
pastKVCache[`past_key_values.${i}.key`] = results[`present.${i}.key`];
|
|
pastKVCache[`past_key_values.${i}.value`] = results[`present.${i}.value`];
|
|
}
|
|
// Progress indicator
|
|
if ((step + 1) % 10 === 0) {
|
|
console.log(`🔄 Generated ${step + 1} tokens...`);
|
|
}
|
|
}
|
|
// Decode output using optimized tiktoken
|
|
const generatedText = this.decode(outputIds);
|
|
const latency = Date.now() - startTime;
|
|
const tokensPerSecond = (outputIds.length / (latency / 1000)).toFixed(1);
|
|
console.log(`✅ Generated: ${generatedText}`);
|
|
console.log(`⏱️ Latency: ${latency}ms (${tokensPerSecond} tokens/sec)`);
|
|
const content = [{
|
|
type: 'text',
|
|
text: generatedText.trim()
|
|
}];
|
|
return {
|
|
id: `onnx-local-${Date.now()}`,
|
|
model: this.config.modelPath,
|
|
content,
|
|
stopReason: 'end_turn',
|
|
usage: {
|
|
inputTokens: inputIds.length,
|
|
outputTokens: outputIds.length
|
|
},
|
|
metadata: {
|
|
provider: 'onnx-local',
|
|
model: 'Phi-4-mini-instruct-onnx',
|
|
latency,
|
|
cost: 0, // Local inference is free
|
|
executionProviders: this.config.executionProviders,
|
|
tokensPerSecond: parseFloat(tokensPerSecond)
|
|
}
|
|
};
|
|
}
|
|
catch (error) {
|
|
const providerError = {
|
|
name: 'ONNXInferenceError',
|
|
message: `ONNX inference failed: ${error}`,
|
|
provider: 'onnx-local',
|
|
retryable: true
|
|
};
|
|
throw providerError;
|
|
}
|
|
}
|
|
/**
|
|
* Streaming not implemented (requires complex generation loop)
|
|
*/
|
|
async *stream(params) {
|
|
throw new Error('Streaming not yet implemented for ONNX local inference');
|
|
}
|
|
/**
|
|
* Validate capabilities
|
|
*/
|
|
validateCapabilities(features) {
|
|
const supported = ['chat'];
|
|
return features.every(f => supported.includes(f));
|
|
}
|
|
/**
|
|
* Get model info
|
|
*/
|
|
getModelInfo() {
|
|
return {
|
|
modelPath: this.config.modelPath,
|
|
executionProviders: this.config.executionProviders,
|
|
initialized: this.session !== null,
|
|
tokenizerLoaded: this.tiktoken !== null
|
|
};
|
|
}
|
|
/**
|
|
* Cleanup resources
|
|
*/
|
|
async dispose() {
|
|
if (this.session) {
|
|
// ONNX Runtime sessions don't have explicit disposal in Node.js
|
|
this.session = null;
|
|
}
|
|
if (this.tiktoken) {
|
|
this.tiktoken.free();
|
|
this.tiktoken = null;
|
|
}
|
|
}
|
|
}
|
|
//# sourceMappingURL=onnx-local.js.map
|