#version 450

layout(location = 0) flat in uvec2 fragRasterPos;
layout(location = 1) flat in uint fragVolumeStart;
layout(location = 2) in vec3 origPosition;
layout(location = 3) flat in uint facing;

layout(location = 0) out vec4 outColor;

layout(binding = 0) uniform UniformBufferObject {
    mat4 model;
    mat4 geom_rot;
    mat4 view;
    mat4 proj;
    vec3 camera_pos;
    bool[16] use_geom_shader;
} ubo;

// 0 - location for the maximum number of lights referenced per chunk (also will be the invalid memory allocation for pointing to a nonexistant neighbor)
// 1 - location for the max iterations per light
// 2 - diffuse raster samples (2*n + 1) * (2*n + 1) so as to always have at least the central fragment covered
// 3 - diffuse raster size (float, needs to be decoded)
// 4 - max recursive rays
// 5 - diffuse rays per hit
layout(binding = 2) buffer SceneInfoBuffer{
     uint infos[]; 
} scene_info;

uvec4 unpack_color(uint val) {
    // left most 8 bits first
    uint val1 = (val >> 24);
    uint val2 = (val << 8) >> 24;
    uint val3 = (val << 16) >> 24;
    uint val4 = (val << 24) >> 24;

    return uvec4(val4, val3, val2, val1);
}

uint sample_neighbor_from_scene_info(uint volume_start, uvec2 raster_pos, uint f) {
    uint array_descr_start = volume_start + 6 + scene_info.infos[0];
    uint color_array_start = array_descr_start + 24;

    uint top_color_size_u = scene_info.infos[array_descr_start];
    uint top_color_size_v = scene_info.infos[array_descr_start + 1];

    uint bottom_color_size_u = scene_info.infos[array_descr_start + 2];
    uint bottom_color_size_v = scene_info.infos[array_descr_start + 3];

    uint left_color_size_u = scene_info.infos[array_descr_start + 4];
    uint left_color_size_v = scene_info.infos[array_descr_start + 5];

    uint right_color_size_u = scene_info.infos[array_descr_start + 6];
    uint right_color_size_v = scene_info.infos[array_descr_start + 7];

    uint front_color_size_u = scene_info.infos[array_descr_start + 8];
    uint front_color_size_v = scene_info.infos[array_descr_start + 9];

    uint back_color_size_u = scene_info.infos[array_descr_start + 10];
    uint back_color_size_v = scene_info.infos[array_descr_start + 11];

    uint top_neighbor_size_u = scene_info.infos[array_descr_start + 12];
    uint top_neighbor_size_v = scene_info.infos[array_descr_start + 13];

    uint bottom_neighbor_size_u = scene_info.infos[array_descr_start + 14];
    uint bottom_neighbor_size_v = scene_info.infos[array_descr_start + 15];

    uint left_neighbor_size_u = scene_info.infos[array_descr_start + 16];
    uint left_neighbor_size_v = scene_info.infos[array_descr_start + 17];

    uint right_neighbor_size_u = scene_info.infos[array_descr_start + 18];
    uint right_neighbor_size_v = scene_info.infos[array_descr_start + 19];

    uint front_neighbor_size_u = scene_info.infos[array_descr_start + 20];
    uint front_neighbor_size_v = scene_info.infos[array_descr_start + 21];

    uint back_neighbor_size_u = scene_info.infos[array_descr_start + 22];
    uint back_neighbor_size_v = scene_info.infos[array_descr_start + 23];

    uint top_color_size = top_color_size_u * top_color_size_v;
    uint bottom_color_size = bottom_color_size_u * bottom_color_size_v;
    uint left_color_size = left_color_size_u * left_color_size_v;
    uint right_color_size = right_color_size_u * right_color_size_v;
    uint front_color_size = front_color_size_u * front_color_size_v;
    uint back_color_size = back_color_size_u * back_color_size_v;

    uint color_array_end = color_array_start + top_color_size + bottom_color_size + left_color_size + right_color_size + front_color_size + back_color_size;

    uint top_neighbor_size = top_neighbor_size_u * top_neighbor_size_v;
    uint bottom_neighbor_size = bottom_neighbor_size_u * bottom_neighbor_size_v;
    uint left_neighbor_size = left_neighbor_size_u * left_neighbor_size_v;
    uint right_neighbor_size = right_neighbor_size_u * right_neighbor_size_v;
    uint front_neighbor_size = front_neighbor_size_u * front_neighbor_size_v;
    uint back_neighbor_size = back_neighbor_size_u * back_neighbor_size_v;

    // maybe do an array solution for this as well
    uint array_start = color_array_end + uint(f > 0) * top_neighbor_size + uint(f > 1) * bottom_neighbor_size + uint(f > 2) * left_neighbor_size + uint(f > 3) * right_neighbor_size + uint(f > 4) * front_neighbor_size;
    uint us[6] = {top_neighbor_size_u, bottom_neighbor_size_u, left_neighbor_size_u, right_neighbor_size_u, front_neighbor_size_u, back_neighbor_size_u};
    uint vs[6] = {top_neighbor_size_v, bottom_neighbor_size_v, left_neighbor_size_v, right_neighbor_size_v, front_neighbor_size_v, back_neighbor_size_v};
    uint u_size = us[f];
    uint v_size = vs[f];
    uint value = scene_info.infos[array_start + raster_pos.x * v_size * uint(u_size > 1) + raster_pos.y * uint(v_size > 1)];
    return value; 
}

