66 lines
2.2 KiB
TypeScript
66 lines
2.2 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {NUMBER_TYPES} from '../../../operators';
|
|
import {Tensor} from '../../../tensor';
|
|
import {WebGLInferenceHandler} from '../inference-handler';
|
|
import {ProgramInfo, ProgramMetadata, TextureType} from '../types';
|
|
|
|
export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
|
|
validateInputs(inputs);
|
|
|
|
const tileProgramMetadata = {
|
|
name: 'Tile',
|
|
inputNames: ['A'],
|
|
inputTypes: [TextureType.unpacked],
|
|
};
|
|
|
|
const output = inferenceHandler.run(
|
|
{...tileProgramMetadata, get: () => createTileProgramInfo(inferenceHandler, inputs, tileProgramMetadata)},
|
|
inputs);
|
|
return [output];
|
|
};
|
|
|
|
const createTileProgramInfo =
|
|
(handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => {
|
|
const inputShape = inputs[0].dims.slice();
|
|
const outputShape = new Array(inputShape.length);
|
|
|
|
const tileOps: string[] = [];
|
|
for (let i = 0; i < inputShape.length; i++) {
|
|
outputShape[i] = inputShape[i] * inputs[1].numberData[i];
|
|
tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`);
|
|
}
|
|
|
|
const rank = outputShape.length;
|
|
const shaderSource = `
|
|
float process(int outputIdx[${rank}]) {
|
|
int inputIdx[${rank}];
|
|
${tileOps.join('\n')}
|
|
return _A(inputIdx);
|
|
}
|
|
`;
|
|
return {
|
|
...tileProgramMetadata,
|
|
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
|
|
shaderSource
|
|
};
|
|
};
|
|
|
|
const validateInputs = (inputs: Tensor[]): void => {
|
|
if (!inputs || inputs.length !== 2) {
|
|
throw new Error('Tile requires 2 input.');
|
|
}
|
|
if (inputs[1].dims.length !== 1) {
|
|
throw new Error('The second input shape must 1 dimension.');
|
|
}
|
|
if (inputs[1].dims[0] !== inputs[0].dims.length) {
|
|
throw new Error('Invalid input shape.');
|
|
}
|
|
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
|
|
throw new Error('Invalid input type.');
|
|
}
|
|
if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
|
|
throw new Error('Invalid repeat type.');
|
|
}
|
|
}; |