#custom render pass

9 messages · Page 1 of 1 (latest)

mighty agate
#

Hello, I'm trying to add a custom render pass. The current goal is to read some f32 values from a vec and draw a black line to the screen. The x value of any given point on the line would be based on the index in the vec, and the y value would be based on the f32 value. I'm doing something very wrong but I don't know what. Here's what I have so far. It's a bit of a mess...

use std::{f32::consts::PI, num::NonZeroU64};
use bevy::{
    prelude::*,
    render::{
        mesh::PrimitiveTopology,
        render_graph::{NodeRunError, RenderGraphApp, RenderGraphContext, RenderLabel, ViewNode, ViewNodeRunner},
        render_resource::{
            BindGroupEntry, BindGroupLayout, BindGroupLayoutEntry, BindingResource,
            BufferAddress, BufferBindingType, BufferInitDescriptor, BufferSize, BufferUsages,
            CachedRenderPipelineId, ColorTargetState, ColorWrites, FragmentState, LoadOp, Operations,
            PipelineCache, PrimitiveState, RenderPassColorAttachment, RenderPassDescriptor,
            RenderPipelineDescriptor, ShaderStages, StoreOp, TextureFormat,
            VertexAttribute, VertexBufferLayout, VertexFormat, VertexState, VertexStepMode,
        },
        renderer::{RenderContext, RenderDevice},
        view::ViewTarget,
        MainWorld, RenderApp,
    },
};
use bytemuck::{Pod, Zeroable};

// Must match the value in the shader.
const MAX_POINTS: u32 = 100;

/// Label used in the render graph (similar to PostProcessLabel).
#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
struct SignalRenderLabel;

#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
struct SignalVertex {
    position: [f32; 2],
    color: [f32; 4],
}

#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct SignalUniform {
    resolution: [f32; 2],
}

// Our custom pipeline now stores two bind group layouts.
#[derive(Resource)]
struct SignalPipeline {
    uniform_bind_group_layout: BindGroupLayout, // group 0: resolution
    line_bind_group_layout: BindGroupLayout,    // group 1: line data (point count & points)
    pipeline_id: CachedRenderPipelineId,
}
#
impl FromWorld for SignalPipeline {
    fn from_world(world: &mut World) -> Self {
        let render_device = world.resource::<RenderDevice>();
        // Load our custom shader (which must match the expected groups below).
        let shader = world.load_asset("shaders/signal.wgsl");

        // Create group 0: resolution uniform.
        let uniform_size = std::mem::size_of::<SignalUniform>() as u64;
        let uniform_bind_group_layout = render_device.create_bind_group_layout(
            "signal_uniform_bind_group_layout",
            &[BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::VERTEX,
                ty: bevy::render::render_resource::BindingType::Buffer {
                    ty: BufferBindingType::Uniform,
                    has_dynamic_offset: false,
                    min_binding_size: NonZeroU64::new(uniform_size),
                },
                count: None,
            }],
        );

        // Create group 1: line data (point count and points array).
        let line_bind_group_layout = render_device.create_bind_group_layout(
            "signal_line_bind_group_layout",
            &[
                // Binding 0: point count (u32, 4 bytes)
                BindGroupLayoutEntry {
                    binding: 0,
                    visibility: ShaderStages::FRAGMENT,
                    ty: bevy::render::render_resource::BindingType::Buffer {
                        ty: BufferBindingType::Uniform,
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(4),
                    },
                    count: None,
                },
                // Binding 1: points array (array of vec4<f32>, each 16 bytes)
                BindGroupLayoutEntry {
                    binding: 1,
                    visibility: ShaderStages::FRAGMENT,
                    ty: bevy::render::render_resource::BindingType::Buffer {
                        ty: BufferBindingType::Uniform,
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(16 * MAX_POINTS as u64),
                    },
                    count: None,
                },
            ],
        );
#
        // Queue a render pipeline using both bind groups.
        let pipeline_cache = world.resource_mut::<PipelineCache>();
        let pipeline_id = pipeline_cache.queue_render_pipeline(RenderPipelineDescriptor {
            label: Some("signal_pipeline".into()),
            layout: vec![uniform_bind_group_layout.clone(), line_bind_group_layout.clone()],
            vertex: VertexState {
                shader: shader.clone(),
                shader_defs: vec![],
                entry_point: "vertex_main".into(),
                buffers: vec![VertexBufferLayout {
                    array_stride: std::mem::size_of::<SignalVertex>() as BufferAddress,
                    step_mode: VertexStepMode::Vertex,
                    attributes: vec![
                        VertexAttribute {
                            format: VertexFormat::Float32x2,
                            offset: 0,
                            shader_location: 0,
                        },
                        VertexAttribute {
                            format: VertexFormat::Float32x4,
                            offset: std::mem::size_of::<[f32; 2]>() as BufferAddress,
                            shader_location: 1,
                        },
                    ],
                }],
            },
            fragment: Some(FragmentState {
                shader,
                shader_defs: vec![],
                entry_point: "fragment_main".into(),
                targets: vec![Some(ColorTargetState {
                    format: TextureFormat::Bgra8UnormSrgb, // Matches the view target.
                    blend: Some(bevy::render::render_resource::BlendState::ALPHA_BLENDING),
                    write_mask: ColorWrites::ALL,
                })],
            }),
            primitive: PrimitiveState {
                topology: PrimitiveTopology::LineStrip,
                ..Default::default()
            },
            depth_stencil: None,
            multisample: bevy::render::render_resource::MultisampleState::default(),
            push_constant_ranges: vec![],
            zero_initialize_workgroup_memory: false,
        });

