401 lines
15 KiB
JavaScript
401 lines
15 KiB
JavaScript
"use strict";
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
};
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.Tensor = void 0;
|
|
const guid_typescript_1 = require("guid-typescript");
|
|
const long_1 = __importDefault(require("long"));
|
|
const onnx_proto_1 = require("onnx-proto");
|
|
const ort_generated_1 = require("./ort-schema/ort-generated");
|
|
const util_1 = require("./util");
|
|
var ortFbs = ort_generated_1.onnxruntime.experimental.fbs;
|
|
class Tensor {
|
|
/**
|
|
* get the underlying tensor data
|
|
*/
|
|
get data() {
|
|
if (this.cache === undefined) {
|
|
const data = this.dataProvider(this.dataId);
|
|
if (data.length !== this.size) {
|
|
throw new Error('Length of data provided by the Data Provider is inconsistent with the dims of this Tensor.');
|
|
}
|
|
this.cache = data;
|
|
}
|
|
return this.cache;
|
|
}
|
|
/**
|
|
* get the underlying string tensor data. Should only use when type is STRING
|
|
*/
|
|
get stringData() {
|
|
if (this.type !== 'string') {
|
|
throw new TypeError('data type is not string');
|
|
}
|
|
return this.data;
|
|
}
|
|
/**
|
|
* get the underlying integer tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16,
|
|
* INT16, INT32, UINT32, BOOL)
|
|
*/
|
|
get integerData() {
|
|
switch (this.type) {
|
|
case 'uint8':
|
|
case 'int8':
|
|
case 'uint16':
|
|
case 'int16':
|
|
case 'int32':
|
|
case 'uint32':
|
|
case 'bool':
|
|
return this.data;
|
|
default:
|
|
throw new TypeError('data type is not integer (uint8, int8, uint16, int16, int32, uint32, bool)');
|
|
}
|
|
}
|
|
/**
|
|
* get the underlying float tensor data. Should only use when type is one of the following: (FLOAT, DOUBLE)
|
|
*/
|
|
get floatData() {
|
|
switch (this.type) {
|
|
case 'float32':
|
|
case 'float64':
|
|
return this.data;
|
|
default:
|
|
throw new TypeError('data type is not float (float32, float64)');
|
|
}
|
|
}
|
|
/**
|
|
* get the underlying number tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16,
|
|
* INT16, INT32, UINT32, BOOL, FLOAT, DOUBLE)
|
|
*/
|
|
get numberData() {
|
|
if (this.type !== 'string') {
|
|
return this.data;
|
|
}
|
|
throw new TypeError('type cannot be non-number (string)');
|
|
}
|
|
/**
|
|
* get value of an element at the given indices
|
|
*/
|
|
get(indices) {
|
|
return this.data[util_1.ShapeUtil.indicesToOffset(indices, this.strides)];
|
|
}
|
|
/**
|
|
* set value of an element at the given indices
|
|
*/
|
|
set(indices, value) {
|
|
this.data[util_1.ShapeUtil.indicesToOffset(indices, this.strides)] = value;
|
|
}
|
|
/**
|
|
* get the underlying tensor data asynchronously
|
|
*/
|
|
async getData() {
|
|
if (this.cache === undefined) {
|
|
this.cache = await this.asyncDataProvider(this.dataId);
|
|
}
|
|
return this.cache;
|
|
}
|
|
/**
|
|
* get the strides for each dimension
|
|
*/
|
|
get strides() {
|
|
if (!this._strides) {
|
|
this._strides = util_1.ShapeUtil.computeStrides(this.dims);
|
|
}
|
|
return this._strides;
|
|
}
|
|
constructor(
|
|
/**
|
|
* get the dimensions of the tensor
|
|
*/
|
|
dims,
|
|
/**
|
|
* get the type of the tensor
|
|
*/
|
|
type, dataProvider, asyncDataProvider, cache,
|
|
/**
|
|
* get the data ID that used to map to a tensor data
|
|
*/
|
|
dataId = guid_typescript_1.Guid.create()) {
|
|
this.dims = dims;
|
|
this.type = type;
|
|
this.dataProvider = dataProvider;
|
|
this.asyncDataProvider = asyncDataProvider;
|
|
this.cache = cache;
|
|
this.dataId = dataId;
|
|
this.size = util_1.ShapeUtil.validateDimsAndCalcSize(dims);
|
|
const size = this.size;
|
|
const empty = (dataProvider === undefined && asyncDataProvider === undefined && cache === undefined);
|
|
if (cache !== undefined) {
|
|
if (cache.length !== size) {
|
|
throw new RangeError('Input dims doesn\'t match data length.');
|
|
}
|
|
}
|
|
if (type === 'string') {
|
|
if (cache !== undefined && (!Array.isArray(cache) || !cache.every(i => typeof i === 'string'))) {
|
|
throw new TypeError('cache should be a string array');
|
|
}
|
|
if (empty) {
|
|
this.cache = new Array(size);
|
|
}
|
|
}
|
|
else {
|
|
if (cache !== undefined) {
|
|
const constructor = dataviewConstructor(type);
|
|
if (!(cache instanceof constructor)) {
|
|
throw new TypeError(`cache should be type ${constructor.name}`);
|
|
}
|
|
}
|
|
if (empty) {
|
|
const buf = new ArrayBuffer(size * sizeof(type));
|
|
this.cache = createView(buf, type);
|
|
}
|
|
}
|
|
}
|
|
/**
|
|
* Construct new Tensor from a ONNX Tensor object
|
|
* @param tensorProto the ONNX Tensor
|
|
*/
|
|
static fromProto(tensorProto) {
|
|
if (!tensorProto) {
|
|
throw new Error('cannot construct Value from an empty tensor');
|
|
}
|
|
const type = util_1.ProtoUtil.tensorDataTypeFromProto(tensorProto.dataType);
|
|
const dims = util_1.ProtoUtil.tensorDimsFromProto(tensorProto.dims);
|
|
const value = new Tensor(dims, type);
|
|
if (type === 'string') {
|
|
// When it's STRING type, the value should always be stored in field
|
|
// 'stringData'
|
|
tensorProto.stringData.forEach((str, i) => {
|
|
value.data[i] = (0, util_1.decodeUtf8String)(str);
|
|
});
|
|
}
|
|
else if (tensorProto.rawData && typeof tensorProto.rawData.byteLength === 'number' &&
|
|
tensorProto.rawData.byteLength > 0) {
|
|
// NOT considering segment for now (IMPORTANT)
|
|
// populate value from rawData
|
|
const dataDest = value.data;
|
|
const dataSource = new DataView(tensorProto.rawData.buffer, tensorProto.rawData.byteOffset, tensorProto.rawData.byteLength);
|
|
const elementSize = sizeofProto(tensorProto.dataType);
|
|
const length = tensorProto.rawData.byteLength / elementSize;
|
|
if (tensorProto.rawData.byteLength % elementSize !== 0) {
|
|
throw new Error('invalid buffer length');
|
|
}
|
|
if (dataDest.length !== length) {
|
|
throw new Error('buffer length mismatch');
|
|
}
|
|
for (let i = 0; i < length; i++) {
|
|
const n = readProto(dataSource, tensorProto.dataType, i * elementSize);
|
|
dataDest[i] = n;
|
|
}
|
|
}
|
|
else {
|
|
// populate value from array
|
|
let array;
|
|
switch (tensorProto.dataType) {
|
|
case onnx_proto_1.onnx.TensorProto.DataType.FLOAT:
|
|
array = tensorProto.floatData;
|
|
break;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT32:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT16:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT16:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT8:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT8:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.BOOL:
|
|
array = tensorProto.int32Data;
|
|
break;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT64:
|
|
array = tensorProto.int64Data;
|
|
break;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.DOUBLE:
|
|
array = tensorProto.doubleData;
|
|
break;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT32:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT64:
|
|
array = tensorProto.uint64Data;
|
|
break;
|
|
default:
|
|
// should never run here
|
|
throw new Error('unspecific error');
|
|
}
|
|
if (array === null || array === undefined) {
|
|
throw new Error('failed to populate data from a tensorproto value');
|
|
}
|
|
const data = value.data;
|
|
if (data.length !== array.length) {
|
|
throw new Error('array length mismatch');
|
|
}
|
|
for (let i = 0; i < array.length; i++) {
|
|
const element = array[i];
|
|
if (long_1.default.isLong(element)) {
|
|
data[i] = longToNumber(element, tensorProto.dataType);
|
|
}
|
|
else {
|
|
data[i] = element;
|
|
}
|
|
}
|
|
}
|
|
return value;
|
|
}
|
|
/**
|
|
* Construct new Tensor from raw data
|
|
* @param data the raw data object. Should be a string array for 'string' tensor, and the corresponding typed array
|
|
* for other types of tensor.
|
|
* @param dims the dimensions of the tensor
|
|
* @param type the type of the tensor
|
|
*/
|
|
static fromData(data, dims, type) {
|
|
return new Tensor(dims, type, undefined, undefined, data);
|
|
}
|
|
static fromOrtTensor(ortTensor) {
|
|
if (!ortTensor) {
|
|
throw new Error('cannot construct Value from an empty tensor');
|
|
}
|
|
const dims = util_1.ProtoUtil.tensorDimsFromORTFormat(ortTensor);
|
|
const type = util_1.ProtoUtil.tensorDataTypeFromProto(ortTensor.dataType());
|
|
const value = new Tensor(dims, type);
|
|
if (type === 'string') {
|
|
// When it's STRING type, the value should always be stored in field
|
|
// 'stringData'
|
|
for (let i = 0; i < ortTensor.stringDataLength(); i++) {
|
|
value.data[i] = ortTensor.stringData(i);
|
|
}
|
|
}
|
|
else if (ortTensor.rawDataArray() && typeof ortTensor.rawDataLength() === 'number' && ortTensor.rawDataLength() > 0) {
|
|
// NOT considering segment for now (IMPORTANT)
|
|
// populate value from rawData
|
|
const dataDest = value.data;
|
|
const dataSource = new DataView(ortTensor.rawDataArray().buffer, ortTensor.rawDataArray().byteOffset, ortTensor.rawDataLength());
|
|
const elementSize = sizeofProto(ortTensor.dataType());
|
|
const length = ortTensor.rawDataLength() / elementSize;
|
|
if (ortTensor.rawDataLength() % elementSize !== 0) {
|
|
throw new Error('invalid buffer length');
|
|
}
|
|
if (dataDest.length !== length) {
|
|
throw new Error('buffer length mismatch');
|
|
}
|
|
for (let i = 0; i < length; i++) {
|
|
const n = readProto(dataSource, ortTensor.dataType(), i * elementSize);
|
|
dataDest[i] = n;
|
|
}
|
|
}
|
|
return value;
|
|
}
|
|
}
|
|
exports.Tensor = Tensor;
|
|
function sizeof(type) {
|
|
switch (type) {
|
|
case 'bool':
|
|
case 'int8':
|
|
case 'uint8':
|
|
return 1;
|
|
case 'int16':
|
|
case 'uint16':
|
|
return 2;
|
|
case 'int32':
|
|
case 'uint32':
|
|
case 'float32':
|
|
return 4;
|
|
case 'float64':
|
|
return 8;
|
|
default:
|
|
throw new Error(`cannot calculate sizeof() on type ${type}`);
|
|
}
|
|
}
|
|
function sizeofProto(type) {
|
|
switch (type) {
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT8:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT8:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.BOOL:
|
|
return 1;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT16:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT16:
|
|
return 2;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.FLOAT:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT32:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT32:
|
|
return 4;
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT64:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.DOUBLE:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT64:
|
|
return 8;
|
|
default:
|
|
throw new Error(`cannot calculate sizeof() on type ${onnx_proto_1.onnx.TensorProto.DataType[type]}`);
|
|
}
|
|
}
|
|
function createView(dataBuffer, type) {
|
|
return new (dataviewConstructor(type))(dataBuffer);
|
|
}
|
|
function dataviewConstructor(type) {
|
|
switch (type) {
|
|
case 'bool':
|
|
case 'uint8':
|
|
return Uint8Array;
|
|
case 'int8':
|
|
return Int8Array;
|
|
case 'int16':
|
|
return Int16Array;
|
|
case 'uint16':
|
|
return Uint16Array;
|
|
case 'int32':
|
|
return Int32Array;
|
|
case 'uint32':
|
|
return Uint32Array;
|
|
case 'float32':
|
|
return Float32Array;
|
|
case 'float64':
|
|
return Float64Array;
|
|
default:
|
|
// should never run to here
|
|
throw new Error('unspecified error');
|
|
}
|
|
}
|
|
// convert a long number to a 32-bit integer (cast-down)
|
|
function longToNumber(i, type) {
|
|
// INT64, UINT32, UINT64
|
|
if (type === onnx_proto_1.onnx.TensorProto.DataType.INT64 || type === ortFbs.TensorDataType.INT64) {
|
|
if (i.greaterThanOrEqual(2147483648) || i.lessThan(-2147483648)) {
|
|
throw new TypeError('int64 is not supported');
|
|
}
|
|
}
|
|
else if (type === onnx_proto_1.onnx.TensorProto.DataType.UINT32 || type === ortFbs.TensorDataType.UINT32 ||
|
|
type === onnx_proto_1.onnx.TensorProto.DataType.UINT64 || type === ortFbs.TensorDataType.UINT64) {
|
|
if (i.greaterThanOrEqual(4294967296) || i.lessThan(0)) {
|
|
throw new TypeError('uint64 is not supported');
|
|
}
|
|
}
|
|
else {
|
|
throw new TypeError(`not a LONG type: ${onnx_proto_1.onnx.TensorProto.DataType[type]}`);
|
|
}
|
|
return i.toNumber();
|
|
}
|
|
// read one value from TensorProto
|
|
function readProto(view, type, byteOffset) {
|
|
switch (type) {
|
|
case onnx_proto_1.onnx.TensorProto.DataType.BOOL:
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT8:
|
|
return view.getUint8(byteOffset);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT8:
|
|
return view.getInt8(byteOffset);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT16:
|
|
return view.getUint16(byteOffset, true);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT16:
|
|
return view.getInt16(byteOffset, true);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.FLOAT:
|
|
return view.getFloat32(byteOffset, true);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT32:
|
|
return view.getInt32(byteOffset, true);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT32:
|
|
return view.getUint32(byteOffset, true);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.INT64:
|
|
return longToNumber(long_1.default.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), false), type);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.DOUBLE:
|
|
return view.getFloat64(byteOffset, true);
|
|
case onnx_proto_1.onnx.TensorProto.DataType.UINT64:
|
|
return longToNumber(long_1.default.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), true), type);
|
|
default:
|
|
throw new Error(`cannot read from DataView for type ${onnx_proto_1.onnx.TensorProto.DataType[type]}`);
|
|
}
|
|
}
|
|
//# sourceMappingURL=tensor.js.map
|