from OpenGL.GL import *
import numpy as np
from OpenGL.GL.ARB.vertex_array_object import glDeleteVertexArrays
from OpenGL.GL.framebufferobjects import glBindRenderbuffer
from OpenGL.GLUT import *
import OpenGL.GLUT.freeglut
from OpenGL.GLU import *
from OpenGL.GL import *
from ctypes import sizeof, c_float, c_void_p, c_uint

from Lights.Spotlight.Spotlight import Spotlight
from WorldProvider.WorldProvider import WorldProvider
from MatrixStuff.Transformations import perspectiveMatrix, lookAt, translate, rotate
from Objects.Cube.Cube import Cube
from Objects.Cuboid.Cuboid import Cuboid
from Objects.World import World
import json

import random
import time
from scipy.signal import convolve

MAX_DISTANCE = 200.0
FRICTION_COEFFICENT = 1
EPSILON = 0.00001


def value_to_color(v, min_value, max_value):
    r = g = b = 0.0
    scope = max_value - min_value
    normalized = (v - min_value) / (max_value - min_value)
    if 0.5 * scope + min_value != 0:
        b = max(0, 1.0 - abs(2.0 * normalized))
        g = max(0, 1.0 - abs(2.0 * normalized - 1.0))
        r = max(0, 1.0 - abs(2.0 * normalized - 2.0))
        l = np.sqrt((r * r + b * b + g * g))
        r /= l
        g /= l
        b /= l
    return r, g, b


