diff --git a/vulkano/src/device/mod.rs b/vulkano/src/device/mod.rs index 0157c3efac..a1fde8e53b 100644 --- a/vulkano/src/device/mod.rs +++ b/vulkano/src/device/mod.rs @@ -2194,8 +2194,8 @@ impl Deref for DeviceOwnedDebugWrapper { #[derive(Clone, Debug)] pub struct ShaderGroupHandlesData { - pub(crate) data: Vec, - pub(crate) handle_size: u32, + data: Vec, + handle_size: u32, } impl ShaderGroupHandlesData { diff --git a/vulkano/src/pipeline/ray_tracing.rs b/vulkano/src/pipeline/ray_tracing.rs index 4f99a8a190..795b41ce61 100644 --- a/vulkano/src/pipeline/ray_tracing.rs +++ b/vulkano/src/pipeline/ray_tracing.rs @@ -5,7 +5,7 @@ use super::{ }; use crate::{ buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer}, - device::{Device, DeviceOwned, DeviceOwnedDebugWrapper, ShaderGroupHandlesData}, + device::{Device, DeviceOwned, DeviceOwnedDebugWrapper}, instance::InstanceOwnedDebugWrapper, macros::impl_id_counter, memory::{ @@ -826,17 +826,29 @@ impl ShaderBindingTable { { let mut sbt_buffer_write = sbt_buffer.write().unwrap(); - copy_shader_handles( - &handle_data, - &mut sbt_buffer_write, - raygen.size as usize, - miss.size as usize, - miss.stride as usize, - hit.size as usize, - hit.stride as usize, - callable.size as usize, - callable.stride as usize, - ); + + let mut handle_iter = handle_data.iter(); + + let handle_size = handle_data.handle_size() as usize; + sbt_buffer_write[..handle_size].copy_from_slice(handle_iter.next().unwrap()); + let mut offset = raygen.size as usize; + for _ in 0..miss_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += miss.stride as usize; + } + offset = (raygen.size + miss.size) as usize; + for _ in 0..hit_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += hit.stride as usize; + } + offset = (raygen.size + miss.size + hit.size) as usize; + for _ in 0..callable_shader_count { + sbt_buffer_write[offset..offset + handle_size] + .copy_from_slice(handle_iter.next().unwrap()); + offset += callable.stride as usize; + } } Ok(Self { @@ -850,109 +862,3 @@ impl ShaderBindingTable { }) } } - -fn copy_shader_handles( - handle_data: &ShaderGroupHandlesData, - output: &mut [u8], - raygen_size: usize, - miss_size: usize, - miss_stride: usize, - hit_size: usize, - hit_stride: usize, - callable_size: usize, - callable_stride: usize, -) { - let handle_size = handle_data.handle_size() as usize; - let mut handle_iter = handle_data.iter(); - - // Copy raygen shader handle - output[..handle_size].copy_from_slice(handle_iter.next().unwrap()); - - // Copy miss shader handles - let mut offset = raygen_size; - while offset < raygen_size + miss_size { - output[offset..offset + handle_size].copy_from_slice(handle_iter.next().unwrap()); - offset += miss_stride; - } - - // Copy hit shader handles - assert_eq!(offset, raygen_size + miss_size); - while offset < raygen_size + miss_size + hit_size { - output[offset..offset + handle_size].copy_from_slice(handle_iter.next().unwrap()); - offset += hit_stride; - } - - // Copy callable shader handles - assert_eq!(offset, raygen_size + miss_size + hit_size); - while offset < raygen_size + miss_size + hit_size + callable_size { - output[offset..offset + handle_size].copy_from_slice(handle_iter.next().unwrap()); - offset += callable_stride; - } - - assert_eq!(offset, raygen_size + miss_size + hit_size + callable_size); -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::device::ShaderGroupHandlesData; - - #[test] - fn test_copy_shader_handles_single_raygen() { - let handle_data = ShaderGroupHandlesData { - data: vec![1, 2, 3, 4], - handle_size: 4, - }; - let mut output = vec![0; 4]; - - copy_shader_handles(&handle_data, &mut output, 4, 0, 0, 0, 0, 0, 0); - - assert_eq!(output, vec![1, 2, 3, 4]); - } - - #[test] - fn test_copy_shader_handles_with_stride() { - let handle_data = ShaderGroupHandlesData { - data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - handle_size: 4, - }; - let mut output = vec![0; 24]; - - // raygen (4 bytes) + 2 miss shaders with 8-byte stride - copy_shader_handles(&handle_data, &mut output, 4, 16, 8, 4, 4, 0, 0); - - assert_eq!( - output, - vec![ - 1, 2, 3, 4, // raygen - 5, 6, 7, 8, // first miss shader - 0, 0, 0, 0, // padding - 9, 10, 11, 12, // second miss shader - 0, 0, 0, 0, // padding - 13, 14, 15, 16 // hit shader - ] - ); - } - - #[test] - fn test_copy_shader_handles_all_types() { - let handle_data = ShaderGroupHandlesData { - data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - handle_size: 4, - }; - let mut output = vec![0; 24]; - - // raygen (4 bytes) + miss (4 bytes) + hit (4 bytes) + callable (4 bytes) - copy_shader_handles(&handle_data, &mut output, 4, 4, 4, 4, 4, 4, 4); - - assert_eq!( - output[..16], - vec![ - 1, 2, 3, 4, // raygen - 5, 6, 7, 8, // miss - 9, 10, 11, 12, // hit - 13, 14, 15, 16 // callable - ] - ); - } -}