uvec4 sample_color_from_scene_info(uint volume_start, uvec2 raster_pos, uint f) {
    uint array_descr_start = volume_start + 6 + scene_info.infos[0];
    uint color_array_start = array_descr_start + 24;

    uint top_color_size_u = scene_info.infos[array_descr_start];
    uint top_color_size_v = scene_info.infos[array_descr_start + 1];

    uint bottom_color_size_u = scene_info.infos[array_descr_start + 2];
    uint bottom_color_size_v = scene_info.infos[array_descr_start + 3];

    uint left_color_size_u = scene_info.infos[array_descr_start + 4];
    uint left_color_size_v = scene_info.infos[array_descr_start + 5];

    uint right_color_size_u = scene_info.infos[array_descr_start + 6];
    uint right_color_size_v = scene_info.infos[array_descr_start + 7];

    uint front_color_size_u = scene_info.infos[array_descr_start + 8];
    uint front_color_size_v = scene_info.infos[array_descr_start + 9];

    uint back_color_size_u = scene_info.infos[array_descr_start + 10];
    uint back_color_size_v = scene_info.infos[array_descr_start + 11];

    uint top_size = top_color_size_u * top_color_size_v;
    uint bottom_size = bottom_color_size_u * bottom_color_size_v;
    uint left_size = left_color_size_u * left_color_size_v;
    uint right_size = right_color_size_u * right_color_size_v;
    uint front_size = front_color_size_u * front_color_size_v;
    uint back_size = back_color_size_u * back_color_size_v;

    // maybe do an array solution for this as well
    uint array_start = color_array_start + uint(f > 0) * top_size + uint(f > 1) * bottom_size + uint(f > 2) * left_size + uint(f > 3) * right_size + uint(f > 4) * front_size;
    uint us[6] = {top_color_size_u, bottom_color_size_u, left_color_size_u, right_color_size_u, front_color_size_u, back_color_size_u};
    uint vs[6] = {top_color_size_v, bottom_color_size_v, left_color_size_v, right_color_size_v, front_color_size_v, back_color_size_v};
    uint u_size = us[f];
    uint v_size = vs[f];
    uint value = scene_info.infos[array_start + clamp(raster_pos.x, 0, u_size) * v_size * uint(u_size > 1) + clamp(raster_pos.y, 0, v_size) * uint(v_size > 1)];
    return unpack_color(value); 
}

vec3 get_light_position(uint light_index) {
    return vec3(uintBitsToFloat(scene_info.infos[light_index]), uintBitsToFloat(scene_info.infos[light_index + 1]), uintBitsToFloat(scene_info.infos[light_index + 2]));
}

