66 lines
2.8 KiB
JavaScript
66 lines
2.8 KiB
JavaScript
"use strict";
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.parseTransposeAttributes = exports.transpose = void 0;
|
|
const attribute_with_cache_key_1 = require("../../../attribute-with-cache-key");
|
|
const util_1 = require("../../../util");
|
|
const types_1 = require("../types");
|
|
const transposeProgramMetadata = {
|
|
name: 'Transpose',
|
|
inputNames: ['A'],
|
|
inputTypes: [types_1.TextureType.unpacked]
|
|
};
|
|
const transpose = (inferenceHandler, inputs, attributes) => {
|
|
validateInputs(inputs);
|
|
const output = inferenceHandler.run(Object.assign(Object.assign({}, transposeProgramMetadata), { cacheHint: attributes.cacheKey, get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm) }), inputs);
|
|
return [output];
|
|
};
|
|
exports.transpose = transpose;
|
|
const parseTransposeAttributes = (node) => (0, attribute_with_cache_key_1.createAttributeWithCacheKey)({ perm: node.attributes.getInts('perm', []) });
|
|
exports.parseTransposeAttributes = parseTransposeAttributes;
|
|
const createTransposeProgramInfo = (inferenceHandler, input, perm) => {
|
|
const inputShape = input.dims;
|
|
perm = getAdjustedPerm(inputShape, perm);
|
|
const unpackedOutputShape = getOutputShape(inputShape, perm);
|
|
const rank = inputShape.length;
|
|
// A dims=[${inputs[0].dims.toString()}]
|
|
// out Dims=[${unpackedOutputShape.toString()}]
|
|
// based on perm=[${perm.toString()}]
|
|
const shaderSource = `
|
|
${getPermFunctionBody('perm', perm, rank)}
|
|
float process(int indices[${rank}]) {
|
|
int a[${rank}];
|
|
perm(a, indices);
|
|
return _A(a);
|
|
}`;
|
|
return Object.assign(Object.assign({}, transposeProgramMetadata), { output: { dims: unpackedOutputShape, type: input.type, textureType: types_1.TextureType.unpacked }, shaderSource });
|
|
};
|
|
const getAdjustedPerm = (inputShape, perm) => {
|
|
if (perm && perm.length !== inputShape.length) {
|
|
perm = [...(inputShape.keys())].reverse();
|
|
}
|
|
return perm;
|
|
};
|
|
const getOutputShape = (inputShape, perm) => {
|
|
perm = getAdjustedPerm(inputShape, perm);
|
|
return util_1.ShapeUtil.sortBasedOnPerm(inputShape, perm);
|
|
};
|
|
const getPermFunctionBody = (name, perm, rank) => {
|
|
const reverseFunc = [];
|
|
reverseFunc.push(`void ${name}(out int a[${rank}], int src[${rank}]) {`);
|
|
for (let i = 0; i < rank; ++i) {
|
|
reverseFunc.push(`\ta[${perm[i]}]=src[${i}];`);
|
|
}
|
|
reverseFunc.push('\t}');
|
|
return reverseFunc.join('\n');
|
|
};
|
|
const validateInputs = (inputs) => {
|
|
if (!inputs || inputs.length !== 1) {
|
|
throw new Error('Transpose requires 1 input.');
|
|
}
|
|
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
|
|
throw new Error('input should be float tensor');
|
|
}
|
|
};
|
|
//# sourceMappingURL=transpose.js.map
|