        Self {
            uniform_bind_group_layout,
            line_bind_group_layout,
            pipeline_id,
        }
    }
}
#

// Our custom render node.
#[derive(Default)]
struct SignalNode;

impl ViewNode for SignalNode {
    // We only need the view target.
    type ViewQuery = (&'static ViewTarget,);
    fn run(
        &self,
        _graph: &mut RenderGraphContext,
        render_context: &mut RenderContext,
        (view_target,): bevy::ecs::query::QueryItem<Self::ViewQuery>,
        world: &World,
    ) -> Result<(), NodeRunError> {
        // Fixed resolution matching our uniform.
        let resolution = [800.0, 600.0];
        let signal_pipeline = world.resource::<SignalPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        let Some(pipeline) = pipeline_cache.get_render_pipeline(signal_pipeline.pipeline_id) else {
            return Ok(());
        };

        let render_device = world.resource::<RenderDevice>();

        // Create group 0 uniform (resolution).
        let signal_uniform = SignalUniform { resolution };
        let uniform_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: None,
            contents: bytemuck::bytes_of(&signal_uniform),
            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
        });
        let uniform_bind_group = render_device.create_bind_group(
            "signal_uniform_bind_group",
            &signal_pipeline.uniform_bind_group_layout,
            &[BindGroupEntry {
                binding: 0,
                resource: BindingResource::Buffer(uniform_buffer.as_entire_buffer_binding()),
            }],
        );

        // Get our extracted vertex data.
        let signal_data = world.resource::<ExtractedSignalData>();
        if signal_data.vertices.is_empty() {
            return Ok(());
        }
        let vertex_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: None,
            contents: bytemuck::cast_slice(&signal_data.vertices),
            usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
        });

        // Prepare group 1 data:
        // Convert each vertex position (in pixel space) to normalized UV coordinates.
        let point_count = signal_data.vertices.len() as u32;
        let mut points: Vec<[f32; 4]> = Vec::with_capacity(MAX_POINTS as usize);
        for vertex in &signal_data.vertices {
            let uv = [vertex.position[0] / resolution[0], vertex.position[1] / resolution[1]];
            points.push([uv[0], uv[1], 0.0, 0.0]);
        }
        // If fewer than MAX_POINTS, pad the array.
        while points.len() < MAX_POINTS as usize {
            points.push([0.0, 0.0, 0.0, 0.0]);
        }

        let point_count_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: None,
            contents: bytemuck::bytes_of(&point_count),
            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
        });
        let points_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: None,
            contents: bytemuck::cast_slice(&points),
            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
        });
        let line_bind_group = render_device.create_bind_group(
            "signal_line_bind_group",
            &signal_pipeline.line_bind_group_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::Buffer(point_count_buffer.as_entire_buffer_binding()),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::Buffer(points_buffer.as_entire_buffer_binding()),
                },
            ],
        );
#

        // Begin the render pass.
        let mut render_pass = render_context.begin_tracked_render_pass(RenderPassDescriptor {
            label: Some("signal_render_pass".into()),
            color_attachments: &[Some(RenderPassColorAttachment {
                view: &view_target.out_texture(),
                resolve_target: None,
                ops: Operations {
                    load: LoadOp::Load,
                    store: StoreOp::Store,
                },
            })],
            depth_stencil_attachment: None,
            timestamp_writes: None,
            occlusion_query_set: None,
        });

        render_pass.set_render_pipeline(pipeline);
        render_pass.set_bind_group(0, &uniform_bind_group, &[]);
        render_pass.set_bind_group(1, &line_bind_group, &[]);
        render_pass.set_vertex_buffer(0, vertex_buffer.slice(..));
        render_pass.draw(0..(signal_data.vertices.len() as u32), 0..1);
        Ok(())
    }
}

