use anyhow::Result;

use vulkanalia::prelude::v1_0::*;


use crate::app_data;
use crate::queue_family_indices;
use crate::scene;

pub unsafe fn create_command_pool(
    instance: &Instance,
    device: &Device,
    data: &mut app_data::AppData,
) -> Result<()> {
    let indices = queue_family_indices::QueueFamilyIndices::get(instance, data, data.physical_device)?;

    let info = vk::CommandPoolCreateInfo::builder()
        .flags(vk::CommandPoolCreateFlags::empty()) // Optional.
        .queue_family_index(indices.graphics);

    data.command_pool = device.create_command_pool(&info, None)?;

    Ok(())
}

pub unsafe fn create_command_buffers(device: &Device, data: &mut app_data::AppData, scene_handler: &scene::Scene) -> Result<()> {
    let allocate_info = vk::CommandBufferAllocateInfo::builder()
        .command_pool(data.command_pool)
        .level(vk::CommandBufferLevel::PRIMARY)
        .command_buffer_count(data.framebuffers.len() as u32);
    
    data.command_buffers = device.allocate_command_buffers(&allocate_info)?;

    for (i, command_buffer) in data.command_buffers.iter().enumerate() {
        let inheritance = vk::CommandBufferInheritanceInfo::builder();
    
        let info = vk::CommandBufferBeginInfo::builder()
            .flags(vk::CommandBufferUsageFlags::empty()) // Optional.
            .inheritance_info(&inheritance);             // Optional.
    
        device.begin_command_buffer(*command_buffer, &info)?;

        let render_area = vk::Rect2D::builder()
            .offset(vk::Offset2D::default())
            .extent(data.swapchain_extent);

        let color_clear_value = vk::ClearValue {
            color: vk::ClearColorValue {
                float32: [0.0, 0.0, 0.0, 1.0],
            },
        };

        let depth_clear_value = vk::ClearValue {
            depth_stencil: vk::ClearDepthStencilValue {
                depth: 1.0,
                stencil: 0,
            },
        };

        // define the compute load before going into the render pass
        if scene_handler.volumetrics.len() != 0 {
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::COMPUTE, data.pipeline_compute_rasterize); //todo build own pipeline

            device.cmd_bind_descriptor_sets(
                *command_buffer, 
                vk::PipelineBindPoint::COMPUTE, 
                data.pipeline_layout, 
                0, 
                &[data.descriptor_sets[i]],
                 &[]);
            
            device.cmd_dispatch(*command_buffer, (data.compute_task_one_size as f64 / 16.0).ceil() as u32, 1, 1);

            let buffer_memory_barrier_vertex = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_cuboid_buffers[i])
                .src_access_mask(vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::VERTEX_ATTRIBUTE_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();
            let buffer_memory_barrier_index = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_cuboid_index_buffers[i])
                .src_access_mask(vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::INDEX_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();

            device.cmd_pipeline_barrier(*command_buffer,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::PipelineStageFlags::VERTEX_INPUT,
                vk::DependencyFlags::DEVICE_GROUP,
                &[] as &[vk::MemoryBarrier],
                &[buffer_memory_barrier_index, buffer_memory_barrier_vertex],
                &[] as &[vk::ImageMemoryBarrier]);

            // compute storage barrier
            let buffer_memory_barrier_color = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_storage_buffers_color[i])
                .src_access_mask(vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::SHADER_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();

            let buffer_memory_barrier_transparent = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_storage_buffers_transparent[i])
                .src_access_mask(vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::SHADER_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();

            device.cmd_pipeline_barrier(*command_buffer,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::DependencyFlags::DEVICE_GROUP,
                &[] as &[vk::MemoryBarrier],
                &[buffer_memory_barrier_color, buffer_memory_barrier_transparent],
                &[] as &[vk::ImageMemoryBarrier]);
            // grow x axis
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::COMPUTE, data.pipeline_compute_grow_one);

            device.cmd_bind_descriptor_sets(
                *command_buffer, 
                vk::PipelineBindPoint::COMPUTE, 
                data.pipeline_layout, 
                0, 
                &[data.descriptor_sets[i]],
                    &[]);
            
            device.cmd_dispatch(*command_buffer, (data.compute_task_one_size as f64 / 16.0).ceil() as u32, 1, 1);

            let buffer_memory_barrier_out = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_storage_buffers_size_three[i])
                .src_access_mask(vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::SHADER_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();

            device.cmd_pipeline_barrier(*command_buffer,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::DependencyFlags::DEVICE_GROUP,
                &[] as &[vk::MemoryBarrier],
                &[buffer_memory_barrier_out],
                &[] as &[vk::ImageMemoryBarrier]);
            // grow y axis
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::COMPUTE, data.pipeline_compute_grow_two);

            device.cmd_bind_descriptor_sets(
                *command_buffer, 
                vk::PipelineBindPoint::COMPUTE, 
                data.pipeline_layout, 
                0, 
                &[data.descriptor_sets[i]],
                    &[]);
            
            device.cmd_dispatch(*command_buffer, (data.compute_task_one_size as f64 / 16.0).ceil() as u32, 1, 1);

            let buffer_memory_barrier_out = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_storage_buffers_size_two[i])
                .src_access_mask(vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::SHADER_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();

            device.cmd_pipeline_barrier(*command_buffer,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::DependencyFlags::DEVICE_GROUP,
                &[] as &[vk::MemoryBarrier],
                &[buffer_memory_barrier_out],
                &[] as &[vk::ImageMemoryBarrier]);

            // grow z axis
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::COMPUTE, data.pipeline_compute_grow_three);

            device.cmd_bind_descriptor_sets(
                *command_buffer, 
                vk::PipelineBindPoint::COMPUTE, 
                data.pipeline_layout, 
                0, 
                &[data.descriptor_sets[i]],
                    &[]);
            
            device.cmd_dispatch(*command_buffer, (data.compute_task_one_size as f64 / 16.0).ceil() as u32, 1, 1);

            let buffer_memory_barrier_out = vk::BufferMemoryBarrier::builder()
                .buffer(data.compute_out_storage_buffers_size_three[i])
                .src_access_mask(vk::AccessFlags::SHADER_WRITE)
                .dst_access_mask(vk::AccessFlags::SHADER_READ)
                .size(vk::WHOLE_SIZE as u64)
                .build();

            device.cmd_pipeline_barrier(*command_buffer,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::PipelineStageFlags::COMPUTE_SHADER,
                vk::DependencyFlags::DEVICE_GROUP,
                &[] as &[vk::MemoryBarrier],
                &[buffer_memory_barrier_out],
                &[] as &[vk::ImageMemoryBarrier]);
        }
        // start render pass
        let clear_values = &[color_clear_value, depth_clear_value];
        let info = vk::RenderPassBeginInfo::builder()
            .render_pass(data.render_pass)
            .framebuffer(data.framebuffers[i])
            .render_area(render_area)
            .clear_values(clear_values);

        device.cmd_begin_render_pass(
            *command_buffer, &info, vk::SubpassContents::INLINE);
        
        if scene_handler.vertices.len() != 0 {
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::GRAPHICS, data.pipeline_cube);
        
            device.cmd_bind_vertex_buffers(*command_buffer, 0, &[scene_handler.vertex_buffer_cube], &[0]);
            device.cmd_bind_index_buffer(*command_buffer, scene_handler.index_buffer_cube, 0, vk::IndexType::UINT32);
        
        
            device.cmd_bind_descriptor_sets(
                *command_buffer,
                vk::PipelineBindPoint::GRAPHICS,
                data.pipeline_layout,
                0,
                &[data.descriptor_sets[i]],
                &[],
            );

            device.cmd_draw_indexed(*command_buffer, scene_handler.indices_cube.len() as u32, 1, 0, 0, 0);
        }

        if scene_handler.sized_vertices.len() != 0 {
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::GRAPHICS, data.pipeline_cuboid);
        
            device.cmd_bind_vertex_buffers(*command_buffer, 0, &[scene_handler.vertex_buffer_cuboid], &[0]);
            device.cmd_bind_index_buffer(*command_buffer, scene_handler.index_buffer_cuboid, 0, vk::IndexType::UINT32);
        
        
            device.cmd_bind_descriptor_sets(
                *command_buffer,
                vk::PipelineBindPoint::GRAPHICS,
                data.pipeline_layout,
                0,
                &[data.descriptor_sets[i]],
                &[],
            );

            device.cmd_draw_indexed(*command_buffer, scene_handler.indices_cuboid.len() as u32, 1, 0, 0, 0);
        }
        // draw sized vertices from compute shader
        if scene_handler.volumetrics.len() != 0 {
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::GRAPHICS, data.pipeline_cuboid);
        
            device.cmd_bind_vertex_buffers(*command_buffer, 0, &[data.compute_out_cuboid_buffers[i]], &[0]);
            device.cmd_bind_index_buffer(*command_buffer, data.compute_out_cuboid_index_buffers[i], 0, vk::IndexType::UINT32);
        
        
            device.cmd_bind_descriptor_sets(
                *command_buffer,
                vk::PipelineBindPoint::GRAPHICS,
                data.pipeline_layout,
                0,
                &[data.descriptor_sets[i]],
                &[],
            );
            
            device.cmd_draw_indexed(*command_buffer,  (data.compute_task_one_out_size * 36) as u32, 1, 0, 0, 0);
        }

        if scene_handler.rt_vertices.len() != 0 {
            device.cmd_bind_pipeline(
                *command_buffer, vk::PipelineBindPoint::GRAPHICS, data.pipeline_quad); //todo build own pipeline
        
            device.cmd_bind_vertex_buffers(*command_buffer, 0, &[scene_handler.vertex_buffer_quad], &[0]);
            device.cmd_bind_index_buffer(*command_buffer, scene_handler.index_buffer_quad, 0, vk::IndexType::UINT32);
        
        
            device.cmd_bind_descriptor_sets(
                *command_buffer,
                vk::PipelineBindPoint::GRAPHICS,
                data.pipeline_layout,
                0,
                &[data.descriptor_sets[i]],
                &[],
            );

            device.cmd_draw_indexed(*command_buffer, scene_handler.indices_rt.len() as u32, 1, 0, 0, 0);
        }

        device.cmd_end_render_pass(*command_buffer);

        device.end_command_buffer(*command_buffer)?;
    }

    Ok(())
}

pub unsafe fn begin_single_time_commands(
    device: &Device,
    data: &app_data::AppData,
) -> Result<vk::CommandBuffer> {
    let info = vk::CommandBufferAllocateInfo::builder()
        .level(vk::CommandBufferLevel::PRIMARY)
        .command_pool(data.command_pool)
        .command_buffer_count(1);

    let command_buffer = device.allocate_command_buffers(&info)?[0];

    let info = vk::CommandBufferBeginInfo::builder()
        .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT);

    device.begin_command_buffer(command_buffer, &info)?;

    Ok(command_buffer)
}

pub unsafe fn end_single_time_commands(
    device: &Device,
    data: &app_data::AppData,
    command_buffer: vk::CommandBuffer,
) -> Result<()> {
    device.end_command_buffer(command_buffer)?;

    let command_buffers = &[command_buffer];
    let info = vk::SubmitInfo::builder()
        .command_buffers(command_buffers);

    device.queue_submit(data.graphics_queue, &[info], vk::Fence::null())?;
    device.queue_wait_idle(data.graphics_queue)?;

    device.free_command_buffers(data.command_pool, &[command_buffer]);

    Ok(())
}