Skip to content

Commit

Permalink
refactor:
Browse files Browse the repository at this point in the history
- add doc
- add check for stages/groups
- pub modification
- `pub use` StridedDeviceAddressRegionKHR in lib.rs
  • Loading branch information
ComfyFluffy committed Dec 13, 2024
1 parent 2d8804b commit 9f6b795
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
3 changes: 3 additions & 0 deletions vulkano/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ pub use ash::vk::DeviceAddress;
/// A [`DeviceAddress`] that is known not to equal zero.
pub type NonNullDeviceAddress = NonZeroU64;

/// Represents a region of device addresses with a stride.
pub use ash::vk::StridedDeviceAddressRegionKHR as StridedDeviceAddressRegion;

/// Holds 24 bits in the least significant bits of memory,
/// and 8 bytes in the most significant bits of that memory,
/// occupying a single [`u32`] in total.
Expand Down
62 changes: 45 additions & 17 deletions vulkano/src/pipeline/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ use crate::{
DeviceAlignment,
},
shader::{spirv::ExecutionModel, DescriptorBindingRequirements},
Validated, ValidationError, VulkanError, VulkanObject,
StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject,
};
use ash::vk::StridedDeviceAddressRegionKHR;
use foldhash::{HashMap, HashSet};
use smallvec::SmallVec;
use std::{collections::hash_map::Entry, mem::MaybeUninit, num::NonZeroU64, ptr, sync::Arc};
Expand Down Expand Up @@ -140,6 +139,12 @@ impl RayTracingPipeline {
Ok(Self::from_handle(device, handle, create_info))
}

/// Creates a new `RayTracingPipeline` from a raw object handle.
///
/// # Safety
///
/// - `handle` must be a valid Vulkan object handle created from `device`.
/// - `create_info` must match the info used to create the object.
pub unsafe fn from_handle(
device: Arc<Device>,
handle: ash::vk::Pipeline,
Expand Down Expand Up @@ -196,18 +201,22 @@ impl RayTracingPipeline {
})
}

// Returns the shader groups that the pipeline was created with.
pub fn groups(&self) -> &[RayTracingShaderGroupCreateInfo] {
&self.groups
}

// Returns the shader stages that the pipeline was created with.
pub fn stages(&self) -> &[PipelineShaderStageCreateInfo] {
&self.stages
}

/// Returns the `Device` that the pipeline was created with.
pub fn device(&self) -> &Arc<Device> {
&self.device
}

/// Returns the flags that the pipeline was created with.
pub fn flags(&self) -> PipelineCreateFlags {
self.flags
}
Expand Down Expand Up @@ -275,12 +284,12 @@ pub struct RayTracingPipelineCreateInfo {

/// The ray tracing shader stages to use.
///
/// There is no default value.
/// The default value is empty, which must be overridden.
pub stages: SmallVec<[PipelineShaderStageCreateInfo; 5]>,

/// The shader groups to use. They reference the shader stages in `stages`.
///
/// The default value is empty.
/// The default value is empty, which must be overridden.
pub groups: SmallVec<[RayTracingShaderGroupCreateInfo; 5]>,

/// The maximum recursion depth of the pipeline.
Expand Down Expand Up @@ -377,6 +386,13 @@ impl RayTracingPipelineCreateInfo {
}));
}

if stages.is_empty() {
return Err(Box::new(ValidationError {
problem: "`stages` is empty".into(),
vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-pLibraryInfo-07999"],
..Default::default()
}));
}
for stage in stages {
stage.validate(device).map_err(|err| {
err.add_context("stages")
Expand Down Expand Up @@ -407,6 +423,13 @@ impl RayTracingPipelineCreateInfo {
})?;
}

if groups.is_empty() {
return Err(Box::new(ValidationError {
problem: "`groups` is empty".into(),
vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-flags-08700"],
..Default::default()
}));
}
for group in groups {
group.validate(stages).map_err(|err| {
err.add_context("groups")
Expand Down Expand Up @@ -746,32 +769,36 @@ impl RayTracingShaderGroupCreateInfo {
}
}

pub struct RayTracingPipelineCreateInfoFields1Vk<'a> {
pub(crate) struct RayTracingPipelineCreateInfoFields1Vk<'a> {
pub(crate) stages_vk: SmallVec<[ash::vk::PipelineShaderStageCreateInfo<'a>; 5]>,
pub(crate) groups_vk: SmallVec<[ash::vk::RayTracingShaderGroupCreateInfoKHR<'static>; 5]>,
pub(crate) dynamic_state_vk: Option<ash::vk::PipelineDynamicStateCreateInfo<'a>>,
}

pub struct RayTracingPipelineCreateInfoFields1ExtensionsVk {
pub(crate) struct RayTracingPipelineCreateInfoFields1ExtensionsVk {
pub(crate) stages_extensions_vk: SmallVec<[PipelineShaderStageCreateInfoExtensionsVk; 5]>,
}

pub struct RayTracingPipelineCreateInfoFields2Vk<'a> {
pub(crate) struct RayTracingPipelineCreateInfoFields2Vk<'a> {
pub(crate) stages_fields1_vk: SmallVec<[PipelineShaderStageCreateInfoFields1Vk<'a>; 5]>,
pub(crate) dynamic_states_vk: SmallVec<[ash::vk::DynamicState; 4]>,
}

pub struct RayTracingPipelineCreateInfoFields3Vk {
pub(crate) struct RayTracingPipelineCreateInfoFields3Vk {
pub(crate) stages_fields2_vk: SmallVec<[PipelineShaderStageCreateInfoFields2Vk; 5]>,
}

/// An object that holds the addresses of the shader groups in a shader binding table.
/// An object that holds the strided addresses of the shader groups in a shader binding table.
#[derive(Debug, Clone)]
pub struct ShaderBindingTableAddresses {
pub raygen: StridedDeviceAddressRegionKHR,
pub miss: StridedDeviceAddressRegionKHR,
pub hit: StridedDeviceAddressRegionKHR,
pub callable: StridedDeviceAddressRegionKHR,
/// The address of the ray generation shader group handle.
pub raygen: StridedDeviceAddressRegion,
/// The address of the miss shader group handles.
pub miss: StridedDeviceAddressRegion,
/// The address of the hit shader group handles.
pub hit: StridedDeviceAddressRegion,
/// The address of the callable shader group handles.
pub callable: StridedDeviceAddressRegion,
}

/// An object that holds the shader binding table buffer and its addresses.
Expand All @@ -782,6 +809,7 @@ pub struct ShaderBindingTable {
}

impl ShaderBindingTable {
/// Returns the addresses of the shader groups in the shader binding table.
pub fn addresses(&self) -> &ShaderBindingTableAddresses {
&self.addresses
}
Expand Down Expand Up @@ -837,28 +865,28 @@ impl ShaderBindingTable {

let raygen_stride = align_up(handle_size_aligned, shader_group_base_alignment);

let mut raygen = StridedDeviceAddressRegionKHR {
let mut raygen = StridedDeviceAddressRegion {
stride: raygen_stride,
size: raygen_stride,
device_address: 0,
};
let mut miss = StridedDeviceAddressRegionKHR {
let mut miss = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * miss_shader_count,
shader_group_base_alignment,
),
device_address: 0,
};
let mut hit = StridedDeviceAddressRegionKHR {
let mut hit = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * hit_shader_count,
shader_group_base_alignment,
),
device_address: 0,
};
let mut callable = StridedDeviceAddressRegionKHR {
let mut callable = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * callable_shader_count,
Expand Down

0 comments on commit 9f6b795

Please # to comment.