#version 450

layout(binding = 0) uniform UniformBufferObject {
    mat4 model;
    mat4 geom_rot;
    mat4 view;
    mat4 proj;
    vec3 camera_pos;
    bool[16] use_geom_shader;
} ubo;

layout(binding = 3) readonly buffer CompoundBuffer {
   uint compounds[];
};

layout(binding = 4) readonly buffer ColorBuffer {
   uint grid_in[];
};

layout(binding = 9) readonly buffer TransparentBuffer {
   bool transparent_grid[];
};

layout(binding = 8) readonly buffer SizeBuffer3D {
   uint grid_size_in[];
};

layout(binding = 10) buffer OutMemory {
   uint out_memory[];
};

layout(binding = 2) readonly buffer SceneInfoBuffer{
     uint infos[]; 
} scene_info;

uint max_num_lights = scene_info.infos[0];

layout (local_size_x = 16, local_size_y = 1, local_size_z = 1) in;

uint num_nodes(uint size) {
   uint nodes = 0;
   uint add_size = 1;
   while (size >= 2) {
      nodes += add_size;
      add_size = add_size * 8;
      size = size / 2;
   }

   return nodes;
}

layout(binding = 5) buffer SizedVertices {
   float vertices[];
};

layout(binding = 6) buffer Indices {
   uint indices[];
};

vec3 unpack_color(uint val) {
    // left most 8 bits first
    uint val1 = (val >> 24);
    uint val2 = (val << 8) >> 24;
    uint val3 = (val << 16) >> 24;
    uint val4 = (val << 24) >> 24;

    return vec3(val4 / 255.0, val3 / 255.0, val2 / 255.0);
}

