tasq/node_modules/onnxruntime-web/lib/onnxjs/backends/webgl/ops/gather.js

87 lines
3.9 KiB
JavaScript

"use strict";
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
Object.defineProperty(exports, "__esModule", { value: true });
exports.parseGatherAttributes = exports.gather = void 0;
const attribute_with_cache_key_1 = require("../../../attribute-with-cache-key");
const operators_1 = require("../../../operators");
const util_1 = require("../../../util");
const types_1 = require("../types");
const gather = (inferenceHandler, inputs, attributes) => {
validateInputs(inputs, attributes.axis);
const output = inferenceHandler.run(createGatherProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
return [output];
};
exports.gather = gather;
const parseGatherAttributes = (node) => (0, attribute_with_cache_key_1.createAttributeWithCacheKey)({ axis: node.attributes.getInt('axis', 0) });
exports.parseGatherAttributes = parseGatherAttributes;
const gatherProgramMetadata = {
name: 'Gather',
inputNames: ['A', 'B'],
inputTypes: [types_1.TextureType.unpacked, types_1.TextureType.unpacked],
};
const createGatherProgramInfo = (handler, metadata, inputs, axis) => {
const inputShape = inputs[0].dims.slice();
const indexDataShape = inputs[1].dims.slice();
const outputShape = new Array(inputShape.length + indexDataShape.length - 1);
axis = util_1.ShapeUtil.normalizeAxis(axis, inputShape.length);
const indexCopyOps = [];
for (let i = 0; i < outputShape.length; i++) {
// outputShape is divided into three parts: A, B, C
// |0 axis| axis + indexDataShape.length | end|
// | A | B | C |
//
// inputIdx: [A, inputs[1][B], C]
if (i < axis) { // A
outputShape[i] = inputShape[i];
indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
}
else {
if (i < axis + indexDataShape.length) { // B
outputShape[i] = indexDataShape[i - axis];
indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
}
else { // C
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis
indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
}
}
}
const orank = outputShape.length || 1;
const irank = inputShape.length;
const iDrank = indexDataShape.length || 1;
const shaderSource = `
float process(int outputIdx[${orank}]) {
int inputIdx[${irank}];
int indexDataIdx[${iDrank}];
indexDataIdx[0] = 0;
${indexCopyOps.join('\n ')}
int idx = int(_B(indexDataIdx));
inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
return _A(inputIdx);
}`;
return Object.assign(Object.assign({}, metadata), { output: { dims: outputShape, type: inputs[0].type, textureType: types_1.TextureType.unpacked }, shaderSource });
};
const createGatherProgramInfoLoader = (handler, inputs, attributes) => {
const metadata = Object.assign(Object.assign({}, gatherProgramMetadata), { cacheHint: attributes.cacheKey });
return Object.assign(Object.assign({}, metadata), { get: () => createGatherProgramInfo(handler, metadata, inputs, attributes.axis) });
};
const validateInputs = (inputs, axis) => {
if (!inputs || inputs.length !== 2) {
throw new Error('Gather requires 2 inputs.');
}
const tensorRank = inputs[0].dims.length;
if (tensorRank < 1) {
throw new Error('Invalid input shape.');
}
if (axis < -tensorRank || axis > tensorRank - 1) {
throw new Error('Invalid axis.');
}
if (operators_1.NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
throw new Error('Invaid input type.');
}
if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
throw new Error('Invaid input type.');
}
};
//# sourceMappingURL=gather.js.map