class Client:
    def __init__(self, test=False, pos=[0, 0, 0]):
        self.state = 0
        with open('./config.json', 'r') as f:
            self.config = json.load(f)
        glutInit(sys.argv)
        self.width = 1920
        self.height = 1080
        glutInitDisplayMode(GLUT_DOUBLE | GLUT_RGB | GLUT_DEPTH)
        glutInitWindowSize(self.width, self.height)
        glutCreateWindow(b'Voxelengine')

        with open('passthroughvertex.glsl', 'r') as f:
            vertex_shader_string = f.read()
        self.passthrough_vertex_shader_id = glCreateShader(GL_VERTEX_SHADER)
        glShaderSource(self.passthrough_vertex_shader_id, vertex_shader_string)
        glCompileShader(self.passthrough_vertex_shader_id)
        if glGetShaderiv(self.passthrough_vertex_shader_id, GL_COMPILE_STATUS) != GL_TRUE:
            raise RuntimeError(glGetShaderInfoLog(self.passthrough_vertex_shader_id))

        with open('vertex.glsl', 'r') as f:
            vertex_shader_string = f.read()
        self.vertex_shader_id = glCreateShader(GL_VERTEX_SHADER)
        glShaderSource(self.vertex_shader_id, vertex_shader_string)
        glCompileShader(self.vertex_shader_id)
        if glGetShaderiv(self.vertex_shader_id, GL_COMPILE_STATUS) != GL_TRUE:
            raise RuntimeError(glGetShaderInfoLog(self.vertex_shader_id))

        with open('fragment.glsl', 'r') as f:
            fragment_shader_string = f.read()
        self.fragment_shader_id = glCreateShader(GL_FRAGMENT_SHADER)
        glShaderSource(self.fragment_shader_id, fragment_shader_string)
        glCompileShader(self.fragment_shader_id)
        if glGetShaderiv(self.fragment_shader_id, GL_COMPILE_STATUS) != GL_TRUE:
            raise RuntimeError(glGetShaderInfoLog(self.fragment_shader_id))

        Cube.initializeShader()
        Cuboid.initializeShader()
        self.geometry_shaders = {
            Cube: Cube.GeometryShaderId,
            Cuboid: Cuboid.GeometryShaderId
        }

        self.normal_program = {}
        self.depth_program = {}

        for key in self.geometry_shaders.keys():
            self.normal_program[key] = glCreateProgram()
            glAttachShader(self.normal_program[key], self.vertex_shader_id)
            glAttachShader(self.normal_program[key], key.GeometryShaderId)
            glAttachShader(self.normal_program[key], self.fragment_shader_id)
            glLinkProgram(self.normal_program[key])

            self.depth_program[self.normal_program[key]] = Spotlight.getDepthProgram(self.vertex_shader_id,
                                                                                     key.GeometryShaderId)

        self.world_provider = WorldProvider(self.normal_program)
        for x_pos in range(0, 100):
            for y_pos in range(0, 100):
                for z_pos in range(0, 1):
                    self.world_provider.world.put_object(x_pos, y_pos, z_pos, Cuboid().setColor(
                        random.randint(0, 100) / 100.0, random.randint(0, 100) / 100.0, random.randint(0, 100) / 100.0))

        self.projMatrix = perspectiveMatrix(45.0, 400 / 400, 0.01, MAX_DISTANCE)

        self.rx = self.cx = self.cy = 0
        self.opening = 45

        glutReshapeFunc(self.resize)
        glutDisplayFunc(self.display)
        glutKeyboardFunc(self.keyboardHandler)
        glutSpecialFunc(self.funcKeydHandler)

        self.pos = pos

        self.time = time.time()

        self.field = (100, 100, 1)
        self.e_a = np.array([
            # [0, 0, 0],
            [1, 0, 0],
            [1, 1, 0],
            [0, 1, 0],
            [-1, 1, 0],
            [-1, 0, 0],
            [-1, -1, 0],
            [0, -1, 0],
            [1, -1, 0],
        ])

        self.relaxation_time = 0.55  # 0.55
        self.w_a = [
            4.0 / 9.0,
            1.0 / 9.0,
            1.0 / 36.0,
            1.0 / 9.0,
            1.0 / 36.0,
            1.0 / 9.0,
            1.0 / 36.0,
            1.0 / 9.0,
            1.0 / 36.0
        ]

        self.n_a = np.zeros((len(self.e_a),) + self.field)
        self.n_a_eq = np.zeros(self.n_a.shape)
        self.n = np.zeros(self.field, np.int)
        self.n[:, 50:, :] += 1
        self.viscosity = np.reshape(self.n * 0.01, self.field + (1,))
        self.gravity_applies = np.zeros(self.field)
        # self.n /= np.sum(self.n)
        self.n_a[0] = np.array(self.n)
        self.u = np.zeros(self.field + (self.e_a.shape[1],))
        self.f = np.zeros(self.field + (self.e_a.shape[1],))
        self.true_pos = np.zeros(self.field + (self.e_a.shape[1],))
        self.pos_indices = np.zeros(self.field + (self.e_a.shape[1],), dtype=np.int)
        for x in range(self.field[0]):
            for y in range(self.field[1]):
                for z in range(self.field[2]):
                    self.pos_indices[x, y, z, :] = [x, y, z]

        self.compressible = True
        self.max_n = self.w_a[0]

        self.test_pixel = [40, 50, 0]

        if not test:
            glutMainLoop()
        else:
            self.display()
            self.resize(100, 100)

    def display(self):
        glClearColor(0, 0, 0, 0)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        projMatrix = perspectiveMatrix(45, float(self.width) / float(self.height), 0.01, MAX_DISTANCE)

        world: World = self.world_provider.world
        lights = world.get_lights_to_render(self.pos, self.config['render_light_distance'])
        for light in lights:
            light.prepareForDepthMapping()
            glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
            light_mat = translate(light.pos[0], light.pos[1], light.pos[2]) * \
                        lookAt(0, 0, 0, -light.pos[0], -light.pos[1], -light.pos[2], 0, 1, 0) * \
                        perspectiveMatrix(90, float(light.map_size) / float(light.map_size), 0.01, MAX_DISTANCE)

            for obj_type, program_id in self.depth_program.items():
                glUseProgram(program_id)
                widthid = glGetUniformLocation(program_id, 'width')
                heightid = glGetUniformLocation(program_id, 'height')
                nearid = glGetUniformLocation(program_id, 'near')
                farid = glGetUniformLocation(program_id, 'far')
                glUniform1f(nearid, 0.01)
                glUniform1f(farid, 100)
                glUniform1f(widthid, light.map_size)
                glUniform1f(heightid, light.map_size)

            world.render(light_mat, rotate(0, 0, 0), self.depth_program)
            glFlush()
            light.finishDepthMapping()
            glClearColor(0, 0, 0, 0)
            glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        glClearColor(0, 0, 0, 0)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)

        for obj_type, program_id in self.normal_program.items():
            glUseProgram(program_id)
            widthid = glGetUniformLocation(program_id, 'width')
            heightid = glGetUniformLocation(program_id, 'height')
            nearid = glGetUniformLocation(program_id, 'near')
            farid = glGetUniformLocation(program_id, 'far')
            glUniform1f(nearid, 0.01)
            glUniform1f(farid, 100)
            glUniform1f(widthid, self.width)
            glUniform1f(heightid, self.height)

        world.render(translate(self.pos[0], self.pos[1], self.pos[2]) * lookAt(0, 0, 0, 0, 0, -self.pos[2], 0, 1,
                                                                               0) * projMatrix, rotate(0, 0, 0))
        glFlush()

        glutSwapBuffers()

        min_value = 0
        max_value_n = np.max(self.n)
        # max_value_n = 1.0

        vel = np.sqrt(np.sum(np.square(self.u), axis=3)) * self.n
        max_value_vel = np.max(vel)
        # max_value_vel = np.sqrt(3)

        print('round')
        print('sum n: %f' % np.sum(self.n))
        print('max n: %f' % np.max(self.n))
        print('min n: %f' % np.min(self.n))
        print('sum vel: %f' % np.sum(vel))
        print('max vel: %f' % np.max(vel))
        print('min vel: %f' % np.min(vel))

        for x_pos in range(0, 100):
            for y_pos in range(0, 100):
                for z_pos in range(0, 1):
                    if self.state == 2:
                        r, g, b = value_to_color(int(self.gravity_applies[x_pos, y_pos, z_pos]), 0, 1)
                    if self.state == 1:
                        r, g, b = value_to_color(vel[x_pos, y_pos, z_pos], min_value, max_value_vel)
                    if self.state == 0:
                        r, g, b = value_to_color(self.n[x_pos, y_pos, z_pos], min_value, max_value_n)

                    self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
        self.world_provider.world.set_color(int(round(self.test_pixel[0])),
                                            int(round(self.test_pixel[1])),
                                            int(round(self.test_pixel[2])), 1.0, 1.0, 1.0)

        new_f = np.zeros(self.f.shape)
        # viscosity
        neighbours_filter = np.zeros((3, 3, 3), np.int) + 1
        neighbours = convolve(np.pad(self.n, 1, constant_values=0), neighbours_filter, mode='valid') * self.n
        neighbours = np.reshape(neighbours, self.field + (1,))
        forces_to_share_per_neighbor = self.f * self.viscosity
        length = np.sqrt(np.sum(np.square(forces_to_share_per_neighbor), axis=3, keepdims=True))
        direction = forces_to_share_per_neighbor / length
        direction[(length == 0)[:, :, :, 0]] = 0
        ############################## experimental
        for a in range(len(self.e_a)):
            unit = self.e_a[a] / np.sqrt(np.sum(np.square(self.e_a[a])))
            scalar = np.sum(direction * unit, axis=3, keepdims=True)
            altered_direction = direction + scalar * unit
            altered_length = np.sqrt(np.sum(np.square(altered_direction), axis=3, keepdims=True))
            altered_direction = altered_direction / altered_length
            altered_direction[(altered_length == 0)[:, :, :, 0]] = 0

            f_2_add = length * altered_direction

            new_f[max(0, self.e_a[a][0]):min(new_f.shape[0], new_f.shape[0] + self.e_a[a][0]),
            max(0, self.e_a[a][1]):min(new_f.shape[1], new_f.shape[1] + self.e_a[a][1]),
            max(0, self.e_a[a][2]):min(new_f.shape[2], new_f.shape[2] + self.e_a[a][2])] += f_2_add[
                                                                                            max(0, -self.e_a[a][0]):min(
                                                                                                new_f.shape[0],
                                                                                                new_f.shape[0] -
                                                                                                self.e_a[a][0]),
                                                                                            max(0, -self.e_a[a][1]):min(
                                                                                                new_f.shape[1],
                                                                                                new_f.shape[1] -
                                                                                                self.e_a[a][1]),
                                                                                            max(0, -self.e_a[a][2]):min(
                                                                                                new_f.shape[2],
                                                                                                new_f.shape[2] -
                                                                                                self.e_a[a][2])]

        new_f += self.f - forces_to_share_per_neighbor * neighbours
        #########################################
        new_f += self.f
        # TODO movement generating things
        # gravity
        new_f[:, :, :, 1] -= 1.0 * self.n
        # clean up
        new_f[self.n == 0] = 0

        self.f = new_f
        # collision and movement
        collision = True
        iterations = 0
        while collision:
            f_length = np.sqrt(np.sum(np.square(self.f), axis=3, keepdims=True))
            f_direction = self.f / f_length
            f_direction[f_length[:, :, :, 0] == 0] = 0
            velocity = f_direction * np.sqrt(f_length) / np.reshape(self.n, self.field + (1,)) # TODO replace self.n by mass
            velocity[self.n == 0] = 0
            timestep = min(1, 0.5 / np.max(np.sqrt(np.sum(np.square(velocity), axis=3))))
            if iterations > 20:
                print('Takes too long!')
                timestep /= 10
            new_pos = self.true_pos + velocity * timestep
            new_pos_round = np.round(new_pos)
            moved = np.logical_or.reduce(new_pos_round != [0, 0, 0], axis=3)
            pos_change_targets = new_pos_round[moved] + self.pos_indices[moved]

            # handle bordercases
            bordercase = np.array(moved)
            bordercase[moved] = np.logical_or(
                np.logical_or(np.logical_or(0 > pos_change_targets[:, 0], pos_change_targets[:, 0] >= self.field[0]),
                               np.logical_or(0 > pos_change_targets[:, 1], pos_change_targets[:, 1] >= self.field[1])),
                np.logical_or(0 > pos_change_targets[:, 2], pos_change_targets[:, 2] >= self.field[2]))
            self.f[bordercase] *= -1
            velocity[bordercase] *= -1

            # recalculate targets
            new_pos = self.true_pos + velocity * timestep
            new_pos_round = np.round(new_pos)
            new_pos_target = new_pos_round + self.pos_indices

            contenders = np.zeros(self.field)
            starts = self.pos_indices[self.n != 0]
            targets = new_pos_target[self.n != 0].astype(np.int)
            speeds = velocity[self.n != 0]
            forces = new_f[self.n != 0]
            max_speeds = np.zeros(self.field)
            fast_pos = {}
            is_stayer = []
            for index in range(len(targets)):
                target = targets[index]
                speed = np.sqrt(np.sum(np.square(speeds[index])))
                contenders[target[0], target[1], target[2]] += 1

                # new max speed and we do not update stayers
                if speed > max_speeds[target[0], target[1], target[2]] and tuple(target) not in is_stayer:
                    max_speeds[target[0], target[1], target[2]] = speed
                    fast_pos[tuple(target)] = index

                # atoms that are already there are there (and stay there) are the fastest
                start = starts[index]
                if np.all(start == target):
                    fast_pos[tuple(target)] = index
                    is_stayer.append(tuple(target))

            collision = np.max(contenders) > 1

            # we only need collision if there is one
            if collision:
                # go through the movers again
                for index in range(len(targets)):
                    target = targets[index]
                    # collision here?
                    if contenders[target[0], target[1], target[2]] > 1:
                        # the fastest one does not need to do anything
                        if index != fast_pos[tuple(target)]:
                            force = forces[index]
                            fastest = fast_pos[tuple(target)]
                            # TODO use relative weight?
                            forces[fastest] += 0.5*force
                            forces[index] *= 0.5
            new_f[self.n != 0] = forces


            self.f = new_f

            iterations += 1

        # final calculation
        f_length = np.sqrt(np.sum(np.square(self.f), axis=3, keepdims=True))
        f_direction = self.f / f_length
        f_direction[f_length[:, :, :, 0] == 0] = 0
        velocity = f_direction * np.sqrt(f_length) / np.reshape(self.n, self.field + (1,))  # TODO replace self.n by mass
        velocity[self.n == 0] = 0
        #timestep = min(1, 0.5 / np.max(np.sqrt(np.sum(np.square(velocity), axis=3))))
        new_pos = self.true_pos + velocity * timestep
        new_pos_round = np.round(new_pos)
        moved = np.logical_or.reduce(new_pos_round[:, :, :] != [0, 0, 0], axis=3)
        not_moved = np.logical_and.reduce(new_pos_round[:, :, :] == [0, 0, 0], axis=3)
        pos_change_targets = (new_pos_round[moved] + self.pos_indices[moved]).astype(np.int)
        movers = self.pos_indices[moved]

        print('timestep: %f' % timestep)

        update_n = np.zeros(self.n.shape, np.int)
        update_f = np.zeros(self.f.shape)
        update_u = np.zeros(self.u.shape)
        update_true_pos = np.zeros(self.true_pos.shape)

        update_n[not_moved] = self.n[not_moved]
        update_f[not_moved, :] = self.f[not_moved, :]
        update_u[not_moved, :] = velocity[not_moved, :]
        update_true_pos[not_moved, :] = new_pos[not_moved, :]

        for indice in range(len(movers)):
            mover = movers[indice]
            target = pos_change_targets[indice]
            update_n[target[0], target[1], target[2]] = self.n[mover[0], mover[1], mover[2]]
            update_f[target[0], target[1], target[2], :] = self.f[mover[0], mover[1], mover[2], :]
            update_u[target[0], target[1], target[2], :] = velocity[mover[0], mover[1], mover[2], :]
            update_true_pos[target[0], target[1], target[2], :] = new_pos[mover[0], mover[1], mover[2], :] - new_pos_round[mover[0], mover[1], mover[2], :]

        self.n = update_n
        self.f = update_f
        self.u = update_u
        self.true_pos = update_true_pos

        print(1.0 / (time.time() - self.time))
        self.time = time.time()
        glutPostRedisplay()

    def resize(self, w, h):
        w = max(w, 1)
        h = max(h, 1)
        glViewport(0, 0, w, h)
        self.projMatrix = perspectiveMatrix(45.0, float(w) / float(h), 0.01, MAX_DISTANCE)
        self.width = w
        self.height = h

    def keyboardHandler(self, key: int, x: int, y: int):
        if key == b'\x1b':
            exit()

        if key == b'+':
            self.rx += 0.25
        if key == b'-':
            self.rx -= 0.25

        if key == b'w':
            self.cy += 0.25
        if key == b's':
            self.cy -= 0.25

        if key == b'a':
            self.cx -= 0.25
        if key == b'd':
            self.cx += 0.25

        if key == b'q':
            self.opening -= 0.25
        if key == b'e':
            self.opening += 0.25

        if key == b'+':
            self.state = (self.state + 1) % 3

        if key == b'r':
            print(self.cx, self.cy, self.opening)
        # glutPostRedisplay()
        # print(key,x,y)

    def funcKeydHandler(self, key: int, x: int, y: int):
        if key == 11:
            glutFullScreenToggle()
        # print(key)


if __name__ == '__main__':
    client = Client(pos=[-50, -50, -200])