vec3 get_light_color(uint light_index) {
    return vec3(float(scene_info.infos[light_index + 3]) / 255.0, float(scene_info.infos[light_index + 4]) / 255.0, float(scene_info.infos[light_index + 5]) / 255.0);
}

struct Tracing {
    vec3 end_pos;
    uvec4 end_color;
    uint end_volume;
    uint end_facing;
    float end_factor;
    uint end_cycle;
    bool has_hit;
    vec3 color_mul;
    uvec2 end_raster;
};

Tracing trace_ray(uint volume_start, vec3 starting_pos, vec3 direction, float max_factor, uint start_cycle, uint max_cycle) {
    uint cycle = start_cycle;
    // setup volume info
    uint volume_index = volume_start;
    uint volume_pos_x = scene_info.infos[volume_index + 0]; 
    uint volume_pos_y = scene_info.infos[volume_index + 1]; 
    uint volume_pos_z = scene_info.infos[volume_index + 2]; 

    bool x_pos = direction.x > 0.0;
    bool x_null = (direction.x == 0.0);
    
    bool y_pos = direction.y > 0.0;
    bool y_null = (direction.y == 0.0);

    bool z_pos = direction.z > 0.0;
    bool z_null = (direction.z == 0.0);

    // default is max factor, that way we avoid collision when going parallel to an axis. The other directions will score a hit
    float x_factor = max_factor;
    float y_factor = max_factor;
    float z_factor = max_factor;

    Tracing result;
    result.color_mul = vec3(1.0, 1.0, 1.0);

    while (cycle < max_cycle) {
        cycle ++;
        float x_border = float(volume_pos_x + (scene_info.infos[volume_index + 3]) * uint(x_pos)) - 0.5;
        float y_border = float(volume_pos_y + (scene_info.infos[volume_index + 4]) * uint(y_pos)) - 0.5;
        float z_border = float(volume_pos_z + (scene_info.infos[volume_index + 5]) * uint(z_pos)) - 0.5;
        
        bool needs_next_light = false;

        if (!x_null) {
            x_factor = (x_border - starting_pos.x) / direction.x;
        }
        if (!y_null) {
            y_factor = (y_border - starting_pos.y) / direction.y;
        }
        if (!z_null) {
            z_factor = (z_border - starting_pos.z) / direction.z;
        }

        if ((x_factor >= max_factor) && (y_factor >= max_factor) && (z_factor >= max_factor)) {
            // no hit, finish tracking
            result.has_hit = false;
            break;
        } else {
            // if there is a border hit before reaching the end
            // change to the relevant next volume
            // Todo: look into removing ifs from this
            uint hit_facing = 0;
            uint u = 0;
            uint v = 0;
            if (x_factor <= y_factor && x_factor <= z_factor) {
                if (x_pos) {
                    hit_facing = 3;
                } else {
                    hit_facing = 2;
                }
                vec3 intersection_pos = starting_pos + x_factor * direction;
                u = uint(round(intersection_pos.y)) - volume_pos_y;
                v = uint(round(intersection_pos.z)) - volume_pos_z;
                result.end_pos = intersection_pos;
                result.end_facing = hit_facing;
            }

            if (y_factor <= x_factor && y_factor <= z_factor) {
                if (y_pos) {
                    hit_facing = 5;
                } else {
                    hit_facing = 4;
                }
                vec3 intersection_pos = starting_pos + y_factor * direction;
                u = uint(round(intersection_pos.x)) - volume_pos_x;
                v = uint(round(intersection_pos.z)) - volume_pos_z;
                result.end_pos = intersection_pos;
                result.end_facing = hit_facing;
            }

            if (z_factor <= x_factor && z_factor <= y_factor) {
                if (z_pos) {
                    hit_facing = 0;
                } else {
                    hit_facing = 1;
                }
                vec3 intersection_pos = starting_pos + z_factor * direction;
                u = uint(round(intersection_pos.x)) - volume_pos_x;
                v = uint(round(intersection_pos.y)) - volume_pos_y;
                result.end_pos = intersection_pos;
                result.end_facing = hit_facing;
            }
            uint next_neighbor = sample_neighbor_from_scene_info(volume_index, uvec2(u, v), hit_facing);
            uvec4 color_sample = sample_color_from_scene_info(volume_index, uvec2(u, v), hit_facing);

            if (color_sample == uvec4(0, 0, 0, 0)) {
                // not a color hit, so check neighbor
                if (next_neighbor != 0) {
                    volume_index = next_neighbor;
                    volume_pos_x = scene_info.infos[volume_index + 0]; 
                    volume_pos_y = scene_info.infos[volume_index + 1]; 
                    volume_pos_z = scene_info.infos[volume_index + 2];
                } else {
                    // neightbor miss
                    break;
                }
            } else {
                if (next_neighbor != 0) {
                    // transparent hit, move on but change the color
                    volume_index = next_neighbor;
                    volume_pos_x = scene_info.infos[volume_index + 0]; 
                    volume_pos_y = scene_info.infos[volume_index + 1]; 
                    volume_pos_z = scene_info.infos[volume_index + 2];
                    result.color_mul = result.color_mul * vec3(float(color_sample.x) / 255.0, float(color_sample.y) / 255.0, float(color_sample.z) / 255.0);
                } else {
                    // color hit, move on
                    result.end_color = color_sample;
                    result.end_raster = uvec2(u, v);
                    result.has_hit = true;
                    break;
                }
            }
        }
    }
    result.end_volume = volume_index;
    result.end_factor = min(min(x_factor, y_factor), z_factor);
    result.end_cycle = cycle;

    return result;
}