void add_cube(uint cube_num, float scale, vec3 pos, vec3 color) {
    // add node info for the cube
    //vertice 0
    vertices[(cube_num * 8 + 0) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 0) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 0) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 0) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 0) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 0) * 11 + 5] = color.z;

    //vertice 1
    vertices[(cube_num * 8 + 1) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 1) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 1) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 1) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 1) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 1) * 11 + 5] = color.z;

    //vertice 2
    vertices[(cube_num * 8 + 2) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 2) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 2) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 2) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 2) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 2) * 11 + 5] = color.z;

    //vertice 3
    vertices[(cube_num * 8 + 3) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 3) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 3) * 11 + 2] = pos.z + 0.5 * scale;

    vertices[(cube_num * 8 + 3) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 3) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 3) * 11 + 5] = color.z;

    //vertice 4
    vertices[(cube_num * 8 + 4) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 4) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 4) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 4) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 4) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 4) * 11 + 5] = color.z;

    //vertice 5
    vertices[(cube_num * 8 + 5) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 5) * 11 + 1] = pos.y + 0.5 * scale;
    vertices[(cube_num * 8 + 5) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 5) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 5) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 5) * 11 + 5] = color.z;
    
    //vertice 6
    vertices[(cube_num * 8 + 6) * 11 + 0] = pos.x - 0.5 * scale;
    vertices[(cube_num * 8 + 6) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 6) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 6) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 6) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 6) * 11 + 5] = color.z;

    //vertice 7
    vertices[(cube_num * 8 + 7) * 11 + 0] = pos.x + 0.5 * scale;
    vertices[(cube_num * 8 + 7) * 11 + 1] = pos.y - 0.5 * scale;
    vertices[(cube_num * 8 + 7) * 11 + 2] = pos.z - 0.5 * scale;

    vertices[(cube_num * 8 + 7) * 11 + 3] = color.x;
    vertices[(cube_num * 8 + 7) * 11 + 4] = color.y;
    vertices[(cube_num * 8 + 7) * 11 + 5] = color.z;

    //add indices for the cube
    //top
    indices[cube_num * 36 + 0] = cube_num * 8 + 3;
    indices[cube_num * 36 + 1] = cube_num * 8 + 0;
    indices[cube_num * 36 + 2] = cube_num * 8 + 2;

    indices[cube_num * 36 + 3] = cube_num * 8 + 3;
    indices[cube_num * 36 + 4] = cube_num * 8 + 1;
    indices[cube_num * 36 + 5] = cube_num * 8 + 0;

    //bottom
    indices[cube_num * 36 + 6] = cube_num * 8 + 6;
    indices[cube_num * 36 + 7] = cube_num * 8 + 4;
    indices[cube_num * 36 + 8] = cube_num * 8 + 7;

    indices[cube_num * 36 + 9] = cube_num * 8 + 4;
    indices[cube_num * 36 + 10] = cube_num * 8 + 5;
    indices[cube_num * 36 + 11] = cube_num * 8 + 7;

    //left
    indices[cube_num * 36 + 12] = cube_num * 8 + 0;
    indices[cube_num * 36 + 13] = cube_num * 8 + 4;
    indices[cube_num * 36 + 14] = cube_num * 8 + 2;

    indices[cube_num * 36 + 15] = cube_num * 8 + 6;
    indices[cube_num * 36 + 16] = cube_num * 8 + 2;
    indices[cube_num * 36 + 17] = cube_num * 8 + 4;

    //right
    indices[cube_num * 36 + 18] = cube_num * 8 + 1;
    indices[cube_num * 36 + 19] = cube_num * 8 + 3;
    indices[cube_num * 36 + 20] = cube_num * 8 + 5;

    indices[cube_num * 36 + 21] = cube_num * 8 + 5;
    indices[cube_num * 36 + 22] = cube_num * 8 + 3;
    indices[cube_num * 36 + 23] = cube_num * 8 + 7;

    //near
    indices[cube_num * 36 + 24] = cube_num * 8 + 6;
    indices[cube_num * 36 + 25] = cube_num * 8 + 3;
    indices[cube_num * 36 + 26] = cube_num * 8 + 2;

    indices[cube_num * 36 + 27] = cube_num * 8 + 3;
    indices[cube_num * 36 + 28] = cube_num * 8 + 6;
    indices[cube_num * 36 + 29] = cube_num * 8 + 7;

    //far
    indices[cube_num * 36 + 30] = cube_num * 8 + 0;
    indices[cube_num * 36 + 31] = cube_num * 8 + 1;
    indices[cube_num * 36 + 32] = cube_num * 8 + 4;

    indices[cube_num * 36 + 33] = cube_num * 8 + 5;
    indices[cube_num * 36 + 34] = cube_num * 8 + 4;
    indices[cube_num * 36 + 35] = cube_num * 8 + 1;

}

uint cohort_index_from_pos(uint x, uint y, uint z, uint block_size, uint compound_size) {
   uint steps = compound_size / block_size;
   return (z / block_size) * (steps*steps) + (y / block_size) * steps + (x / block_size);
}

