From 37741ed9aab2383fbf6b49ded75b3f56df28215a Mon Sep 17 00:00:00 2001 From: rowan Date: Sun, 20 Apr 2025 05:41:50 -0500 Subject: [PATCH] more shader work --- src/resources/shader-module.js | 138 +++++++++++++++++++-------------- src/utils/bitflags.js | 41 ++++++++++ src/utils/wgsl-to-wgpu.js | 32 ++++++++ 3 files changed, 151 insertions(+), 60 deletions(-) create mode 100644 src/utils/bitflags.js diff --git a/src/resources/shader-module.js b/src/resources/shader-module.js index 3038594..fd008e6 100644 --- a/src/resources/shader-module.js +++ b/src/resources/shader-module.js @@ -6,11 +6,15 @@ import { accessToBufferType, accessToStorageTextureAccess, parseTextureType, + typeToSamplerBindingType, + typeToTextureSampleType, typeToViewDimension, wgslToWgpuFormat } from '../utils/wgsl-to-wgpu.js' +import { BufferBindingType } from '../enum.js' +import { BitFlags } from '../utils/bitflags.js' -/** @import { WGSLAccess, WGSLSampledType, WGSLSamplerType, WGSLTextureType } from '../utils/wgsl-to-wgpu.js' */ +/** @import { WGSLAccess, WGSLSamplerType } from '../utils/wgsl-to-wgpu.js' */ export class ShaderModule { _handle @@ -45,6 +49,7 @@ export class ShaderModule { reflect() { if (this._reflection == null) { this._reflection = new WgslReflect(this._code) + // no longer needed allow the GC to collect it this._code = undefined } @@ -53,13 +58,40 @@ export class ShaderModule { } export class ReflectedShader { + _shader + + /** + * @param {ShaderModule} shader + */ + constructor(shader) { + this._shader = shader + } + + /** + * @returns {GPUShaderStageFlags} + */ + getShaderStages() { + const entry = this._shader.reflect().entry + let stages = 0 + + stages |= entry.vertex.length > 0 ? + GPUShaderStage.VERTEX : 0 + + stages |= entry.fragment.length > 0 ? + GPUShaderStage.FRAGMENT : 0 + + stages |= entry.compute.length > 0 ? + GPUShaderStage.COMPUTE : 0 + + return stages + } + /** - * @param {WgslReflect} reflection * @param {GPUShaderStageFlags} stages * @param {GroupBindingMap} [out=new GroupBindingMap()] */ - _getBindingsForStage(reflection, stages, out = new GroupBindingMap()) { - const groups = reflection.getBindGroups() + getBindingsForStage(stages, out = new GroupBindingMap()) { + const groups = this._shader.reflect().getBindGroups() groups.forEach((bindings, groupIndex) => { if (!out.has(groupIndex)) { @@ -79,6 +111,7 @@ export class ReflectedShader { return out } + /** * @param {Map} map * @returns {number[]} @@ -93,7 +126,7 @@ export class ReflectedShader { */ _parseUniform(_variableInfo) { return { - type: 'uniform', + type: BufferBindingType.Uniform, // TODO: infer these two properties hasDynamicOffset: false, minBindingSize: 0 @@ -116,60 +149,29 @@ export class ReflectedShader { } } - /** - * @param {WGSLTextureType} type - * @param {WGSLSampledType} sampledType - * @returns {GPUTextureSampleType} - */ - _parseSampleType(type, sampledType) { - if (type.includes('depth')) { - return 'depth' - } - - switch (sampledType) { - case 'f32': - case 'i32': - case 'u32': - default: - return 'float' - } - } - /** * @param {VariableInfo} variableInfo * @returns {GPUTextureBindingLayout} */ _parseTexture(variableInfo) { - const [type, sampledType] = parseTextureType(variableInfo.type.name) + const [type, sampledType] = parseTextureType( + variableInfo.type.name + ) return { - sampleType: this._parseSampleType(type, sampledType), + sampleType: typeToTextureSampleType(type, sampledType), viewDimension: typeToViewDimension(type), multisampled: type.includes('multisampled') } } - /** - * @param {WGSLSamplerType} type - * @returns {GPUSamplerBindingType} - */ - _parseSamplerType(type) { - switch (type) { - case 'sampler_comparison': - return 'comparison' - case 'sampler': - default: - return 'filtering' - } - } - /** * @param {VariableInfo} variableInfo * @returns {GPUSamplerBindingLayout} */ _parseSampler(variableInfo) { return { - type: this._parseSamplerType( + type: typeToSamplerBindingType( /** @type {WGSLSamplerType} */(variableInfo.type.name) ) } @@ -238,7 +240,7 @@ export class ReflectedShader { /** * @param {GroupBindingMap} groupBindings */ - _createBindGroupLayoutEntries(groupBindings) { + createBindGroupLayoutEntries(groupBindings) { const sortedGroupIndices = this._sortKeyIndices(groupBindings) return sortedGroupIndices.map(groupIndex => { @@ -251,53 +253,67 @@ export class ReflectedShader { } } -export class UnifiedShader extends ReflectedShader { +export class UnifiedShader { _shader /** - * @param {ShaderModule} shader + * @param {ReflectedShader} shader */ constructor(shader) { - super() - this._shader = shader } createBindGroupLayoutEntries() { - return this._createBindGroupLayoutEntries( - this._getBindingsForStage( - this._shader.reflect(), - GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT + const stages = this._shader.getShaderStages() + + const unifiedShader = GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT + + if (!BitFlags.has(stages, unifiedShader)) { + throw new Error('cant do it') + } + + return this._shader.createBindGroupLayoutEntries( + this._shader.getBindingsForStage( + unifiedShader ) ) } } -export class ReflectedShaderPair extends ReflectedShader { +export class ReflectedShaderPair { _vertex _fragment /** - * @param {ShaderModule} vertex - * @param {ShaderModule} [fragment] + * @param {ReflectedShader} vertex + * @param {ReflectedShader} [fragment] */ constructor(vertex, fragment) { - super() - this._vertex = vertex this._fragment = fragment } _createGroupBindings() { const groupBindings = new GroupBindingMap() - this.getBindingsForStage( - this._vertex.reflect(), + + if ( + !BitFlags.has( + this._vertex.getShaderStages(), + GPUShaderStage.VERTEX) + && !BitFlags.has( + this._fragment.getShaderStages(), + GPUShaderStage.FRAGMENT + ) + ) { + throw new Error('nope') + } + + this._vertex.getBindingsForStage( GPUShaderStage.VERTEX, groupBindings ) - this.getBindingsForStage( - this._fragment.reflect(), + this._fragment.getBindingsForStage( GPUShaderStage.FRAGMENT, groupBindings ) @@ -306,7 +322,9 @@ export class ReflectedShaderPair extends ReflectedShader { } createBindGroupLayoutEntries() { - return this._createBindGroupLayoutEntries( + // FIXME: move this call and all the other calls + // somewhere else + return this._shader.createBindGroupLayoutEntries( this._createGroupBindings() ) } diff --git a/src/utils/bitflags.js b/src/utils/bitflags.js new file mode 100644 index 0000000..5402171 --- /dev/null +++ b/src/utils/bitflags.js @@ -0,0 +1,41 @@ +export class BitFlags { + _value + + get flags() { + return this._value + } + + constructor(value) { + this._value = value + } + /** + * @param {number} a + * @param {number} b + */ + static has(a, b) { + return (a & b) === b + } + + /** + * @param {number} a + * @param {number} b + */ + static add(a, b) { + return a | b + } + + /** + * @param {number} b + */ + has(b) { + return BitFlags.has(this._value, b) + } + + /** + * @param {number} b + */ + add(b) { + return BitFlags.add(this._value, b) + } +} + diff --git a/src/utils/wgsl-to-wgpu.js b/src/utils/wgsl-to-wgpu.js index 57ec775..7329235 100644 --- a/src/utils/wgsl-to-wgpu.js +++ b/src/utils/wgsl-to-wgpu.js @@ -299,3 +299,35 @@ export const wgslToWgpuFormat = (format) => { } } +/** + * @param {WGSLTextureType} type + * @param {WGSLSampledType} sampledType + * @returns {GPUTextureSampleType} + */ +export const typeToTextureSampleType = (type, sampledType) => { + if (type.includes('depth')) { + return 'depth' + } + + switch (sampledType) { + case 'f32': + case 'i32': + case 'u32': + default: + return 'float' + } +} + +/** + * @param {WGSLSamplerType} type + * @returns {GPUSamplerBindingType} + */ +export const typeToSamplerBindingType = type => { + switch (type) { + case 'sampler_comparison': + return 'comparison' + case 'sampler': + default: + return 'filtering' + } +}