177 lines
7.2 KiB
TypeScript
177 lines
7.2 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
|
|
import {Graph} from '../../../graph';
|
|
import {NUMBER_TYPES, OperatorImplementation, OperatorInitialization} from '../../../operators';
|
|
import {Tensor} from '../../../tensor';
|
|
import {ShapeUtil} from '../../../util';
|
|
import {WebGLInferenceHandler} from '../inference-handler';
|
|
import {ProgramInfo, ProgramMetadata, TextureType} from '../types';
|
|
|
|
export interface ReduceAttributes extends AttributeWithCacheKey {
|
|
readonly axes: number[];
|
|
readonly keepDims: boolean;
|
|
}
|
|
|
|
// return [init ops, reduce ops, final ops]
|
|
type ReduceOp = (inputs: Tensor[], axes: number[]) => string[];
|
|
|
|
const reduce =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string,
|
|
reduceOp: ReduceOp): Tensor[] => {
|
|
validateInputs(inputs);
|
|
|
|
const reduceProgramMetadata = {
|
|
name,
|
|
inputNames: ['A'],
|
|
inputTypes: [TextureType.unpacked],
|
|
};
|
|
|
|
const output = inferenceHandler.run(
|
|
{
|
|
...reduceProgramMetadata,
|
|
cacheHint: attributes.cacheKey,
|
|
get: () =>
|
|
createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata)
|
|
},
|
|
inputs);
|
|
return [output];
|
|
};
|
|
|
|
export const parseReduceAttributes: OperatorInitialization<ReduceAttributes> = (node: Graph.Node): ReduceAttributes => {
|
|
const axes = node.attributes.getInts('axes', []);
|
|
const keepDims = node.attributes.getInt('keepdims', 1) === 1;
|
|
return createAttributeWithCacheKey({axes, keepDims});
|
|
};
|
|
|
|
const createReduceProgramInfo =
|
|
(handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string, reduceOp: ReduceOp,
|
|
reduceProgramMetadata: ProgramMetadata): ProgramInfo => {
|
|
const outputShape: number[] = [];
|
|
const iRank = inputs[0].dims.length || 1;
|
|
|
|
const idxCopy = []; // copy output indexes to input indexes
|
|
|
|
const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length);
|
|
const ops = reduceOp(inputs, axes);
|
|
let reduceOps = ops[1];
|
|
|
|
for (let k = 0; k < inputs[0].dims.length; k++) {
|
|
// if this axis is reduced
|
|
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
|
if (attributes.keepDims) {
|
|
outputShape.push(1);
|
|
} // else { remove the axis from outputShape; }
|
|
|
|
// loop over the d-th axis
|
|
reduceOps = `
|
|
for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
|
|
inputIdx[${k}] = j${k};
|
|
${reduceOps}
|
|
}`;
|
|
} else {
|
|
idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`);
|
|
|
|
outputShape.push(inputs[0].dims[k]);
|
|
}
|
|
}
|
|
|
|
const oRank = outputShape.length || 1;
|
|
|
|
const shaderSource = `
|
|
float process(int outputIdx[${oRank}]) {
|
|
float value; // final result
|
|
int inputIdx[${iRank}]; // addressing input data
|
|
${idxCopy.join('\n')}
|
|
${ops[0]} // init ops for reduce max/min
|
|
${reduceOps}
|
|
${ops[2]} // final computation for reduce mean
|
|
return value;
|
|
}`;
|
|
|
|
return {
|
|
...reduceProgramMetadata,
|
|
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
|
|
shaderSource
|
|
};
|
|
};
|
|
|
|
const validateInputs = (inputs: Tensor[]): void => {
|
|
if (!inputs || inputs.length !== 1) {
|
|
throw new Error('Reduce op requires 1 input.');
|
|
}
|
|
|
|
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
|
|
throw new Error('Invalid input type.');
|
|
}
|
|
};
|
|
|
|
export const reduceSum: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', ''];
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp);
|
|
};
|
|
|
|
export const reduceMean: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
|
|
let size = 1.0;
|
|
for (let k = 0; k < inputs[0].dims.length; k++) {
|
|
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
|
size *= inputs[0].dims[k];
|
|
}
|
|
}
|
|
|
|
return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.`
|
|
};
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp);
|
|
};
|
|
|
|
export const reduceMax: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
|
|
const idxZero = [];
|
|
for (let k = 0; k < inputs[0].dims.length; k++) {
|
|
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
|
idxZero.push(`inputIdx[${k}] = 0;`); // first element
|
|
}
|
|
}
|
|
|
|
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', ''];
|
|
};
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp);
|
|
};
|
|
|
|
export const reduceMin: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => {
|
|
const idxZero = [];
|
|
for (let k = 0; k < inputs[0].dims.length; k++) {
|
|
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
|
idxZero.push(`inputIdx[${k}] = 0;`); // first element
|
|
}
|
|
}
|
|
|
|
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', ''];
|
|
};
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp);
|
|
};
|
|
|
|
export const reduceProd: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', ''];
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp);
|
|
};
|
|
|
|
export const reduceLogSum: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);'];
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp);
|
|
};
|
|
|
|
export const reduceLogSumSquare: OperatorImplementation<ReduceAttributes> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
|
|
const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', ''];
|
|
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp);
|
|
}; |