#version 450

layout(binding = 0) uniform UniformBufferObject {
    mat4 model;
    mat4 geom_rot;
    mat4 view;
    mat4 proj;
    vec3 camera_pos;
    bool[16] use_geom_shader;
} ubo;

layout(binding = 3) readonly buffer SceneInfoBuffer {
   uint compounds[];
};

layout(binding = 4) buffer SceneInfoBuffer2 {
   uint grid[];
};

layout(binding = 5) buffer SizedVertices {
   float vertices[];
};

layout(binding = 6) buffer Indices {
   uint indices[];
};

layout(binding = 9) buffer transparencies {
   bool transparent_grid[];
};

layout (local_size_x = 16, local_size_y = 1, local_size_z = 1) in;

uvec4 unpack_color(uint val) {
    // left most 8 bits first
    uint val1 = (val >> 24);
    uint val2 = (val << 8) >> 24;
    uint val3 = (val << 16) >> 24;
    uint val4 = (val << 24) >> 24;

    return uvec4(val4, val3, val2, val1);
}

void add_cube(uint cube_num, float scale, vec3 pos, vec3 color) {
    // add node info for the cube
    //vertice 0
    vertices[(cube_num * 8 + 0) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 0) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 0) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 0) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 0) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 0) * 11 + 5] = color.z;

    //vertice 1
    vertices[(cube_num * 8 + 1) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 1) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 1) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 1) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 1) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 1) * 11 + 5] = color.z;

    //vertice 2
    vertices[(cube_num * 8 + 2) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 2) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 2) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 2) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 2) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 2) * 11 + 5] = color.z;

    //vertice 3
    vertices[(cube_num * 8 + 3) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 3) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 3) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 3) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 3) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 3) * 11 + 5] = color.z;

    //vertice 4
    vertices[(cube_num * 8 + 4) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 4) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 4) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 4) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 4) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 4) * 11 + 5] = color.z;

    //vertice 5
    vertices[(cube_num * 8 + 5) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 5) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 5) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 5) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 5) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 5) * 11 + 5] = color.z;
    
    //vertice 6
    vertices[(cube_num * 8 + 6) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 6) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 6) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 6) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 6) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 6) * 11 + 5] = color.z;

    //vertice 7
    vertices[(cube_num * 8 + 7) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 7) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 7) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 7) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 7) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 7) * 11 + 5] = color.z;

    //add indices for the cube
    //top
    indices[cube_num * 36 + 0] = cube_num * 8 + 3;
    indices[cube_num * 36 + 1] = cube_num * 8 + 0;
    indices[cube_num * 36 + 2] = cube_num * 8 + 2;

    indices[cube_num * 36 + 3] = cube_num * 8 + 3;
    indices[cube_num * 36 + 4] = cube_num * 8 + 1;
    indices[cube_num * 36 + 5] = cube_num * 8 + 0;

    //bottom
    indices[cube_num * 36 + 6] = cube_num * 8 + 6;
    indices[cube_num * 36 + 7] = cube_num * 8 + 4;
    indices[cube_num * 36 + 8] = cube_num * 8 + 7;

    indices[cube_num * 36 + 9] = cube_num * 8 + 4;
    indices[cube_num * 36 + 10] = cube_num * 8 + 5;
    indices[cube_num * 36 + 11] = cube_num * 8 + 7;

    //left
    indices[cube_num * 36 + 12] = cube_num * 8 + 0;
    indices[cube_num * 36 + 13] = cube_num * 8 + 4;
    indices[cube_num * 36 + 14] = cube_num * 8 + 2;

    indices[cube_num * 36 + 15] = cube_num * 8 + 6;
    indices[cube_num * 36 + 16] = cube_num * 8 + 2;
    indices[cube_num * 36 + 17] = cube_num * 8 + 4;

    //right
    indices[cube_num * 36 + 18] = cube_num * 8 + 1;
    indices[cube_num * 36 + 19] = cube_num * 8 + 3;
    indices[cube_num * 36 + 20] = cube_num * 8 + 5;

    indices[cube_num * 36 + 21] = cube_num * 8 + 5;
    indices[cube_num * 36 + 22] = cube_num * 8 + 3;
    indices[cube_num * 36 + 23] = cube_num * 8 + 7;

    //near
    indices[cube_num * 36 + 24] = cube_num * 8 + 6;
    indices[cube_num * 36 + 25] = cube_num * 8 + 3;
    indices[cube_num * 36 + 26] = cube_num * 8 + 2;

    indices[cube_num * 36 + 27] = cube_num * 8 + 3;
    indices[cube_num * 36 + 28] = cube_num * 8 + 6;
    indices[cube_num * 36 + 29] = cube_num * 8 + 7;

    //far
    indices[cube_num * 36 + 30] = cube_num * 8 + 0;
    indices[cube_num * 36 + 31] = cube_num * 8 + 1;
    indices[cube_num * 36 + 32] = cube_num * 8 + 4;

    indices[cube_num * 36 + 33] = cube_num * 8 + 5;
    indices[cube_num * 36 + 34] = cube_num * 8 + 4;
    indices[cube_num * 36 + 35] = cube_num * 8 + 1;

}

