diff --git a/vulkano/src/lib.rs b/vulkano/src/lib.rs index 288e8c1009..07eb7b18ae 100644 --- a/vulkano/src/lib.rs +++ b/vulkano/src/lib.rs @@ -179,6 +179,9 @@ pub use ash::vk::DeviceAddress; /// A [`DeviceAddress`] that is known not to equal zero. pub type NonNullDeviceAddress = NonZeroU64; +/// Represents a region of device addresses with a stride. +pub use ash::vk::StridedDeviceAddressRegionKHR as StridedDeviceAddressRegion; + /// Holds 24 bits in the least significant bits of memory, /// and 8 bytes in the most significant bits of that memory, /// occupying a single [`u32`] in total. diff --git a/vulkano/src/pipeline/ray_tracing.rs b/vulkano/src/pipeline/ray_tracing.rs index 3c6e90fc12..abee90419c 100644 --- a/vulkano/src/pipeline/ray_tracing.rs +++ b/vulkano/src/pipeline/ray_tracing.rs @@ -52,9 +52,8 @@ use crate::{ DeviceAlignment, }, shader::{spirv::ExecutionModel, DescriptorBindingRequirements}, - Validated, ValidationError, VulkanError, VulkanObject, + StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject, }; -use ash::vk::StridedDeviceAddressRegionKHR; use foldhash::{HashMap, HashSet}; use smallvec::SmallVec; use std::{collections::hash_map::Entry, mem::MaybeUninit, num::NonZeroU64, ptr, sync::Arc}; @@ -140,6 +139,12 @@ impl RayTracingPipeline { Ok(Self::from_handle(device, handle, create_info)) } + /// Creates a new `RayTracingPipeline` from a raw object handle. + /// + /// # Safety + /// + /// - `handle` must be a valid Vulkan object handle created from `device`. + /// - `create_info` must match the info used to create the object. pub unsafe fn from_handle( device: Arc, handle: ash::vk::Pipeline, @@ -196,18 +201,22 @@ impl RayTracingPipeline { }) } + // Returns the shader groups that the pipeline was created with. pub fn groups(&self) -> &[RayTracingShaderGroupCreateInfo] { &self.groups } + // Returns the shader stages that the pipeline was created with. pub fn stages(&self) -> &[PipelineShaderStageCreateInfo] { &self.stages } + /// Returns the `Device` that the pipeline was created with. pub fn device(&self) -> &Arc { &self.device } + /// Returns the flags that the pipeline was created with. pub fn flags(&self) -> PipelineCreateFlags { self.flags } @@ -275,12 +284,12 @@ pub struct RayTracingPipelineCreateInfo { /// The ray tracing shader stages to use. /// - /// There is no default value. + /// The default value is empty, which must be overridden. pub stages: SmallVec<[PipelineShaderStageCreateInfo; 5]>, /// The shader groups to use. They reference the shader stages in `stages`. /// - /// The default value is empty. + /// The default value is empty, which must be overridden. pub groups: SmallVec<[RayTracingShaderGroupCreateInfo; 5]>, /// The maximum recursion depth of the pipeline. @@ -377,6 +386,13 @@ impl RayTracingPipelineCreateInfo { })); } + if stages.is_empty() { + return Err(Box::new(ValidationError { + problem: "`stages` is empty".into(), + vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-pLibraryInfo-07999"], + ..Default::default() + })); + } for stage in stages { stage.validate(device).map_err(|err| { err.add_context("stages") @@ -407,6 +423,13 @@ impl RayTracingPipelineCreateInfo { })?; } + if groups.is_empty() { + return Err(Box::new(ValidationError { + problem: "`groups` is empty".into(), + vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-flags-08700"], + ..Default::default() + })); + } for group in groups { group.validate(stages).map_err(|err| { err.add_context("groups") @@ -746,32 +769,36 @@ impl RayTracingShaderGroupCreateInfo { } } -pub struct RayTracingPipelineCreateInfoFields1Vk<'a> { +pub(crate) struct RayTracingPipelineCreateInfoFields1Vk<'a> { pub(crate) stages_vk: SmallVec<[ash::vk::PipelineShaderStageCreateInfo<'a>; 5]>, pub(crate) groups_vk: SmallVec<[ash::vk::RayTracingShaderGroupCreateInfoKHR<'static>; 5]>, pub(crate) dynamic_state_vk: Option>, } -pub struct RayTracingPipelineCreateInfoFields1ExtensionsVk { +pub(crate) struct RayTracingPipelineCreateInfoFields1ExtensionsVk { pub(crate) stages_extensions_vk: SmallVec<[PipelineShaderStageCreateInfoExtensionsVk; 5]>, } -pub struct RayTracingPipelineCreateInfoFields2Vk<'a> { +pub(crate) struct RayTracingPipelineCreateInfoFields2Vk<'a> { pub(crate) stages_fields1_vk: SmallVec<[PipelineShaderStageCreateInfoFields1Vk<'a>; 5]>, pub(crate) dynamic_states_vk: SmallVec<[ash::vk::DynamicState; 4]>, } -pub struct RayTracingPipelineCreateInfoFields3Vk { +pub(crate) struct RayTracingPipelineCreateInfoFields3Vk { pub(crate) stages_fields2_vk: SmallVec<[PipelineShaderStageCreateInfoFields2Vk; 5]>, } -/// An object that holds the addresses of the shader groups in a shader binding table. +/// An object that holds the strided addresses of the shader groups in a shader binding table. #[derive(Debug, Clone)] pub struct ShaderBindingTableAddresses { - pub raygen: StridedDeviceAddressRegionKHR, - pub miss: StridedDeviceAddressRegionKHR, - pub hit: StridedDeviceAddressRegionKHR, - pub callable: StridedDeviceAddressRegionKHR, + /// The address of the ray generation shader group handle. + pub raygen: StridedDeviceAddressRegion, + /// The address of the miss shader group handles. + pub miss: StridedDeviceAddressRegion, + /// The address of the hit shader group handles. + pub hit: StridedDeviceAddressRegion, + /// The address of the callable shader group handles. + pub callable: StridedDeviceAddressRegion, } /// An object that holds the shader binding table buffer and its addresses. @@ -782,6 +809,7 @@ pub struct ShaderBindingTable { } impl ShaderBindingTable { + /// Returns the addresses of the shader groups in the shader binding table. pub fn addresses(&self) -> &ShaderBindingTableAddresses { &self.addresses } @@ -837,12 +865,12 @@ impl ShaderBindingTable { let raygen_stride = align_up(handle_size_aligned, shader_group_base_alignment); - let mut raygen = StridedDeviceAddressRegionKHR { + let mut raygen = StridedDeviceAddressRegion { stride: raygen_stride, size: raygen_stride, device_address: 0, }; - let mut miss = StridedDeviceAddressRegionKHR { + let mut miss = StridedDeviceAddressRegion { stride: handle_size_aligned, size: align_up( handle_size_aligned * miss_shader_count, @@ -850,7 +878,7 @@ impl ShaderBindingTable { ), device_address: 0, }; - let mut hit = StridedDeviceAddressRegionKHR { + let mut hit = StridedDeviceAddressRegion { stride: handle_size_aligned, size: align_up( handle_size_aligned * hit_shader_count, @@ -858,7 +886,7 @@ impl ShaderBindingTable { ), device_address: 0, }; - let mut callable = StridedDeviceAddressRegionKHR { + let mut callable = StridedDeviceAddressRegion { stride: handle_size_aligned, size: align_up( handle_size_aligned * callable_shader_count,