66 lines
2.0 KiB
TypeScript
66 lines
2.0 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {Tensor} from '../../../tensor';
|
|
import {getGlsl} from '../glsl-source';
|
|
import {WebGLInferenceHandler} from '../inference-handler';
|
|
import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types';
|
|
import {getCoordsDataType} from '../utils';
|
|
|
|
import {getChannels, unpackFromChannel} from './packing-utils';
|
|
|
|
const unpackProgramMetadata = {
|
|
name: 'unpack',
|
|
inputNames: ['A'],
|
|
inputTypes: [TextureType.packed]
|
|
};
|
|
|
|
export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
|
|
const rank = input.dims.length;
|
|
|
|
const channels = getChannels('rc', rank);
|
|
const innerDims = channels.slice(-2);
|
|
const coordsDataType = getCoordsDataType(rank);
|
|
const unpackChannel = unpackFromChannel();
|
|
const isScalar = (input.dims.length === 0);
|
|
const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels);
|
|
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
|
|
const glsl = getGlsl(handler.session.backend.glContext.version);
|
|
const shaderSource = `
|
|
${unpackChannel}
|
|
void main() {
|
|
${coordsDataType} rc = getOutputCoords();
|
|
|
|
// Sample the texture with the coords to get the rgba channel value.
|
|
vec4 packedInput = getA(${sourceCoords});
|
|
|
|
${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0);
|
|
}
|
|
`;
|
|
|
|
return {
|
|
...unpackProgramMetadata,
|
|
hasMain: true,
|
|
output: {dims: input.dims, type: input.type, textureType: TextureType.unpacked},
|
|
shaderSource
|
|
};
|
|
};
|
|
|
|
export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader =>
|
|
({...unpackProgramMetadata, get: () => createUnpackProgramInfo(handler, input)});
|
|
|
|
function getSourceCoords(rank: number, dims: string[]): string {
|
|
if (rank === 1) {
|
|
return 'rc';
|
|
}
|
|
|
|
let coords = '';
|
|
for (let i = 0; i < rank; i++) {
|
|
coords += dims[i];
|
|
if (i < rank - 1) {
|
|
coords += ',';
|
|
}
|
|
}
|
|
return coords;
|
|
}
|