tasq/node_modules/agentic-flow/dist/router/providers/onnx-phi4.js

191 lines
6.3 KiB
JavaScript

/**
* ONNX Runtime Provider for Phi-4 Model
*
* Hybrid implementation with fallback to HuggingFace Inference API
* when local ONNX model is not available
*/
import { HfInference } from '@huggingface/inference';
export class ONNXPhi4Provider {
name = 'onnx-phi4';
type = 'custom';
supportsStreaming = true;
supportsTools = false;
supportsMCP = false;
config;
hf;
modelPath = './models/phi-4/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx';
constructor(config = {}) {
this.config = {
modelId: config.modelId || 'microsoft/Phi-3-mini-4k-instruct',
useLocalONNX: config.useLocalONNX ?? false, // Default to API until local model downloaded
huggingfaceApiKey: config.huggingfaceApiKey || process.env.HUGGINGFACE_API_KEY || '',
maxTokens: config.maxTokens || 512,
temperature: config.temperature || 0.7
};
this.hf = new HfInference(this.config.huggingfaceApiKey);
}
/**
* Format messages for Phi-4 chat template
*/
formatMessages(messages) {
let prompt = '';
for (const msg of messages) {
const content = typeof msg.content === 'string'
? msg.content
: msg.content.map(c => c.type === 'text' ? c.text : '').join('');
if (msg.role === 'system') {
prompt += `<|system|>\n${content}<|end|>\n`;
}
else if (msg.role === 'user') {
prompt += `<|user|>\n${content}<|end|>\n`;
}
else if (msg.role === 'assistant') {
prompt += `<|assistant|>\n${content}<|end|>\n`;
}
}
prompt += '<|assistant|>\n';
return prompt;
}
/**
* Chat completion via HuggingFace Inference API
*/
async chatViaAPI(params) {
const startTime = Date.now();
const prompt = this.formatMessages(params.messages);
try {
const result = await this.hf.textGeneration({
model: this.config.modelId,
inputs: prompt,
parameters: {
max_new_tokens: params.maxTokens || this.config.maxTokens,
temperature: params.temperature || this.config.temperature,
return_full_text: false,
do_sample: true,
top_p: 0.9
}
});
const latency = Date.now() - startTime;
const generatedText = result.generated_text;
// Clean up the response
const assistantResponse = generatedText
.split('<|end|>')[0]
.trim();
// Estimate token counts
const inputTokens = Math.ceil(prompt.length / 4);
const outputTokens = Math.ceil(assistantResponse.length / 4);
const content = [{
type: 'text',
text: assistantResponse
}];
return {
id: `onnx-phi4-${Date.now()}`,
model: this.config.modelId,
content,
stopReason: 'end_turn',
usage: {
inputTokens,
outputTokens
},
metadata: {
provider: 'onnx-phi4',
model: this.config.modelId,
latency,
cost: outputTokens * 0.000002, // Rough estimate
mode: 'api'
}
};
}
catch (error) {
const providerError = {
name: 'ONNXPhi4APIError',
message: `HuggingFace API inference failed: ${error}`,
provider: 'onnx-phi4',
retryable: true
};
throw providerError;
}
}
/**
* Chat completion via local ONNX (not yet implemented)
*/
async chatViaONNX(params) {
throw new Error('Local ONNX inference not yet implemented. Download model.onnx.data first.');
}
/**
* Chat completion (uses API or local ONNX based on config)
*/
async chat(params) {
if (this.config.useLocalONNX) {
return this.chatViaONNX(params);
}
else {
return this.chatViaAPI(params);
}
}
/**
* Streaming generation via HuggingFace API
*/
async *stream(params) {
const prompt = this.formatMessages(params.messages);
try {
const stream = this.hf.textGenerationStream({
model: this.config.modelId,
inputs: prompt,
parameters: {
max_new_tokens: params.maxTokens || this.config.maxTokens,
temperature: params.temperature || this.config.temperature,
return_full_text: false
}
});
for await (const chunk of stream) {
if (chunk.token.text) {
yield {
type: 'content_block_delta',
delta: {
type: 'text_delta',
text: chunk.token.text
}
};
}
}
yield {
type: 'message_stop'
};
}
catch (error) {
const providerError = {
name: 'ONNXPhi4StreamError',
message: `Streaming failed: ${error}`,
provider: 'onnx-phi4',
retryable: true
};
throw providerError;
}
}
/**
* Validate capabilities
*/
validateCapabilities(features) {
const supported = ['chat', 'stream'];
return features.every(f => supported.includes(f));
}
/**
* Get model info
*/
getModelInfo() {
return {
modelId: this.config.modelId,
mode: this.config.useLocalONNX ? 'local-onnx' : 'api',
supportsLocalInference: false, // Will be true when model.onnx.data downloaded
modelPath: this.modelPath,
apiKey: this.config.huggingfaceApiKey ? '***' : undefined
};
}
/**
* Switch between API and local ONNX
*/
setMode(useLocalONNX) {
this.config.useLocalONNX = useLocalONNX;
}
}
//# sourceMappingURL=onnx-phi4.js.map