@@ -37,7 +37,8 @@ export function makeShader(inputsInfo: InputInfo[], outputShape: ShapeInfo,
3737 getOutputSamplingSnippet ( outputShape . logicalShape , outTexShape ) ;
3838 const source = [
3939 SHADER_PREFIX , inputPrefixSnippet , SAMPLE_1D_SNIPPET , SAMPLE_2D_SNIPPET ,
40- SAMPLE_3D_SNIPPET , inputSamplingSnippet , outputSamplingSnippet , userCode
40+ SAMPLE_3D_SNIPPET , SAMPLE_4D_SNIPPET , inputSamplingSnippet ,
41+ outputSamplingSnippet , userCode
4142 ] . join ( '\n' ) ;
4243 return source ;
4344}
@@ -63,6 +64,10 @@ function getInputSamplingSnippet(
6364 res += getSampler3D (
6465 inInfo . name , shape as [ number , number , number ] , texShape ) ;
6566 break ;
67+ case 4 :
68+ res += getSampler4D (
69+ inInfo . name , shape as [ number , number , number , number ] , texShape ) ;
70+ break ;
6671 default :
6772 throw new Error (
6873 `${ shape . length } -D input sampling` +
@@ -93,6 +98,9 @@ function getOutputSamplingSnippet(
9398 case 3 :
9499 return getOutput3DCoords ( outShape as [ number , number , number ] ,
95100 outTexShape ) ;
101+ case 4 :
102+ return getOutput4DCoords ( outShape as [ number , number , number , number ] ,
103+ outTexShape ) ;
96104 default :
97105 throw new Error (
98106 `${ outShape . length } -D output sampling is not yet supported` ) ;
@@ -144,6 +152,19 @@ const SAMPLE_3D_SNIPPET = `
144152 }
145153` ;
146154
155+ const SAMPLE_4D_SNIPPET = `
156+ float sample4D(sampler2D texture, float texNumR, float texNumC, float stride0,
157+ float stride1, float stride2, float row, float col, float depth,
158+ float depth2) {
159+ float index = dot(vec4(row, col, depth, depth2),
160+ vec4(stride0, stride1, stride2, 1.0));
161+ float texR = floor(index / texNumC);
162+ float texC = mod(index, texNumC);
163+ vec2 uv = (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
164+ return texture2D(texture, uv).r;
165+ }
166+ ` ;
167+
147168function getOutput1DCoords (
148169 shape : [ number ] , texShape : [ number , number ] ) : string {
149170 if ( texShape [ 0 ] === 1 ) {
@@ -185,6 +206,30 @@ function getOutput3DCoords(shape: [number, number, number],
185206 ` ;
186207}
187208
209+ function getOutput4DCoords ( shape : [ number , number , number , number ] ,
210+ texShape : [ number , number ] ) : string {
211+ const stride2 = shape [ 3 ] ;
212+ const stride1 = shape [ 2 ] * stride2 ;
213+ const stride0 = shape [ 1 ] * stride1 ;
214+ return `
215+ vec4 getOutputCoords() {
216+ vec2 resTexRC = floor(gl_FragCoord.yx);
217+ float index = dot(resTexRC, vec2(${ texShape [ 1 ] } .0, 1.0));
218+
219+ float r = floor(index / ${ stride0 } .0);
220+ index -= r * ${ stride0 } .0;
221+
222+ float c = floor(index / ${ stride1 } .0);
223+ index -= c * ${ stride1 } .0;
224+
225+ float d = floor(index / ${ stride2 } .0);
226+ float d2 = mod(index, ${ stride2 } .0);
227+
228+ return vec4(r, c, d, d2);
229+ }
230+ ` ;
231+ }
232+
188233function getOutput2DCoords (
189234 shape : [ number , number ] , texShape : [ number , number ] ) : string {
190235 if ( util . arraysEqual ( shape , texShape ) ) {
@@ -265,6 +310,24 @@ function getSampler3D(
265310 ` ;
266311}
267312
313+ function getSampler4D (
314+ texName : string , shape : [ number , number , number , number ] ,
315+ texShape : [ number , number ] ) : string {
316+ const funcName = 'get' + texName . charAt ( 0 ) . toUpperCase ( ) + texName . slice ( 1 ) ;
317+ const tR = texShape [ 0 ] ;
318+ const tC = texShape [ 1 ] ;
319+ const stride2 = shape [ 3 ] ;
320+ const stride1 = shape [ 2 ] * stride2 ;
321+ const stride0 = shape [ 1 ] * stride1 ;
322+
323+ return `
324+ float ${ funcName } (float row, float col, float depth, float depth2) {
325+ return sample4D(${ texName } , ${ tR } .0, ${ tC } .0, ${ stride0 } .0, ${ stride1 } .0,
326+ ${ stride2 } .0, row, col, depth, depth2);
327+ }
328+ ` ;
329+ }
330+
268331function getSampler2D (
269332 texName : string , shape : [ number , number ] ,
270333 texShape : [ number , number ] ) : string {
0 commit comments