"use strict"; // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.Tensor = void 0; const guid_typescript_1 = require("guid-typescript"); const long_1 = __importDefault(require("long")); const onnx_proto_1 = require("onnx-proto"); const ort_generated_1 = require("./ort-schema/ort-generated"); const util_1 = require("./util"); var ortFbs = ort_generated_1.onnxruntime.experimental.fbs; class Tensor { /** * get the underlying tensor data */ get data() { if (this.cache === undefined) { const data = this.dataProvider(this.dataId); if (data.length !== this.size) { throw new Error('Length of data provided by the Data Provider is inconsistent with the dims of this Tensor.'); } this.cache = data; } return this.cache; } /** * get the underlying string tensor data. Should only use when type is STRING */ get stringData() { if (this.type !== 'string') { throw new TypeError('data type is not string'); } return this.data; } /** * get the underlying integer tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16, * INT16, INT32, UINT32, BOOL) */ get integerData() { switch (this.type) { case 'uint8': case 'int8': case 'uint16': case 'int16': case 'int32': case 'uint32': case 'bool': return this.data; default: throw new TypeError('data type is not integer (uint8, int8, uint16, int16, int32, uint32, bool)'); } } /** * get the underlying float tensor data. Should only use when type is one of the following: (FLOAT, DOUBLE) */ get floatData() { switch (this.type) { case 'float32': case 'float64': return this.data; default: throw new TypeError('data type is not float (float32, float64)'); } } /** * get the underlying number tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16, * INT16, INT32, UINT32, BOOL, FLOAT, DOUBLE) */ get numberData() { if (this.type !== 'string') { return this.data; } throw new TypeError('type cannot be non-number (string)'); } /** * get value of an element at the given indices */ get(indices) { return this.data[util_1.ShapeUtil.indicesToOffset(indices, this.strides)]; } /** * set value of an element at the given indices */ set(indices, value) { this.data[util_1.ShapeUtil.indicesToOffset(indices, this.strides)] = value; } /** * get the underlying tensor data asynchronously */ async getData() { if (this.cache === undefined) { this.cache = await this.asyncDataProvider(this.dataId); } return this.cache; } /** * get the strides for each dimension */ get strides() { if (!this._strides) { this._strides = util_1.ShapeUtil.computeStrides(this.dims); } return this._strides; } constructor( /** * get the dimensions of the tensor */ dims, /** * get the type of the tensor */ type, dataProvider, asyncDataProvider, cache, /** * get the data ID that used to map to a tensor data */ dataId = guid_typescript_1.Guid.create()) { this.dims = dims; this.type = type; this.dataProvider = dataProvider; this.asyncDataProvider = asyncDataProvider; this.cache = cache; this.dataId = dataId; this.size = util_1.ShapeUtil.validateDimsAndCalcSize(dims); const size = this.size; const empty = (dataProvider === undefined && asyncDataProvider === undefined && cache === undefined); if (cache !== undefined) { if (cache.length !== size) { throw new RangeError('Input dims doesn\'t match data length.'); } } if (type === 'string') { if (cache !== undefined && (!Array.isArray(cache) || !cache.every(i => typeof i === 'string'))) { throw new TypeError('cache should be a string array'); } if (empty) { this.cache = new Array(size); } } else { if (cache !== undefined) { const constructor = dataviewConstructor(type); if (!(cache instanceof constructor)) { throw new TypeError(`cache should be type ${constructor.name}`); } } if (empty) { const buf = new ArrayBuffer(size * sizeof(type)); this.cache = createView(buf, type); } } } /** * Construct new Tensor from a ONNX Tensor object * @param tensorProto the ONNX Tensor */ static fromProto(tensorProto) { if (!tensorProto) { throw new Error('cannot construct Value from an empty tensor'); } const type = util_1.ProtoUtil.tensorDataTypeFromProto(tensorProto.dataType); const dims = util_1.ProtoUtil.tensorDimsFromProto(tensorProto.dims); const value = new Tensor(dims, type); if (type === 'string') { // When it's STRING type, the value should always be stored in field // 'stringData' tensorProto.stringData.forEach((str, i) => { value.data[i] = (0, util_1.decodeUtf8String)(str); }); } else if (tensorProto.rawData && typeof tensorProto.rawData.byteLength === 'number' && tensorProto.rawData.byteLength > 0) { // NOT considering segment for now (IMPORTANT) // populate value from rawData const dataDest = value.data; const dataSource = new DataView(tensorProto.rawData.buffer, tensorProto.rawData.byteOffset, tensorProto.rawData.byteLength); const elementSize = sizeofProto(tensorProto.dataType); const length = tensorProto.rawData.byteLength / elementSize; if (tensorProto.rawData.byteLength % elementSize !== 0) { throw new Error('invalid buffer length'); } if (dataDest.length !== length) { throw new Error('buffer length mismatch'); } for (let i = 0; i < length; i++) { const n = readProto(dataSource, tensorProto.dataType, i * elementSize); dataDest[i] = n; } } else { // populate value from array let array; switch (tensorProto.dataType) { case onnx_proto_1.onnx.TensorProto.DataType.FLOAT: array = tensorProto.floatData; break; case onnx_proto_1.onnx.TensorProto.DataType.INT32: case onnx_proto_1.onnx.TensorProto.DataType.INT16: case onnx_proto_1.onnx.TensorProto.DataType.UINT16: case onnx_proto_1.onnx.TensorProto.DataType.INT8: case onnx_proto_1.onnx.TensorProto.DataType.UINT8: case onnx_proto_1.onnx.TensorProto.DataType.BOOL: array = tensorProto.int32Data; break; case onnx_proto_1.onnx.TensorProto.DataType.INT64: array = tensorProto.int64Data; break; case onnx_proto_1.onnx.TensorProto.DataType.DOUBLE: array = tensorProto.doubleData; break; case onnx_proto_1.onnx.TensorProto.DataType.UINT32: case onnx_proto_1.onnx.TensorProto.DataType.UINT64: array = tensorProto.uint64Data; break; default: // should never run here throw new Error('unspecific error'); } if (array === null || array === undefined) { throw new Error('failed to populate data from a tensorproto value'); } const data = value.data; if (data.length !== array.length) { throw new Error('array length mismatch'); } for (let i = 0; i < array.length; i++) { const element = array[i]; if (long_1.default.isLong(element)) { data[i] = longToNumber(element, tensorProto.dataType); } else { data[i] = element; } } } return value; } /** * Construct new Tensor from raw data * @param data the raw data object. Should be a string array for 'string' tensor, and the corresponding typed array * for other types of tensor. * @param dims the dimensions of the tensor * @param type the type of the tensor */ static fromData(data, dims, type) { return new Tensor(dims, type, undefined, undefined, data); } static fromOrtTensor(ortTensor) { if (!ortTensor) { throw new Error('cannot construct Value from an empty tensor'); } const dims = util_1.ProtoUtil.tensorDimsFromORTFormat(ortTensor); const type = util_1.ProtoUtil.tensorDataTypeFromProto(ortTensor.dataType()); const value = new Tensor(dims, type); if (type === 'string') { // When it's STRING type, the value should always be stored in field // 'stringData' for (let i = 0; i < ortTensor.stringDataLength(); i++) { value.data[i] = ortTensor.stringData(i); } } else if (ortTensor.rawDataArray() && typeof ortTensor.rawDataLength() === 'number' && ortTensor.rawDataLength() > 0) { // NOT considering segment for now (IMPORTANT) // populate value from rawData const dataDest = value.data; const dataSource = new DataView(ortTensor.rawDataArray().buffer, ortTensor.rawDataArray().byteOffset, ortTensor.rawDataLength()); const elementSize = sizeofProto(ortTensor.dataType()); const length = ortTensor.rawDataLength() / elementSize; if (ortTensor.rawDataLength() % elementSize !== 0) { throw new Error('invalid buffer length'); } if (dataDest.length !== length) { throw new Error('buffer length mismatch'); } for (let i = 0; i < length; i++) { const n = readProto(dataSource, ortTensor.dataType(), i * elementSize); dataDest[i] = n; } } return value; } } exports.Tensor = Tensor; function sizeof(type) { switch (type) { case 'bool': case 'int8': case 'uint8': return 1; case 'int16': case 'uint16': return 2; case 'int32': case 'uint32': case 'float32': return 4; case 'float64': return 8; default: throw new Error(`cannot calculate sizeof() on type ${type}`); } } function sizeofProto(type) { switch (type) { case onnx_proto_1.onnx.TensorProto.DataType.UINT8: case onnx_proto_1.onnx.TensorProto.DataType.INT8: case onnx_proto_1.onnx.TensorProto.DataType.BOOL: return 1; case onnx_proto_1.onnx.TensorProto.DataType.UINT16: case onnx_proto_1.onnx.TensorProto.DataType.INT16: return 2; case onnx_proto_1.onnx.TensorProto.DataType.FLOAT: case onnx_proto_1.onnx.TensorProto.DataType.INT32: case onnx_proto_1.onnx.TensorProto.DataType.UINT32: return 4; case onnx_proto_1.onnx.TensorProto.DataType.INT64: case onnx_proto_1.onnx.TensorProto.DataType.DOUBLE: case onnx_proto_1.onnx.TensorProto.DataType.UINT64: return 8; default: throw new Error(`cannot calculate sizeof() on type ${onnx_proto_1.onnx.TensorProto.DataType[type]}`); } } function createView(dataBuffer, type) { return new (dataviewConstructor(type))(dataBuffer); } function dataviewConstructor(type) { switch (type) { case 'bool': case 'uint8': return Uint8Array; case 'int8': return Int8Array; case 'int16': return Int16Array; case 'uint16': return Uint16Array; case 'int32': return Int32Array; case 'uint32': return Uint32Array; case 'float32': return Float32Array; case 'float64': return Float64Array; default: // should never run to here throw new Error('unspecified error'); } } // convert a long number to a 32-bit integer (cast-down) function longToNumber(i, type) { // INT64, UINT32, UINT64 if (type === onnx_proto_1.onnx.TensorProto.DataType.INT64 || type === ortFbs.TensorDataType.INT64) { if (i.greaterThanOrEqual(2147483648) || i.lessThan(-2147483648)) { throw new TypeError('int64 is not supported'); } } else if (type === onnx_proto_1.onnx.TensorProto.DataType.UINT32 || type === ortFbs.TensorDataType.UINT32 || type === onnx_proto_1.onnx.TensorProto.DataType.UINT64 || type === ortFbs.TensorDataType.UINT64) { if (i.greaterThanOrEqual(4294967296) || i.lessThan(0)) { throw new TypeError('uint64 is not supported'); } } else { throw new TypeError(`not a LONG type: ${onnx_proto_1.onnx.TensorProto.DataType[type]}`); } return i.toNumber(); } // read one value from TensorProto function readProto(view, type, byteOffset) { switch (type) { case onnx_proto_1.onnx.TensorProto.DataType.BOOL: case onnx_proto_1.onnx.TensorProto.DataType.UINT8: return view.getUint8(byteOffset); case onnx_proto_1.onnx.TensorProto.DataType.INT8: return view.getInt8(byteOffset); case onnx_proto_1.onnx.TensorProto.DataType.UINT16: return view.getUint16(byteOffset, true); case onnx_proto_1.onnx.TensorProto.DataType.INT16: return view.getInt16(byteOffset, true); case onnx_proto_1.onnx.TensorProto.DataType.FLOAT: return view.getFloat32(byteOffset, true); case onnx_proto_1.onnx.TensorProto.DataType.INT32: return view.getInt32(byteOffset, true); case onnx_proto_1.onnx.TensorProto.DataType.UINT32: return view.getUint32(byteOffset, true); case onnx_proto_1.onnx.TensorProto.DataType.INT64: return longToNumber(long_1.default.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), false), type); case onnx_proto_1.onnx.TensorProto.DataType.DOUBLE: return view.getFloat64(byteOffset, true); case onnx_proto_1.onnx.TensorProto.DataType.UINT64: return longToNumber(long_1.default.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), true), type); default: throw new Error(`cannot read from DataView for type ${onnx_proto_1.onnx.TensorProto.DataType[type]}`); } } //# sourceMappingURL=tensor.js.map