tasq/node_modules/onnxruntime-web/lib/onnxjs/backends/webgl/ops/unpack.ts

66 lines
2.0 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Tensor} from '../../../tensor';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types';
import {getCoordsDataType} from '../utils';
import {getChannels, unpackFromChannel} from './packing-utils';
const unpackProgramMetadata = {
name: 'unpack',
inputNames: ['A'],
inputTypes: [TextureType.packed]
};
export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
const rank = input.dims.length;
const channels = getChannels('rc', rank);
const innerDims = channels.slice(-2);
const coordsDataType = getCoordsDataType(rank);
const unpackChannel = unpackFromChannel();
const isScalar = (input.dims.length === 0);
const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels);
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
const glsl = getGlsl(handler.session.backend.glContext.version);
const shaderSource = `
${unpackChannel}
void main() {
${coordsDataType} rc = getOutputCoords();
// Sample the texture with the coords to get the rgba channel value.
vec4 packedInput = getA(${sourceCoords});
${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0);
}
`;
return {
...unpackProgramMetadata,
hasMain: true,
output: {dims: input.dims, type: input.type, textureType: TextureType.unpacked},
shaderSource
};
};
export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader =>
({...unpackProgramMetadata, get: () => createUnpackProgramInfo(handler, input)});
function getSourceCoords(rank: number, dims: string[]): string {
if (rank === 1) {
return 'rc';
}
let coords = '';
for (let i = 0; i < rank; i++) {
coords += dims[i];
if (i < rank - 1) {
coords += ',';
}
}
return coords;
}