59 lines
2.8 KiB
JavaScript
59 lines
2.8 KiB
JavaScript
"use strict";
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.parseSplitAttributes = exports.split = void 0;
|
|
const attribute_with_cache_key_1 = require("../../../attribute-with-cache-key");
|
|
const util_1 = require("../../../util");
|
|
const types_1 = require("../types");
|
|
const splitProgramMetadata = {
|
|
name: 'Split',
|
|
inputNames: ['A'],
|
|
inputTypes: [types_1.TextureType.unpacked],
|
|
};
|
|
const split = (inferenceHandler, inputs, attributes) => {
|
|
validateInputs(inputs);
|
|
const axis = util_1.ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
|
|
const count = getProgramCount(inferenceHandler, inputs, axis, attributes);
|
|
const output = [];
|
|
for (let i = 0; i < count; ++i) {
|
|
output.push(inferenceHandler.run(Object.assign(Object.assign({}, splitProgramMetadata), { cacheHint: `${attributes.cacheKey};${i}`, get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i) }), inputs));
|
|
}
|
|
return output;
|
|
};
|
|
exports.split = split;
|
|
const parseSplitAttributes = (node) => {
|
|
const axis = node.attributes.getInt('axis', 0);
|
|
const split = node.attributes.getInts('split', []);
|
|
const numOutputs = node.outputs.length;
|
|
return (0, attribute_with_cache_key_1.createAttributeWithCacheKey)({ axis, split, numOutputs });
|
|
};
|
|
exports.parseSplitAttributes = parseSplitAttributes;
|
|
const getProgramCount = (inferenceHandler, inputs, axis, attributes) => {
|
|
const [, offsets] = util_1.SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs);
|
|
return offsets.length;
|
|
};
|
|
const createSplitProgramInfo = (inferenceHandler, input, attributes, axis, index) => {
|
|
const [shapes, offsets] = util_1.SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs);
|
|
const offset = offsets[index];
|
|
const outputShape = shapes[index];
|
|
const rank = outputShape.length;
|
|
const shaderSource = `
|
|
float process(int indices[${rank}]) {
|
|
indices[${axis}] += ${offset};
|
|
return _A(indices);
|
|
}
|
|
`;
|
|
return Object.assign(Object.assign({}, splitProgramMetadata), { cacheHint: `${attributes.cacheKey}:${index}`, output: { dims: outputShape, type: input.type, textureType: types_1.TextureType.unpacked }, shaderSource });
|
|
};
|
|
const validateInputs = (inputs) => {
|
|
if (!inputs || inputs.length !== 1) {
|
|
throw new Error('Split requires one input.');
|
|
}
|
|
if (inputs[0].type !== 'int8' && inputs[0].type !== 'uint8' && inputs[0].type !== 'int16' &&
|
|
inputs[0].type !== 'uint16' && inputs[0].type !== 'int32' && inputs[0].type !== 'uint32' &&
|
|
inputs[0].type !== 'float32' && inputs[0].type !== 'float64' && inputs[0].type !== 'bool') {
|
|
throw new Error('Invalid input type.');
|
|
}
|
|
};
|
|
//# sourceMappingURL=split.js.map
|