vec3 get_lighting_color(uint volume_start, vec3 starting_pos, vec4 orig_color_sample, vec3 normal) {
    uint max_light_num = scene_info.infos[0];
    uint light_num = 0;

    // initialize color
    vec3 color_sum = vec3(0.0, 0.0, 0.0) + (orig_color_sample.xyz * 0.01);

    uint max_iterations = max_light_num * scene_info.infos[1];
    uint iteration = 0;
    while (iteration < max_iterations) {
        // setup light info
        uint light_index = scene_info.infos[volume_start + 6 + light_num];
        if (light_index == 0) {
            // abort if there is no new light
            break;
        }
        vec3 light_direction = get_light_position(light_index) - starting_pos;
        vec3 light_color = get_light_color(light_index);

        Tracing result = trace_ray(volume_start, starting_pos, light_direction, 1.0, iteration, max_iterations);
        if (!result.has_hit) {
            // no hit, add light color result
            color_sum += result.color_mul * max(dot(normal, normalize(light_direction)), 0.0) * (orig_color_sample.xyz * light_color) / (length(light_direction) * length(light_direction));
        }
        iteration = result.end_cycle;

        light_num += 1;
        if (light_num >= max_light_num) {
            break;
        }
    }

    return color_sum;
}

vec3 normal_for_facing(uint facing) {
    if (facing == 0) {
        return vec3(0.0, 0.0, -1.0);
    }
    if (facing == 1) {
        return vec3(0.0, 0.0, 1.0);
    }
    if (facing == 2) {
        return vec3(0.0, 1.0, 0.0);
    }
    if (facing == 3) {
        return vec3(0.0, -1.0, 0.0);
    }
    if (facing == 4) {
        return vec3(1.0, 0.0, 0.0);
    }
    if (facing == 5) {
        return vec3(-1.0, 0.0, 0.0);
    }

    return vec3(0.0, 0.0, 0.0);
}

