Skip to content

Commit

Permalink
undo sbt copy refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ComfyFluffy committed Dec 2, 2024
1 parent a6a8739 commit 6d68913
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 120 deletions.
4 changes: 2 additions & 2 deletions vulkano/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2194,8 +2194,8 @@ impl<T> Deref for DeviceOwnedDebugWrapper<T> {

#[derive(Clone, Debug)]
pub struct ShaderGroupHandlesData {
pub(crate) data: Vec<u8>,
pub(crate) handle_size: u32,
data: Vec<u8>,
handle_size: u32,
}

impl ShaderGroupHandlesData {
Expand Down
142 changes: 24 additions & 118 deletions vulkano/src/pipeline/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
};
use crate::{
buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer},
device::{Device, DeviceOwned, DeviceOwnedDebugWrapper, ShaderGroupHandlesData},
device::{Device, DeviceOwned, DeviceOwnedDebugWrapper},
instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter,
memory::{
Expand Down Expand Up @@ -826,17 +826,29 @@ impl ShaderBindingTable {

{
let mut sbt_buffer_write = sbt_buffer.write().unwrap();
copy_shader_handles(
&handle_data,
&mut sbt_buffer_write,
raygen.size as usize,
miss.size as usize,
miss.stride as usize,
hit.size as usize,
hit.stride as usize,
callable.size as usize,
callable.stride as usize,
);

let mut handle_iter = handle_data.iter();

let handle_size = handle_data.handle_size() as usize;
sbt_buffer_write[..handle_size].copy_from_slice(handle_iter.next().unwrap());
let mut offset = raygen.size as usize;
for _ in 0..miss_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
offset += miss.stride as usize;
}
offset = (raygen.size + miss.size) as usize;
for _ in 0..hit_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
offset += hit.stride as usize;
}
offset = (raygen.size + miss.size + hit.size) as usize;
for _ in 0..callable_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
offset += callable.stride as usize;
}
}

Ok(Self {
Expand All @@ -850,109 +862,3 @@ impl ShaderBindingTable {
})
}
}

fn copy_shader_handles(
handle_data: &ShaderGroupHandlesData,
output: &mut [u8],
raygen_size: usize,
miss_size: usize,
miss_stride: usize,
hit_size: usize,
hit_stride: usize,
callable_size: usize,
callable_stride: usize,
) {
let handle_size = handle_data.handle_size() as usize;
let mut handle_iter = handle_data.iter();

// Copy raygen shader handle
output[..handle_size].copy_from_slice(handle_iter.next().unwrap());

// Copy miss shader handles
let mut offset = raygen_size;
while offset < raygen_size + miss_size {
output[offset..offset + handle_size].copy_from_slice(handle_iter.next().unwrap());
offset += miss_stride;
}

// Copy hit shader handles
assert_eq!(offset, raygen_size + miss_size);
while offset < raygen_size + miss_size + hit_size {
output[offset..offset + handle_size].copy_from_slice(handle_iter.next().unwrap());
offset += hit_stride;
}

// Copy callable shader handles
assert_eq!(offset, raygen_size + miss_size + hit_size);
while offset < raygen_size + miss_size + hit_size + callable_size {
output[offset..offset + handle_size].copy_from_slice(handle_iter.next().unwrap());
offset += callable_stride;
}

assert_eq!(offset, raygen_size + miss_size + hit_size + callable_size);
}

#[cfg(test)]
mod tests {
use super::*;
use crate::device::ShaderGroupHandlesData;

#[test]
fn test_copy_shader_handles_single_raygen() {
let handle_data = ShaderGroupHandlesData {
data: vec![1, 2, 3, 4],
handle_size: 4,
};
let mut output = vec![0; 4];

copy_shader_handles(&handle_data, &mut output, 4, 0, 0, 0, 0, 0, 0);

assert_eq!(output, vec![1, 2, 3, 4]);
}

#[test]
fn test_copy_shader_handles_with_stride() {
let handle_data = ShaderGroupHandlesData {
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
handle_size: 4,
};
let mut output = vec![0; 24];

// raygen (4 bytes) + 2 miss shaders with 8-byte stride
copy_shader_handles(&handle_data, &mut output, 4, 16, 8, 4, 4, 0, 0);

assert_eq!(
output,
vec![
1, 2, 3, 4, // raygen
5, 6, 7, 8, // first miss shader
0, 0, 0, 0, // padding
9, 10, 11, 12, // second miss shader
0, 0, 0, 0, // padding
13, 14, 15, 16 // hit shader
]
);
}

#[test]
fn test_copy_shader_handles_all_types() {
let handle_data = ShaderGroupHandlesData {
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
handle_size: 4,
};
let mut output = vec![0; 24];

// raygen (4 bytes) + miss (4 bytes) + hit (4 bytes) + callable (4 bytes)
copy_shader_handles(&handle_data, &mut output, 4, 4, 4, 4, 4, 4, 4);

assert_eq!(
output[..16],
vec![
1, 2, 3, 4, // raygen
5, 6, 7, 8, // miss
9, 10, 11, 12, // hit
13, 14, 15, 16 // callable
]
);
}
}

0 comments on commit 6d68913

Please # to comment.