void main() {
   uint index = gl_GlobalInvocationID.x;
   uint output_offset = 1;
   uint input_offset = 0;
   uint compound_start = 1;

   uint nodes = num_nodes(compounds[compound_start]);
   // iterate over the compounds and find the work index inside of it
   while (index > nodes) {
      input_offset += compounds[compound_start] * compounds[compound_start] * compounds[compound_start];
      index -= nodes;
      compound_start = compounds[compound_start + 2];
      nodes = num_nodes(compounds[compound_start]);
   }

   output_offset = compounds[compound_start + 8];

   uint compound_grid_size = compounds[compound_start];
   uint parent_start = 0;
   uint cohort_start = 0;
   uint cohort_index = index;
   uint size = compounds[compound_start];
   nodes = 0;
   uint add_size = 1;
   while (cohort_index >= add_size) {
      nodes += add_size;
      cohort_index -= add_size;
      parent_start = cohort_start;
      cohort_start = nodes * 9;
      add_size *= 8;
      size = size / 2;
   }

   uint steps = compounds[compound_start] / size;

   float compound_scale = uintBitsToFloat(compounds[compound_start + 1]);
   vec3 mid_offset = vec3(compound_scale * 0.5, compound_scale * 0.5, compound_scale * 0.5);
   
   uint x_no_offset = (cohort_index % steps) * size;
   uint y_no_offset = (((cohort_index - (cohort_index % steps)) % (steps * steps)) / (steps)) * size;
   uint z_no_offset = (((cohort_index - (cohort_index % (steps * steps)))) / (steps * steps)) * size;

   uint parent_size = size * 2;
   uint parent_steps = compounds[compound_start] / parent_size;
   uint x_parent = uint(floor(float(x_no_offset) / float(parent_size))) * parent_size;
   uint y_parent = uint(floor(float(y_no_offset) / float(parent_size))) * parent_size;
   uint z_parent = uint(floor(float(z_no_offset) / float(parent_size))) * parent_size;

   uint parent = output_offset + parent_start + cohort_index_from_pos(x_parent, y_parent, z_parent, parent_size, compound_grid_size) * 9;;
   if (size == compounds[compound_start]) {
      parent = 0;
   }

   // plus one size offset, since we want to place the nodes at the far end. This aligns with the iteration directions in the previous shaders
   uint x = x_no_offset + (size - 1);
   uint y = y_no_offset + (size - 1);
   uint z = z_no_offset + (size - 1);

   // sum of all elements with coordinates lower than x, y, z
   uint contained_entries = grid_size_in[input_offset + x * compound_grid_size * compound_grid_size + y * compound_grid_size + z];
   if (z > size) {
      // remove contained from z neighbor
      contained_entries = contained_entries - grid_size_in[input_offset + x * compound_grid_size * compound_grid_size + y * compound_grid_size + z - size];
   }

   if (y > size) {
      if (z > size) {
         // add back the section we will remove twice
         contained_entries = contained_entries + int(grid_size_in[input_offset + x * compound_grid_size * compound_grid_size + (y - size) * compound_grid_size + z - size]);
      }
      // remove contained from y neighbor
      contained_entries = contained_entries - int(grid_size_in[input_offset + x * compound_grid_size * compound_grid_size + (y - size) * compound_grid_size + z]);
   }

   if (x > size) {
      if (z > size) {
         // add the portion already removed through the z neighbor
         contained_entries = contained_entries + grid_size_in[input_offset + (x - size) * compound_grid_size * compound_grid_size + y * compound_grid_size + z - size];
      }

      if (y > size) {
         // add the portion already removed by the y neighbor
         contained_entries = contained_entries + grid_size_in[input_offset + (x - size) * compound_grid_size * compound_grid_size + (y - size) * compound_grid_size + z];

         if (z > size) {
            // remove the portion already added through the z neighbor
            contained_entries = contained_entries - grid_size_in[input_offset + (x - size) * compound_grid_size * compound_grid_size + (y - size) * compound_grid_size + z - size];
         }
      }

      // remove contained from x neighbor
      contained_entries = contained_entries - grid_size_in[input_offset + (x - size) * compound_grid_size * compound_grid_size + y * compound_grid_size + z];
   }

   if (contained_entries > 0) {
      out_memory[output_offset + cohort_start + cohort_index * 9 + 0] = parent;

      if (size > 2) {
         // add child node reference
         uint child_size = size / 2;
         uint cohort_end = cohort_start + 9 * add_size;
         out_memory[output_offset + cohort_start + cohort_index * 9 + 1] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset,              y_no_offset,              z_no_offset,              child_size, compound_grid_size) * 9; // xyz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 2] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset + child_size, y_no_offset,              z_no_offset,              child_size, compound_grid_size) * 9; // Xyz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 3] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset,              y_no_offset + child_size, z_no_offset,              child_size, compound_grid_size) * 9; // xYz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 4] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset + child_size, y_no_offset + child_size, z_no_offset,              child_size, compound_grid_size) * 9; // XYz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 5] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset,              y_no_offset,              z_no_offset + child_size, child_size, compound_grid_size) * 9; // xyZ
         out_memory[output_offset + cohort_start + cohort_index * 9 + 6] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset + child_size, y_no_offset,              z_no_offset + child_size, child_size, compound_grid_size) * 9; // XyZ
         out_memory[output_offset + cohort_start + cohort_index * 9 + 7] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset,              y_no_offset + child_size, z_no_offset + child_size, child_size, compound_grid_size) * 9; // xYZ
         out_memory[output_offset + cohort_start + cohort_index * 9 + 8] = output_offset + cohort_end + cohort_index_from_pos(x_no_offset + child_size, y_no_offset + child_size, z_no_offset + child_size, child_size, compound_grid_size) * 9; // XYZ

      } else {
         // copy color values and add cubes to rendering
         out_memory[output_offset + cohort_start + cohort_index * 9 + 1] = grid_in[input_offset + (x - 1) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (z - 1)]; // xyz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 2] = grid_in[input_offset + (x - 0) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (z - 1)]; // Xyz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 3] = grid_in[input_offset + (x - 1) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (z - 1)]; // xYz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 4] = grid_in[input_offset + (x - 0) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (z - 1)]; // XYz
         out_memory[output_offset + cohort_start + cohort_index * 9 + 5] = grid_in[input_offset + (x - 1) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (z - 0)]; // xyZ
         out_memory[output_offset + cohort_start + cohort_index * 9 + 6] = grid_in[input_offset + (x - 0) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (z - 0)]; // XyZ
         out_memory[output_offset + cohort_start + cohort_index * 9 + 7] = grid_in[input_offset + (x - 1) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (z - 0)]; // xYZ
         out_memory[output_offset + cohort_start + cohort_index * 9 + 8] = grid_in[input_offset + (x - 0) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (z - 0)]; // XYZ
      
         vec3 compound_pos = vec3(uintBitsToFloat(compounds[compound_start + 5]), uintBitsToFloat(compounds[compound_start + 6]), uintBitsToFloat(compounds[compound_start + 7]));
         vec3 check_pos = compound_pos + vec3(float(x) * compound_scale, float(y) * compound_scale, float(z) * compound_scale) + mid_offset;
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 1] != 0) {
            add_cube(input_offset + (z - 1) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (x - 1), compound_scale, check_pos - vec3(1.0, 1.0, 1.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 1]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 2] != 0) {
            add_cube(input_offset + (z - 1) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (x - 0), compound_scale, check_pos - vec3(0.0, 1.0, 1.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 2]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 3] != 0) {
            add_cube(input_offset + (z - 1) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (x - 1), compound_scale, check_pos - vec3(1.0, 0.0, 1.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 3]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 4] != 0) {
            add_cube(input_offset + (z - 1) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (x - 0), compound_scale, check_pos - vec3(0.0, 0.0, 1.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 4]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 5] != 0) {
            add_cube(input_offset + (z - 0) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (x - 1), compound_scale, check_pos - vec3(1.0, 1.0, 0.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 5]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 6] != 0) {
            add_cube(input_offset + (z - 0) * compound_grid_size * compound_grid_size + (y - 1) * compound_grid_size + (x - 0), compound_scale, check_pos - vec3(0.0, 1.0, 0.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 6]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 7] != 0) {
            add_cube(input_offset + (z - 0) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (x - 1), compound_scale, check_pos - vec3(1.0, 0.0, 0.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 7]));
         }
         if (out_memory[output_offset + cohort_start + cohort_index * 9 + 8] != 0) {
            add_cube(input_offset + (z - 0) * compound_grid_size * compound_grid_size + (y - 0) * compound_grid_size + (x - 0), compound_scale, check_pos - vec3(0.0, 0.0, 0.0) * compound_scale, unpack_color(out_memory[output_offset + cohort_start + cohort_index * 9 + 8]));
         }
      }
   } else {
      out_memory[output_offset + cohort_start + cohort_index * 9 + 0] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 1] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 2] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 3] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 4] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 5] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 6] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 7] = 0;
      out_memory[output_offset + cohort_start + cohort_index * 9 + 8] = 0;
   }
}