// Plugin to set up our render node and update our signal data.
struct SignalPlugin;
impl Plugin for SignalPlugin {
    fn build(&self, app: &mut App) {
        app.insert_resource(SignalRenderData::default())
            .add_systems(Update, update_signal_stream);
    }
    fn finish(&self, app: &mut App) {
        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
            render_app
                .add_render_graph_node::<ViewNodeRunner<SignalNode>>(
                    bevy::core_pipeline::core_2d::graph::Core2d,
                    SignalRenderLabel,
                )
                .add_render_graph_edges(
                    bevy::core_pipeline::core_2d::graph::Core2d,
                    (
                        bevy::core_pipeline::core_2d::graph::Node2d::Tonemapping,
                        SignalRenderLabel,
                        bevy::core_pipeline::core_2d::graph::Node2d::EndMainPassPostProcessing,
                    ),
                );
            let w = render_app.world_mut();
            let res = SignalPipeline::from_world(w);
            render_app.insert_resource(res)
                .add_systems(ExtractSchedule, extract_signal_data);
            render_app.insert_resource(ExtractedSignalData::default());
        }
    }
}
#
fn main() {
    App::new()
        .add_plugins(DefaultPlugins)
        .add_plugins(SignalPlugin)
        .insert_resource(ClearColor(Color::WHITE))
        .add_systems(Startup, setup)
        .run();
}

fn setup(mut commands: Commands) {
    commands.spawn(Camera2d);
}

fn update_signal_stream(mut signal_data: ResMut<SignalRenderData>, time: Res<Time>) {
    // Generate a new vertex using a sine wave.
    let new_y = (0.5 * ((2.0 * PI * time.elapsed_secs())).sin() + 0.5) * 600.0;
    let new_x = match signal_data.vertices.last() {
        Some(last) => last.position[0] + 1.0,
        None => 0.0,
    };
    let new_vertex = SignalVertex {
        position: [new_x, new_y],
        color: [0.0, 0.0, 0.0, 1.0],
    };
    if signal_data.vertices.len() >= 50_000 {
        signal_data.vertices.remove(0);
    }
    signal_data.vertices.push(new_vertex);
}

#[derive(Resource, Clone, Default)]
struct ExtractedSignalData {
    vertices: Vec<SignalVertex>,
}

fn extract_signal_data(main_world: Res<MainWorld>, mut extracted_signal: ResMut<ExtractedSignalData>) {
    let main_signal = main_world.get_resource::<SignalRenderData>().unwrap();
    extracted_signal.vertices = main_signal.vertices.clone();
}
#

... and the shader:

@group(0) @binding(0)
var<uniform> uResolution: vec2<f32>;

@group(1) @binding(0)
var<uniform> uPointCount: u32;

const MAX_POINTS: u32 = 100u;
@group(1) @binding(1)
var<uniform> uPoints: array<vec4<f32>, MAX_POINTS>;

struct VertexInput {
    @location(0) position: vec2<f32>,
    @location(1) color: vec4<f32>,
};

struct VertexOutput {
    @builtin(position) clip_position: vec4<f32>,
    @location(0) uv: vec2<f32>,
    @location(1) color: vec4<f32>,
};

@vertex
fn vertex_main(input: VertexInput) -> VertexOutput {
    var output: VertexOutput;
    // Convert from pixel space to clip space.
    let clipSpace = (input.position / uResolution) * 2.0 - 1.0;
    output.clip_position = vec4<f32>(clipSpace * vec2<f32>(1.0, -1.0), 0.0, 1.0);
    output.uv = input.position / uResolution;
    output.color = input.color;
    return output;
}

fn distanceToSegment(p: vec2<f32>, a: vec2<f32>, b: vec2<f32>) -> f32 {
    let pa = p - a;
    let ba = b - a;
    let h = clamp(dot(pa, ba) / dot(ba, ba), 0.0, 1.0);
    return length(pa - ba * h);
}

@fragment
fn fragment_main(input: VertexOutput) -> @location(0) vec4<f32> {
    var min_dist = 1.0;
    let threshold = 0.01;
    for (var i: u32 = 0u; i < uPointCount - 1u; i = i + 1u) {
        let p0 = uPoints[i].xy;
        let p1 = uPoints[i + 1u].xy;
        let d = distanceToSegment(input.uv, p0, p1);
        min_dist = min(min_dist, d);
    }
    if (min_dist < threshold) {
        return input.color;
    }
    return vec4<f32>(1.0, 1.0, 1.0, 1.0);
}
#

I'm not really sure where to start here

hollow sierra
#

could you upload the cargo project as a zip?