void main() {
    uint index = gl_GlobalInvocationID.x;
    uint output_offset = 0;
    uint compound_start = 0;
    // iterate over the compounds and find the work index inside of it
    while (index > compounds[compound_start] * compounds[compound_start]) {
        output_offset += compounds[compound_start] * compounds[compound_start] * compounds[compound_start];
        index -= compounds[compound_start] * compounds[compound_start];
        compound_start = compounds[compound_start + 2];
    }
    // grid pos in the task
    uint compound_grid_size = compounds[compound_start];
    float compound_scale = uintBitsToFloat(compounds[compound_start + 1]);
    vec3 mid_offset = vec3(compound_scale * 0.5, compound_scale * 0.5, compound_scale * 0.5);
    uint y = index % compound_grid_size;
    uint x = (index - y) / compound_grid_size;
    vec3 compound_pos = vec3(uintBitsToFloat(compounds[compound_start + 5]), uintBitsToFloat(compounds[compound_start + 6]), uintBitsToFloat(compounds[compound_start + 7]));
    // iterate upwards along the z axis
    for (uint z=0; z < compound_grid_size; z++) {
        // iterate over the included shapes
        vec3 check_pos = compound_pos + vec3(float(x) * compound_scale, float(y) * compound_scale, float(z) * compound_scale) + mid_offset;
        uint color_int;
        uvec4 color_roughness;
        bool render = false;
        vec3 color = vec3(0.0, 0.0, 1.0);
        bool transparent = false;
        //handle included shapes
        for (uint o=0; o < compounds[compound_start + 3]; o++) {
            uint component_index = compounds[compound_start + 9 + o];
            uint component_type = compounds[component_index];
            vec3 component_pos = vec3(uintBitsToFloat(compounds[component_index + 1]), uintBitsToFloat(compounds[component_index + 2]), uintBitsToFloat(compounds[component_index + 3]));
            vec3 component_rot = vec3(uintBitsToFloat(compounds[component_index + 4]), uintBitsToFloat(compounds[component_index + 5]), uintBitsToFloat(compounds[component_index + 6]));
            mat3 component_rot_mat = mat3(
                    vec3(1.0, 0.0, 0.0),
                    vec3(0.0, cos(component_rot.x), sin(component_rot.x)),
                    vec3(0.0, -sin(component_rot.x), cos(component_rot.x))
                ) * mat3(
                    vec3(cos(component_rot.y), 0.0, sin(component_rot.y)),
                    vec3(0.0, 1.0, 0.0),
                    vec3(-sin(component_rot.y), 0.0, cos(component_rot.y))
                ) * mat3(
                    vec3(cos(component_rot.z), sin(component_rot.z), 0.0),
                    vec3(-sin(component_rot.z), cos(component_rot.y), 0.0),
                    vec3(0.0, 0.0, 1.0)
                );

            color_int = compounds[component_index + 7];
            uvec4 component_color = unpack_color(color_int);

            transparent = compounds[component_index + 8] != 0;

            if (component_type == 0) {
                // handle sphere
                float radius = uintBitsToFloat(compounds[component_index + 9]);

                render = length(component_pos - check_pos) <= radius;
                if (render) {
                    color = vec3(float(component_color.x) / 255.0, float(component_color.y) / 255.0, float(component_color.z) / 255.0);
                    break;
                }
                continue;
            }

            if (component_type == 1) {
                // handle cone
                float radius1 = uintBitsToFloat(compounds[component_index + 9]);
                float radius2 = uintBitsToFloat(compounds[component_index + 10]);
                vec3 direction = component_rot_mat * vec3(uintBitsToFloat(compounds[component_index + 11]), uintBitsToFloat(compounds[component_index + 12]), uintBitsToFloat(compounds[component_index + 13]));

                vec3 diff = check_pos - component_pos;
                float factor = dot(direction, diff) / dot(direction, direction);
                
                vec3 n = diff - factor * direction;
                float radius = radius1 * (1.0 - factor) + radius2 * factor;

                render = length(n) <= radius && 0 <= factor && factor <= 1.0;
                if (render) {
                    color = vec3(float(component_color.x) / 255.0, float(component_color.y) / 255.0, float(component_color.z) / 255.0);
                    break;
                }
                continue;
            }

            if (component_type == 2) {
                // handle cone
                vec3 size = vec3(uintBitsToFloat(compounds[component_index + 9]), uintBitsToFloat(compounds[component_index + 10]), uintBitsToFloat(compounds[component_index + 11]));
                vec3 direction1 = component_rot_mat * vec3(size.x, 0.0, 0.0) / 2.0;
                vec3 direction2 = component_rot_mat * vec3(0.0, size.y, 0.0) / 2.0;
                vec3 direction3 = component_rot_mat * vec3(0.0, 0.0, size.z) / 2.0;

                vec3 diff = check_pos - component_pos;
                float factor1 = dot(direction1, diff) / dot(direction1, direction1);
                float factor2 = dot(direction2, diff) / dot(direction2, direction2);
                float factor3 = dot(direction3, diff) / dot(direction3, direction3);
                render = (-1.0 <= factor1 && factor1 <= 1.0) && (-1.0 <= factor2 && factor2 <= 1.0) && (-1.0 <= factor3 && factor3 <= 1.0);
                if (render) {
                    color = vec3(float(component_color.x) / 255.0, float(component_color.y) / 255.0, float(component_color.z) / 255.0);
                    break;
                }
                continue;
            }

            
        }
        //handle excluded shapes
        for (uint o=0; o < compounds[compound_start + 4]; o++) {
            uint component_index = compounds[compound_start + 9 + compounds[compound_start + 3] + o];
            uint component_type = compounds[component_index];
            vec3 component_pos = vec3(uintBitsToFloat(compounds[component_index + 1]), uintBitsToFloat(compounds[component_index + 2]), uintBitsToFloat(compounds[component_index + 3]));
            vec3 component_rot = vec3(uintBitsToFloat(compounds[component_index + 4]), uintBitsToFloat(compounds[component_index + 5]), uintBitsToFloat(compounds[component_index + 6]));
            mat3 component_rot_mat = mat3(
                    vec3(1.0, 0.0, 0.0),
                    vec3(0.0, cos(component_rot.x), sin(component_rot.x)),
                    vec3(0.0, -sin(component_rot.x), cos(component_rot.x))
                ) * mat3(
                    vec3(cos(component_rot.y), 0.0, sin(component_rot.y)),
                    vec3(0.0, 1.0, 0.0),
                    vec3(-sin(component_rot.y), 0.0, cos(component_rot.y))
                ) * mat3(
                    vec3(cos(component_rot.z), sin(component_rot.z), 0.0),
                    vec3(-sin(component_rot.z), cos(component_rot.y), 0.0),
                    vec3(0.0, 0.0, 1.0)
                );
            uvec4 color = unpack_color(compounds[component_index + 7]);

            if (component_type == 0) {
                // handle sphere
                float radius = uintBitsToFloat(compounds[component_index + 9]);

                render = render && !(length(component_pos - check_pos) <= radius);
                if (!render) {
                    break;
                }
                continue;
            }

            if (component_type == 1) {
                // handle cone
                float radius1 = uintBitsToFloat(compounds[component_index + 9]);
                float radius2 = uintBitsToFloat(compounds[component_index + 10]);
                vec3 direction = component_rot_mat * vec3(uintBitsToFloat(compounds[component_index + 11]), uintBitsToFloat(compounds[component_index + 12]), uintBitsToFloat(compounds[component_index + 13]));

                vec3 diff = check_pos - component_pos;
                float factor = dot(direction, diff) / dot(direction, direction);
                
                vec3 n = diff - factor * direction;
                float radius = radius1 * (1.0 - factor) + radius2 * factor;

                render = render && !(length(n) <= radius && 0 <= factor && factor <= 1.0);
                if (!render) {
                    break;
                }
                continue;
            }

            if (component_type == 2) {
                // handle cone
                vec3 size = vec3(uintBitsToFloat(compounds[component_index + 9]), uintBitsToFloat(compounds[component_index + 10]), uintBitsToFloat(compounds[component_index + 11]));
                vec3 direction1 = component_rot_mat * vec3(size.x, 0.0, 0.0) / 2.0;
                vec3 direction2 = component_rot_mat * vec3(0.0, size.y, 0.0) / 2.0;
                vec3 direction3 = component_rot_mat * vec3(0.0, 0.0, size.z) / 2.0;

                vec3 diff = check_pos - component_pos;
                float factor1 = dot(direction1, diff) / dot(direction1, direction1);
                float factor2 = dot(direction2, diff) / dot(direction2, direction2);
                float factor3 = dot(direction3, diff) / dot(direction3, direction3);
                render = render && !((-1.0 <= factor1 && factor1 <= 1.0) && (-1.0 <= factor2 && factor2 <= 1.0) && (-1.0 <= factor3 && factor3 <= 1.0));
                if (!render) {
                    break;
                }
                continue;
            }
        }

        if (render) {
            grid[output_offset + x * compound_grid_size * compound_grid_size + y * compound_grid_size + z] = color_int;
            transparent_grid[output_offset + x * compound_grid_size * compound_grid_size + y * compound_grid_size + z] = transparent;
            add_cube(output_offset + index * compound_grid_size + z, compound_scale, check_pos, color);
        } else {
            grid[output_offset + x * compound_grid_size * compound_grid_size + y * compound_grid_size + z] = 0;
            transparent_grid[output_offset + x * compound_grid_size * compound_grid_size + y * compound_grid_size + z] = false;
        }
    }
}