165 lines
6.4 KiB
TypeScript
165 lines
6.4 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {Tensor} from '../../../tensor';
|
|
import {BroadcastUtil, ShapeUtil} from '../../../util';
|
|
import {getGlsl} from '../glsl-source';
|
|
import {WebGLInferenceHandler} from '../inference-handler';
|
|
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';
|
|
import {getCoordsDataType, getGlChannels} from '../utils';
|
|
|
|
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
|
|
import {getBiasForMatmul} from './matmul';
|
|
|
|
const createPackedMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
|
|
name: 'MatMul (packed)',
|
|
inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'],
|
|
inputTypes: hasBias ? [TextureType.packed, TextureType.packed, TextureType.packed] :
|
|
[TextureType.packed, TextureType.packed],
|
|
cacheHint
|
|
});
|
|
|
|
const createPackedMatmulProgramInfo =
|
|
(inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[],
|
|
activationAttributes: InternalActivationAttributes): ProgramInfo => {
|
|
const hasBias = inputs.length > 2;
|
|
const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
|
|
const aShape = inputs[0].dims;
|
|
const bShape = inputs[1].dims;
|
|
const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
|
|
const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
|
|
|
|
if (!outputShape) {
|
|
throw new Error('Can\'t use matmul on the given tensors');
|
|
}
|
|
const sharedDim = aShape[aShape.length - 1];
|
|
const sharedDimIndex = Math.ceil(sharedDim / 2);
|
|
const aRank = aShape.length;
|
|
const bRank = bShape.length;
|
|
|
|
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
|
const coordsDataType = getCoordsDataType(outputShape.length);
|
|
const outRank = outputShape.length;
|
|
const allGlChannels = getGlChannels();
|
|
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes);
|
|
|
|
const getBiasForMatmulSnippet =
|
|
hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, true)}` : '';
|
|
|
|
const getBcastedSamplerForMatmulSnippet =
|
|
isBroadcast ? `${getBcastSamplerForMatmul(coordsDataType, allGlChannels, inputs, outputShape)}` : '';
|
|
|
|
const getSamplerAInLoopSnippet = isBroadcast ? 'getAAtOutCoordsMatmul(i)' : `getA(${getA(allGlChannels, aRank)})`;
|
|
const getSamplerBInLoopSnippet = isBroadcast ? 'getBAtOutCoordsMatmul(i)' : `getB(${getB(allGlChannels, bRank)})`;
|
|
const getOutputCoordsSnippet = isBroadcast ? '' : `${coordsDataType} rc =
|
|
getOutputCoords(); int lastDim = rc.${allGlChannels[outRank - 1]}; rc.${allGlChannels[outRank - 1]} =
|
|
rc.${allGlChannels[outRank - 2]}; rc.${allGlChannels[outRank - 2]} = lastDim;
|
|
`;
|
|
const shaderSource = `
|
|
${getBcastedSamplerForMatmulSnippet}
|
|
${getBiasForMatmulSnippet}
|
|
${activationFunction}
|
|
void main() {
|
|
${getOutputCoordsSnippet}
|
|
|
|
vec4 value = vec4(0);
|
|
for (int i = 0; i < ${sharedDimIndex}; i++) {
|
|
vec4 a = ${getSamplerAInLoopSnippet};
|
|
vec4 b = ${getSamplerBInLoopSnippet};
|
|
|
|
value += (a.rrbb * b.rgrg);
|
|
value += (a.ggaa * b.baba);
|
|
}
|
|
${processBias}
|
|
${applyActivation}
|
|
${glsl.output} = value;
|
|
}`;
|
|
return {
|
|
...metadata,
|
|
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.packed},
|
|
shaderSource,
|
|
hasMain: true
|
|
};
|
|
};
|
|
|
|
export const createPackedMatmulProgramInfoLoader =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[],
|
|
activationAttributes: InternalActivationAttributes): ProgramInfoLoader => {
|
|
const metadata = createPackedMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
|
|
return {
|
|
...metadata,
|
|
get: () => createPackedMatmulProgramInfo(inferenceHandler, metadata, inputs, activationAttributes)
|
|
};
|
|
};
|
|
|
|
function getBcastSamplerForMatmul(
|
|
coordsDataType: string, allGlChannels: readonly string[], inputs: Tensor[], outShape: readonly number[]): string {
|
|
let unpackedACoordsSnippet = [];
|
|
let unpackedBCoordsSnippet = [];
|
|
|
|
const inAShape = inputs[0].dims;
|
|
const inBShape = inputs[1].dims;
|
|
|
|
const inARank = inAShape.length;
|
|
const inBRank = inBShape.length;
|
|
|
|
const outRank = outShape.length;
|
|
const rankADiff = outRank - inARank;
|
|
const rankBDiff = outRank - inBRank;
|
|
|
|
unpackedACoordsSnippet = inAShape.map((s, i) => `coords.${allGlChannels[i + rankADiff]}`);
|
|
unpackedACoordsSnippet[inARank - 1] = 'i*2';
|
|
unpackedACoordsSnippet.join(', ');
|
|
unpackedBCoordsSnippet = inBShape.map((s, i) => `coords.${allGlChannels[i + rankBDiff]}`);
|
|
unpackedBCoordsSnippet[inBRank - 2] = 'i*2';
|
|
unpackedBCoordsSnippet.join(', ');
|
|
|
|
const broadcastADims = BroadcastUtil.getBroadcastDims(inAShape, outShape);
|
|
const broadcastBDims = BroadcastUtil.getBroadcastDims(inBShape, outShape);
|
|
|
|
const coordsASnippet = broadcastADims.map(d => `coords.${allGlChannels[d + rankADiff]} = 0;`).join('\n');
|
|
const coordsBSnippet = broadcastBDims.map(d => `coords.${allGlChannels[d + rankBDiff]} = 0;`).join('\n');
|
|
const swapDimSnippet = `int lastDim = coords.${allGlChannels[outRank - 1]};
|
|
coords.${allGlChannels[outRank - 1]} = coords.${allGlChannels[outRank - 2]};
|
|
coords.${allGlChannels[outRank - 2]} = lastDim;`;
|
|
|
|
const getBcastSamplerMatmulSource = `
|
|
vec4 getAAtOutCoordsMatmul(int i) {
|
|
${coordsDataType} coords = getOutputCoords();
|
|
${swapDimSnippet}
|
|
${coordsASnippet}
|
|
vec4 outputValue = getA(${unpackedACoordsSnippet});
|
|
return outputValue;
|
|
}
|
|
|
|
vec4 getBAtOutCoordsMatmul(int i) {
|
|
${coordsDataType} coords = getOutputCoords();
|
|
${swapDimSnippet}
|
|
${coordsBSnippet}
|
|
vec4 outputValue = getB(${unpackedBCoordsSnippet});
|
|
return outputValue;
|
|
}`;
|
|
|
|
return getBcastSamplerMatmulSource;
|
|
}
|
|
|
|
function getA(allGlChannels: string[], rank: number): string {
|
|
let res = '';
|
|
for (let i = 0; i < rank - 2; i++) {
|
|
res += `rc.${allGlChannels[i]}, `;
|
|
}
|
|
res += `rc.${allGlChannels[rank - 2]}, ` +
|
|
'i*2';
|
|
return res;
|
|
}
|
|
|
|
function getB(allGlChannels: string[], rank: number): string {
|
|
let res = '';
|
|
for (let i = 0; i < rank - 2; i++) {
|
|
res += `rc.${allGlChannels[i]}, `;
|
|
}
|
|
res += 'i*2, ' +
|
|
`rc.${allGlChannels[rank - 1]}`;
|
|
return res;
|
|
}
|