import { BindGroupLayout } from './bind-group-layout.js'
import { FragmentStateDescriptor, ShaderPair, ShaderPairDescriptor, VertexStateDescriptor } from './shader-module.js'
import { MaterialError, WebGPUObjectError } from '../utils/errors.js'
import { ResourceType } from 'wgsl_reflect'
import { BindGroup } from './bind-group.js'
import { Texture } from './texture.js'
import { Buffer } from './buffer.js'
import { Sampler } from './sampler.js'

/** @import { FragmentStateDescriptor, VertexStateDescriptor } from './shader-module.js' */

type BindingResource = Buffer | Texture | Sampler

interface MaterialPipelineDescriptor {
  label?: string
  pipelineLayout?: GPUPipelineLayout
  vertex: VertexStateDescriptor
  fragment?: FragmentStateDescriptor
  primitive?: GPUPrimitiveState
}

interface MaterialDescriptor extends ShaderPairDescriptor {
  bindGroupLayouts?: BindGroupLayout[]
}

export class Material {
  _device: GPUDevice
  _shaders: ShaderPair
  _bindGroupLayouts: BindGroupLayout[]
  _pipelineLayout: GPUPipelineLayout

  get shaders() {
    return this._shaders
  }

  get bindGroupLayouts() {
    return this._bindGroupLayouts
  }

  constructor(device: GPUDevice, descriptor: MaterialDescriptor) {
    this._device = device
    this._shaders = Material._reflectShaders(descriptor)
    const bgl = descriptor.bindGroupLayouts

    if (bgl && bgl.length > 0) {
      this._bindGroupLayouts = bgl
    } else {
      this._bindGroupLayouts = this._reflectBindGroupLayouts(device, this._shaders)
    }

    if (this._bindGroupLayouts && this.bindGroupLayouts.length > 0) {
      try {
        this._pipelineLayout = device.createPipelineLayout({
          bindGroupLayouts: this._bindGroupLayouts.map(bgl => bgl.handle)
        })
      } catch (err) {
        throw WebGPUObjectError.from(err, Material)
      }
    }
  }

  static _reflectShaders(shaders: ShaderPairDescriptor): ShaderPair {
    if (shaders == null) {
      throw MaterialError.missingShader('vertex')
    }

    if ('vertex' in shaders) {
      return ShaderPair.fromPair(shaders)
    }
  }

  _reflectBindGroupLayouts(device: GPUDevice, shaders: ShaderPair): BindGroupLayout[] {
    const layouts = shaders.createBindGroupLayoutEntries()
    return layouts.map(entries => BindGroupLayout.create(device, { entries }))
  }

  createBindGroup(groupIndex: number, resources: Record<PropertyKey, BindingResource>, label?: string) {
    if (groupIndex < 0 || groupIndex >= this._bindGroupLayouts.length) {
      throw new Error(`Invalid bind group index: ${groupIndex}`)
    }

    const bgl = this._bindGroupLayouts[groupIndex]

    let entries = []
    for (const name in resources) {
      const resource = resources[name]

      const variableInfo = this._shaders.findVariableInfo(name, groupIndex)

      const entry = {
        binding: variableInfo.binding,
        resource: resource.toBindingResource()
      }

      if (variableInfo.resourceType === ResourceType.Uniform || variableInfo.resourceType === ResourceType.Storage) {
        // TODO: handle user provided offset/size
      }

      entries.push(entry)
    }

    entries.sort((a, b) => a.binding - b.binding)

    return BindGroup.create(this._device, {
      layout: bgl.handle,
      entries,
      label
    })
  }

  getRenderPipelineDescriptor(descriptor: MaterialPipelineDescriptor): GPURenderPipelineDescriptor {
    const { fragment, vertex } = this.shaders.getRenderPipelineStates(descriptor)

    return {
      label: descriptor.label,
      layout: descriptor.pipelineLayout || this._pipelineLayout,
      fragment,
      vertex,
      primitive: descriptor.primitive || { topology: 'triangle-list' }
    }
  }
}