vec3 diffuse_tracing(uint volume_start, uvec2 raster_pos, vec3 pos, uint f) {
    uvec4 color_roughness = sample_color_from_scene_info(volume_start, raster_pos, f);
    vec4 orig_color_sample = vec4(float(color_roughness.x) / 255.0, float(color_roughness.y) / 255.0, float(color_roughness.z) / 255.0, 1);
    vec3 normal = normal_for_facing(f);

    // diffuse raytracing using a quadratic raster of rays
    int raster_half_steps = int(scene_info.infos[2]);
    float raster_distance = uintBitsToFloat(scene_info.infos[3]);
    int raster_points = (2 * raster_half_steps + 1) * (2 * raster_half_steps + 1);

    vec3 color_sum = vec3(0.0, 0.0, 0.0);
    for (int u_offset = -raster_half_steps; u_offset <= raster_half_steps; u_offset++) {
        for (int v_offset = -raster_half_steps; v_offset <= raster_half_steps; v_offset++) {
            float x_offset = raster_distance * float(u_offset) * float(f == 0 || f == 1 || f == 4 || f == 5);
            float y_offset = raster_distance * float(u_offset) * float(f == 2 || f == 3);
            y_offset += raster_distance * float(v_offset) * float(f == 0 || f == 1);
            float z_offset = raster_distance * float(v_offset) * float(f == 4 || f == 5 || f == 2 || f == 3);

            vec3 offset = vec3(x_offset, y_offset, z_offset);

            color_sum += get_lighting_color(volume_start, pos + offset, orig_color_sample, normal) / float(raster_points);
        }
    }

    return color_sum;
}

vec3 clamp_to_volume(uint volume_start, vec3 position) {
    uint volume_pos_x = scene_info.infos[volume_start + 0]; 
    uint volume_pos_y = scene_info.infos[volume_start + 1]; 
    uint volume_pos_z = scene_info.infos[volume_start + 2]; 

    float high_x_border = float(volume_pos_x + (scene_info.infos[volume_start + 3])) - 0.5;
    float high_y_border = float(volume_pos_y + (scene_info.infos[volume_start + 4])) - 0.5;
    float high_z_border = float(volume_pos_z + (scene_info.infos[volume_start + 5])) - 0.5;

    float low_x_border = float(volume_pos_x) - 0.5;
    float low_y_border = float(volume_pos_y) - 0.5;
    float low_z_border = float(volume_pos_z) - 0.5;

    return vec3(min(max(position.x, low_x_border), high_x_border), min(max(position.y, low_y_border), high_y_border), min(max(position.z, low_z_border), high_z_border));
}

void main() {
    vec3 clamped_pos = clamp_to_volume(fragVolumeStart, origPosition);
    uvec4 color_roughness = sample_color_from_scene_info(fragVolumeStart, fragRasterPos, facing);
    vec3 orig_color_sample = vec3(float(color_roughness.x) / 255.0, float(color_roughness.y) / 255.0, float(color_roughness.z) / 255.0);
    vec3 color_sum;

    uint orig_neighbor = sample_neighbor_from_scene_info(fragVolumeStart, fragRasterPos, facing);
    if (orig_neighbor != 0) {
        float pos_infinity = uintBitsToFloat(0x7F800000);
        Tracing t = trace_ray(fragVolumeStart, ubo.camera_pos, clamped_pos - ubo.camera_pos, 100.0, 0, 20);
        float opacity = float(color_roughness.w) / 255.0;
        if (t.has_hit) {
            vec3 color_seen_through = diffuse_tracing(t.end_volume, t.end_raster, t.end_pos, t.end_facing) * orig_color_sample;
            vec3 color_direct = diffuse_tracing(fragVolumeStart, fragRasterPos, clamped_pos, facing);
            color_sum = opacity * color_direct + (1.0 - opacity) * color_seen_through;
        }
        else {
            // Todo: hit sky box
            vec3 color_direct = diffuse_tracing(fragVolumeStart, fragRasterPos, clamped_pos, facing);
            color_sum = opacity * color_direct + (1.0 - opacity) * vec3(0.0, 0.0, 0.0);
        }
    }
    else {
        color_sum = diffuse_tracing(fragVolumeStart, fragRasterPos, clamped_pos, facing);
    }


    outColor = vec4(color_sum, 1.0);
}