more shader work

This commit is contained in:
Rowan 2025-04-20 05:41:50 -05:00
parent a4f94c5bf4
commit 37741ed9aa
3 changed files with 151 additions and 60 deletions

View file

@ -6,11 +6,15 @@ import {
accessToBufferType, accessToBufferType,
accessToStorageTextureAccess, accessToStorageTextureAccess,
parseTextureType, parseTextureType,
typeToSamplerBindingType,
typeToTextureSampleType,
typeToViewDimension, typeToViewDimension,
wgslToWgpuFormat wgslToWgpuFormat
} from '../utils/wgsl-to-wgpu.js' } from '../utils/wgsl-to-wgpu.js'
import { BufferBindingType } from '../enum.js'
import { BitFlags } from '../utils/bitflags.js'
/** @import { WGSLAccess, WGSLSampledType, WGSLSamplerType, WGSLTextureType } from '../utils/wgsl-to-wgpu.js' */ /** @import { WGSLAccess, WGSLSamplerType } from '../utils/wgsl-to-wgpu.js' */
export class ShaderModule { export class ShaderModule {
_handle _handle
@ -45,6 +49,7 @@ export class ShaderModule {
reflect() { reflect() {
if (this._reflection == null) { if (this._reflection == null) {
this._reflection = new WgslReflect(this._code) this._reflection = new WgslReflect(this._code)
// no longer needed allow the GC to collect it
this._code = undefined this._code = undefined
} }
@ -53,13 +58,40 @@ export class ShaderModule {
} }
export class ReflectedShader { export class ReflectedShader {
_shader
/**
* @param {ShaderModule} shader
*/
constructor(shader) {
this._shader = shader
}
/**
* @returns {GPUShaderStageFlags}
*/
getShaderStages() {
const entry = this._shader.reflect().entry
let stages = 0
stages |= entry.vertex.length > 0 ?
GPUShaderStage.VERTEX : 0
stages |= entry.fragment.length > 0 ?
GPUShaderStage.FRAGMENT : 0
stages |= entry.compute.length > 0 ?
GPUShaderStage.COMPUTE : 0
return stages
}
/** /**
* @param {WgslReflect} reflection
* @param {GPUShaderStageFlags} stages * @param {GPUShaderStageFlags} stages
* @param {GroupBindingMap} [out=new GroupBindingMap()] * @param {GroupBindingMap} [out=new GroupBindingMap()]
*/ */
_getBindingsForStage(reflection, stages, out = new GroupBindingMap()) { getBindingsForStage(stages, out = new GroupBindingMap()) {
const groups = reflection.getBindGroups() const groups = this._shader.reflect().getBindGroups()
groups.forEach((bindings, groupIndex) => { groups.forEach((bindings, groupIndex) => {
if (!out.has(groupIndex)) { if (!out.has(groupIndex)) {
@ -79,6 +111,7 @@ export class ReflectedShader {
return out return out
} }
/** /**
* @param {Map<any, any>} map * @param {Map<any, any>} map
* @returns {number[]} * @returns {number[]}
@ -93,7 +126,7 @@ export class ReflectedShader {
*/ */
_parseUniform(_variableInfo) { _parseUniform(_variableInfo) {
return { return {
type: 'uniform', type: BufferBindingType.Uniform,
// TODO: infer these two properties // TODO: infer these two properties
hasDynamicOffset: false, hasDynamicOffset: false,
minBindingSize: 0 minBindingSize: 0
@ -116,60 +149,29 @@ export class ReflectedShader {
} }
} }
/**
* @param {WGSLTextureType} type
* @param {WGSLSampledType} sampledType
* @returns {GPUTextureSampleType}
*/
_parseSampleType(type, sampledType) {
if (type.includes('depth')) {
return 'depth'
}
switch (sampledType) {
case 'f32':
case 'i32':
case 'u32':
default:
return 'float'
}
}
/** /**
* @param {VariableInfo} variableInfo * @param {VariableInfo} variableInfo
* @returns {GPUTextureBindingLayout} * @returns {GPUTextureBindingLayout}
*/ */
_parseTexture(variableInfo) { _parseTexture(variableInfo) {
const [type, sampledType] = parseTextureType(variableInfo.type.name) const [type, sampledType] = parseTextureType(
variableInfo.type.name
)
return { return {
sampleType: this._parseSampleType(type, sampledType), sampleType: typeToTextureSampleType(type, sampledType),
viewDimension: typeToViewDimension(type), viewDimension: typeToViewDimension(type),
multisampled: type.includes('multisampled') multisampled: type.includes('multisampled')
} }
} }
/**
* @param {WGSLSamplerType} type
* @returns {GPUSamplerBindingType}
*/
_parseSamplerType(type) {
switch (type) {
case 'sampler_comparison':
return 'comparison'
case 'sampler':
default:
return 'filtering'
}
}
/** /**
* @param {VariableInfo} variableInfo * @param {VariableInfo} variableInfo
* @returns {GPUSamplerBindingLayout} * @returns {GPUSamplerBindingLayout}
*/ */
_parseSampler(variableInfo) { _parseSampler(variableInfo) {
return { return {
type: this._parseSamplerType( type: typeToSamplerBindingType(
/** @type {WGSLSamplerType} */(variableInfo.type.name) /** @type {WGSLSamplerType} */(variableInfo.type.name)
) )
} }
@ -238,7 +240,7 @@ export class ReflectedShader {
/** /**
* @param {GroupBindingMap} groupBindings * @param {GroupBindingMap} groupBindings
*/ */
_createBindGroupLayoutEntries(groupBindings) { createBindGroupLayoutEntries(groupBindings) {
const sortedGroupIndices = this._sortKeyIndices(groupBindings) const sortedGroupIndices = this._sortKeyIndices(groupBindings)
return sortedGroupIndices.map(groupIndex => { return sortedGroupIndices.map(groupIndex => {
@ -251,53 +253,67 @@ export class ReflectedShader {
} }
} }
export class UnifiedShader extends ReflectedShader { export class UnifiedShader {
_shader _shader
/** /**
* @param {ShaderModule} shader * @param {ReflectedShader} shader
*/ */
constructor(shader) { constructor(shader) {
super()
this._shader = shader this._shader = shader
} }
createBindGroupLayoutEntries() { createBindGroupLayoutEntries() {
return this._createBindGroupLayoutEntries( const stages = this._shader.getShaderStages()
this._getBindingsForStage(
this._shader.reflect(), const unifiedShader = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT
GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT
if (!BitFlags.has(stages, unifiedShader)) {
throw new Error('cant do it')
}
return this._shader.createBindGroupLayoutEntries(
this._shader.getBindingsForStage(
unifiedShader
) )
) )
} }
} }
export class ReflectedShaderPair extends ReflectedShader { export class ReflectedShaderPair {
_vertex _vertex
_fragment _fragment
/** /**
* @param {ShaderModule} vertex * @param {ReflectedShader} vertex
* @param {ShaderModule} [fragment] * @param {ReflectedShader} [fragment]
*/ */
constructor(vertex, fragment) { constructor(vertex, fragment) {
super()
this._vertex = vertex this._vertex = vertex
this._fragment = fragment this._fragment = fragment
} }
_createGroupBindings() { _createGroupBindings() {
const groupBindings = new GroupBindingMap() const groupBindings = new GroupBindingMap()
this.getBindingsForStage(
this._vertex.reflect(), if (
!BitFlags.has(
this._vertex.getShaderStages(),
GPUShaderStage.VERTEX)
&& !BitFlags.has(
this._fragment.getShaderStages(),
GPUShaderStage.FRAGMENT
)
) {
throw new Error('nope')
}
this._vertex.getBindingsForStage(
GPUShaderStage.VERTEX, GPUShaderStage.VERTEX,
groupBindings groupBindings
) )
this.getBindingsForStage( this._fragment.getBindingsForStage(
this._fragment.reflect(),
GPUShaderStage.FRAGMENT, GPUShaderStage.FRAGMENT,
groupBindings groupBindings
) )
@ -306,7 +322,9 @@ export class ReflectedShaderPair extends ReflectedShader {
} }
createBindGroupLayoutEntries() { createBindGroupLayoutEntries() {
return this._createBindGroupLayoutEntries( // FIXME: move this call and all the other calls
// somewhere else
return this._shader.createBindGroupLayoutEntries(
this._createGroupBindings() this._createGroupBindings()
) )
} }

41
src/utils/bitflags.js Normal file
View file

@ -0,0 +1,41 @@
export class BitFlags {
_value
get flags() {
return this._value
}
constructor(value) {
this._value = value
}
/**
* @param {number} a
* @param {number} b
*/
static has(a, b) {
return (a & b) === b
}
/**
* @param {number} a
* @param {number} b
*/
static add(a, b) {
return a | b
}
/**
* @param {number} b
*/
has(b) {
return BitFlags.has(this._value, b)
}
/**
* @param {number} b
*/
add(b) {
return BitFlags.add(this._value, b)
}
}

View file

@ -299,3 +299,35 @@ export const wgslToWgpuFormat = (format) => {
} }
} }
/**
* @param {WGSLTextureType} type
* @param {WGSLSampledType} sampledType
* @returns {GPUTextureSampleType}
*/
export const typeToTextureSampleType = (type, sampledType) => {
if (type.includes('depth')) {
return 'depth'
}
switch (sampledType) {
case 'f32':
case 'i32':
case 'u32':
default:
return 'float'
}
}
/**
* @param {WGSLSamplerType} type
* @returns {GPUSamplerBindingType}
*/
export const typeToSamplerBindingType = type => {
switch (type) {
case 'sampler_comparison':
return 'comparison'
case 'sampler':
default:
return 'filtering'
}
}