265 lines
8.6 KiB
JavaScript
265 lines
8.6 KiB
JavaScript
/**
|
|
* ONNX Runtime Provider for Local Model Inference
|
|
*
|
|
* Supports CPU and GPU execution providers for optimized local inference
|
|
* Compatible with Phi-3, Llama, and other ONNX models
|
|
*/
|
|
// Dynamic imports for optional ONNX dependencies
|
|
let ort;
|
|
let transformers;
|
|
async function ensureOnnxDependencies() {
|
|
if (!ort) {
|
|
try {
|
|
const ortModule = await import('onnxruntime-node');
|
|
ort = ortModule;
|
|
}
|
|
catch (e) {
|
|
throw new Error('onnxruntime-node not installed. Run: npm install onnxruntime-node');
|
|
}
|
|
}
|
|
if (!transformers) {
|
|
try {
|
|
const transformersModule = await import('@xenova/transformers');
|
|
transformers = transformersModule;
|
|
transformers.env.allowLocalModels = true;
|
|
}
|
|
catch (e) {
|
|
throw new Error('@xenova/transformers not installed. Run: npm install @xenova/transformers');
|
|
}
|
|
}
|
|
}
|
|
export class ONNXProvider {
|
|
name = 'onnx';
|
|
type = 'custom';
|
|
supportsStreaming = true;
|
|
supportsTools = false;
|
|
supportsMCP = false;
|
|
session = null;
|
|
generator = null;
|
|
config;
|
|
executionProviders = [];
|
|
constructor(config = {}) {
|
|
this.config = {
|
|
modelId: config.modelId || 'Xenova/Phi-3-mini-4k-instruct',
|
|
maxTokens: config.maxTokens || 512,
|
|
temperature: config.temperature || 0.7,
|
|
...config
|
|
};
|
|
}
|
|
/**
|
|
* Detect available execution providers
|
|
*/
|
|
async detectExecutionProviders() {
|
|
const providers = [];
|
|
// Try CUDA for NVIDIA GPUs
|
|
try {
|
|
if (process.platform === 'linux') {
|
|
providers.push('cuda');
|
|
this.executionProviders.push('cuda');
|
|
}
|
|
}
|
|
catch (e) {
|
|
// CUDA not available
|
|
}
|
|
// Try DirectML for Windows GPUs
|
|
try {
|
|
if (process.platform === 'win32') {
|
|
providers.push('dml');
|
|
this.executionProviders.push('dml');
|
|
}
|
|
}
|
|
catch (e) {
|
|
// DirectML not available
|
|
}
|
|
// Always fallback to CPU
|
|
providers.push('cpu');
|
|
this.executionProviders.push('cpu');
|
|
console.log(`🔧 ONNX Execution Providers: ${this.executionProviders.join(', ')}`);
|
|
return providers;
|
|
}
|
|
/**
|
|
* Initialize ONNX session with model
|
|
*/
|
|
async initializeSession() {
|
|
if (this.generator)
|
|
return;
|
|
try {
|
|
await ensureOnnxDependencies();
|
|
console.log(`📦 Loading ONNX model: ${this.config.modelId}`);
|
|
// Use Transformers.js for easier model loading
|
|
this.generator = await transformers.pipeline('text-generation', this.config.modelId, {
|
|
quantized: true, // Use quantized models for better CPU performance
|
|
});
|
|
console.log(`✅ ONNX model loaded successfully`);
|
|
}
|
|
catch (error) {
|
|
const providerError = {
|
|
name: 'ONNXInitError',
|
|
message: `Failed to initialize ONNX model: ${error}`,
|
|
provider: 'onnx',
|
|
retryable: false
|
|
};
|
|
throw providerError;
|
|
}
|
|
}
|
|
/**
|
|
* Format messages for model input
|
|
*/
|
|
formatMessages(messages) {
|
|
// Simple chat template for Phi-3
|
|
let prompt = '';
|
|
for (const msg of messages) {
|
|
const content = typeof msg.content === 'string'
|
|
? msg.content
|
|
: msg.content.map(c => c.type === 'text' ? c.text : '').join('');
|
|
if (msg.role === 'user') {
|
|
prompt += `<|user|>\n${content}<|end|>\n`;
|
|
}
|
|
else if (msg.role === 'assistant') {
|
|
prompt += `<|assistant|>\n${content}<|end|>\n`;
|
|
}
|
|
else if (msg.role === 'system') {
|
|
prompt += `<|system|>\n${content}<|end|>\n`;
|
|
}
|
|
}
|
|
prompt += '<|assistant|>\n';
|
|
return prompt;
|
|
}
|
|
/**
|
|
* Chat completion
|
|
*/
|
|
async chat(params) {
|
|
await this.initializeSession();
|
|
const startTime = Date.now();
|
|
const prompt = this.formatMessages(params.messages);
|
|
try {
|
|
const result = await this.generator(prompt, {
|
|
max_new_tokens: params.maxTokens || this.config.maxTokens,
|
|
temperature: params.temperature || this.config.temperature,
|
|
do_sample: true,
|
|
top_p: 0.9,
|
|
});
|
|
const generatedText = result[0].generated_text;
|
|
// Extract only the new assistant response
|
|
const assistantResponse = generatedText
|
|
.split('<|assistant|>')
|
|
.pop()
|
|
?.split('<|end|>')[0]
|
|
?.trim() || '';
|
|
const latency = Date.now() - startTime;
|
|
// Estimate token counts (rough approximation)
|
|
const inputTokens = Math.ceil(prompt.length / 4);
|
|
const outputTokens = Math.ceil(assistantResponse.length / 4);
|
|
const content = [{
|
|
type: 'text',
|
|
text: assistantResponse
|
|
}];
|
|
return {
|
|
id: `onnx-${Date.now()}`,
|
|
model: this.config.modelId || 'onnx-model',
|
|
content,
|
|
stopReason: 'end_turn',
|
|
usage: {
|
|
inputTokens,
|
|
outputTokens
|
|
},
|
|
metadata: {
|
|
provider: 'onnx',
|
|
model: this.config.modelId,
|
|
latency,
|
|
cost: 0, // Local inference is free
|
|
executionProviders: this.executionProviders
|
|
}
|
|
};
|
|
}
|
|
catch (error) {
|
|
const providerError = {
|
|
name: 'ONNXInferenceError',
|
|
message: `ONNX inference failed: ${error}`,
|
|
provider: 'onnx',
|
|
retryable: true
|
|
};
|
|
throw providerError;
|
|
}
|
|
}
|
|
/**
|
|
* Streaming generation
|
|
*/
|
|
async *stream(params) {
|
|
await this.initializeSession();
|
|
const prompt = this.formatMessages(params.messages);
|
|
try {
|
|
// Note: Transformers.js doesn't natively support streaming
|
|
// We'll simulate it by yielding tokens as they're generated
|
|
const result = await this.generator(prompt, {
|
|
max_new_tokens: params.maxTokens || this.config.maxTokens,
|
|
temperature: params.temperature || this.config.temperature,
|
|
do_sample: true,
|
|
top_p: 0.9,
|
|
});
|
|
const generatedText = result[0].generated_text;
|
|
const assistantResponse = generatedText
|
|
.split('<|assistant|>')
|
|
.pop()
|
|
?.split('<|end|>')[0]
|
|
?.trim() || '';
|
|
// Simulate streaming by chunking the response
|
|
const words = assistantResponse.split(' ');
|
|
for (let i = 0; i < words.length; i++) {
|
|
const chunk = words[i] + (i < words.length - 1 ? ' ' : '');
|
|
yield {
|
|
type: 'content_block_delta',
|
|
delta: {
|
|
type: 'text_delta',
|
|
text: chunk
|
|
}
|
|
};
|
|
// Small delay to simulate real streaming
|
|
await new Promise(resolve => setTimeout(resolve, 10));
|
|
}
|
|
yield {
|
|
type: 'message_stop'
|
|
};
|
|
}
|
|
catch (error) {
|
|
const providerError = {
|
|
name: 'ONNXStreamError',
|
|
message: `ONNX streaming failed: ${error}`,
|
|
provider: 'onnx',
|
|
retryable: true
|
|
};
|
|
throw providerError;
|
|
}
|
|
}
|
|
/**
|
|
* Validate capabilities
|
|
*/
|
|
validateCapabilities(features) {
|
|
const supported = ['chat', 'stream'];
|
|
return features.every(f => supported.includes(f));
|
|
}
|
|
/**
|
|
* Get model info
|
|
*/
|
|
getModelInfo() {
|
|
return {
|
|
modelId: this.config.modelId,
|
|
executionProviders: this.executionProviders,
|
|
supportsGPU: this.executionProviders.includes('cuda') || this.executionProviders.includes('dml'),
|
|
initialized: this.generator !== null
|
|
};
|
|
}
|
|
/**
|
|
* Cleanup resources
|
|
*/
|
|
async dispose() {
|
|
if (this.generator) {
|
|
this.generator = null;
|
|
}
|
|
if (this.session) {
|
|
await this.session.release();
|
|
this.session = null;
|
|
}
|
|
}
|
|
}
|
|
//# sourceMappingURL=onnx.js.map
|