From 9cda7396983328df0b9ffc5e24cfd155361653a1 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Sat, 13 Feb 2021 21:31:39 +0100
Subject: [PATCH 01/14] fluid sim

---
 Client/Client.py | 184 +++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 178 insertions(+), 6 deletions(-)

diff --git a/Client/Client.py b/Client/Client.py
index 9baea65..0aa5dd4 100644
--- a/Client/Client.py
+++ b/Client/Client.py
@@ -119,7 +119,7 @@ class Client:
 
         self.field = (100, 100, 1)
         self.e_a = np.array([
-            [0, 0, 0],
+            # [0, 0, 0],
             [1, 0, 0],
             [1, 1, 0],
             [0, 1, 0],
@@ -145,12 +145,20 @@ class Client:
 
         self.n_a = np.zeros((len(self.e_a),) + self.field)
         self.n_a_eq = np.zeros(self.n_a.shape)
-        self.n = np.zeros(self.field)
-        self.n[:, :, :] += 1.0
+        self.n = np.zeros(self.field, np.int)
+        self.n[:, 50:, :] += 1
+        self.viscosity = np.reshape(self.n * 0.01, self.field + (1,))
         self.gravity_applies = np.zeros(self.field)
         # self.n /= np.sum(self.n)
         self.n_a[0] = np.array(self.n)
         self.u = np.zeros(self.field + (self.e_a.shape[1],))
+        self.f = np.zeros(self.field + (self.e_a.shape[1],))
+        self.true_pos = np.zeros(self.field + (self.e_a.shape[1],))
+        self.pos_indices = np.zeros(self.field + (self.e_a.shape[1],), dtype=np.int)
+        for x in range(self.field[0]):
+            for y in range(self.field[1]):
+                for z in range(self.field[2]):
+                    self.pos_indices[x, y, z, :] = [x, y, z]
 
         self.compressible = True
         self.max_n = self.w_a[0]
@@ -218,7 +226,7 @@ class Client:
         max_value_n = np.max(self.n)
         # max_value_n = 1.0
 
-        vel = np.sqrt(np.sum(np.square(self.u), axis=3)) *self.n
+        vel = np.sqrt(np.sum(np.square(self.u), axis=3)) * self.n
         max_value_vel = np.max(vel)
         # max_value_vel = np.sqrt(3)
 
@@ -245,7 +253,171 @@ class Client:
                                             int(round(self.test_pixel[1])),
                                             int(round(self.test_pixel[2])), 1.0, 1.0, 1.0)
 
-        # print(1.0 / (time.time() - self.time))
+        new_f = np.zeros(self.f.shape)
+        # viscosity
+        neighbours_filter = np.zeros((3, 3, 3), np.int) + 1
+        neighbours = convolve(np.pad(self.n, 1, constant_values=0), neighbours_filter, mode='valid') * self.n
+        neighbours = np.reshape(neighbours, self.field + (1,))
+        forces_to_share_per_neighbor = self.f * self.viscosity
+        length = np.sqrt(np.sum(np.square(forces_to_share_per_neighbor), axis=3, keepdims=True))
+        direction = forces_to_share_per_neighbor / length
+        direction[(length == 0)[:, :, :, 0]] = 0
+        ############################## experimental
+        for a in range(len(self.e_a)):
+            unit = self.e_a[a] / np.sqrt(np.sum(np.square(self.e_a[a])))
+            scalar = np.sum(direction * unit, axis=3, keepdims=True)
+            altered_direction = direction + scalar * unit
+            altered_length = np.sqrt(np.sum(np.square(altered_direction), axis=3, keepdims=True))
+            altered_direction = altered_direction / altered_length
+            altered_direction[(altered_length == 0)[:, :, :, 0]] = 0
+
+            f_2_add = length * altered_direction
+
+            new_f[max(0, self.e_a[a][0]):min(new_f.shape[0], new_f.shape[0] + self.e_a[a][0]),
+            max(0, self.e_a[a][1]):min(new_f.shape[1], new_f.shape[1] + self.e_a[a][1]),
+            max(0, self.e_a[a][2]):min(new_f.shape[2], new_f.shape[2] + self.e_a[a][2])] += f_2_add[
+                                                                                            max(0, -self.e_a[a][0]):min(
+                                                                                                new_f.shape[0],
+                                                                                                new_f.shape[0] -
+                                                                                                self.e_a[a][0]),
+                                                                                            max(0, -self.e_a[a][1]):min(
+                                                                                                new_f.shape[1],
+                                                                                                new_f.shape[1] -
+                                                                                                self.e_a[a][1]),
+                                                                                            max(0, -self.e_a[a][2]):min(
+                                                                                                new_f.shape[2],
+                                                                                                new_f.shape[2] -
+                                                                                                self.e_a[a][2])]
+
+        new_f += self.f - forces_to_share_per_neighbor * neighbours
+        #########################################
+        new_f += self.f
+        # TODO movement generating things
+        # gravity
+        new_f[:, :, :, 1] -= 1.0 * self.n
+        # clean up
+        new_f[self.n == 0] = 0
+
+        self.f = new_f
+        # collision and movement
+        collision = True
+        iterations = 0
+        while collision:
+            f_length = np.sqrt(np.sum(np.square(self.f), axis=3, keepdims=True))
+            f_direction = self.f / f_length
+            f_direction[f_length[:, :, :, 0] == 0] = 0
+            velocity = f_direction * np.sqrt(f_length) / np.reshape(self.n, self.field + (1,)) # TODO replace self.n by mass
+            velocity[self.n == 0] = 0
+            timestep = min(1, 0.5 / np.max(np.sqrt(np.sum(np.square(velocity), axis=3))))
+            if iterations > 20:
+                print('Takes too long!')
+                timestep /= 10
+            new_pos = self.true_pos + velocity * timestep
+            new_pos_round = np.round(new_pos)
+            moved = np.logical_or.reduce(new_pos_round != [0, 0, 0], axis=3)
+            pos_change_targets = new_pos_round[moved] + self.pos_indices[moved]
+
+            # handle bordercases
+            bordercase = np.array(moved)
+            bordercase[moved] = np.logical_or(
+                np.logical_or(np.logical_or(0 > pos_change_targets[:, 0], pos_change_targets[:, 0] >= self.field[0]),
+                               np.logical_or(0 > pos_change_targets[:, 1], pos_change_targets[:, 1] >= self.field[1])),
+                np.logical_or(0 > pos_change_targets[:, 2], pos_change_targets[:, 2] >= self.field[2]))
+            self.f[bordercase] *= -1
+            velocity[bordercase] *= -1
+
+            # recalculate targets
+            new_pos = self.true_pos + velocity * timestep
+            new_pos_round = np.round(new_pos)
+            new_pos_target = new_pos_round + self.pos_indices
+
+            contenders = np.zeros(self.field)
+            starts = self.pos_indices[self.n != 0]
+            targets = new_pos_target[self.n != 0].astype(np.int)
+            speeds = velocity[self.n != 0]
+            forces = new_f[self.n != 0]
+            max_speeds = np.zeros(self.field)
+            fast_pos = {}
+            is_stayer = []
+            for index in range(len(targets)):
+                target = targets[index]
+                speed = np.sqrt(np.sum(np.square(speeds[index])))
+                contenders[target[0], target[1], target[2]] += 1
+
+                # new max speed and we do not update stayers
+                if speed > max_speeds[target[0], target[1], target[2]] and tuple(target) not in is_stayer:
+                    max_speeds[target[0], target[1], target[2]] = speed
+                    fast_pos[tuple(target)] = index
+
+                # atoms that are already there are there (and stay there) are the fastest
+                start = starts[index]
+                if np.all(start == target):
+                    fast_pos[tuple(target)] = index
+                    is_stayer.append(tuple(target))
+
+            collision = np.max(contenders) > 1
+
+            # we only need collision if there is one
+            if collision:
+                # go through the movers again
+                for index in range(len(targets)):
+                    target = targets[index]
+                    # collision here?
+                    if contenders[target[0], target[1], target[2]] > 1:
+                        # the fastest one does not need to do anything
+                        if index != fast_pos[tuple(target)]:
+                            force = forces[index]
+                            fastest = fast_pos[tuple(target)]
+                            # TODO use relative weight?
+                            forces[fastest] += 0.5*force
+                            forces[index] *= 0.5
+            new_f[self.n != 0] = forces
+
+
+            self.f = new_f
+
+            iterations += 1
+
+        # final calculation
+        f_length = np.sqrt(np.sum(np.square(self.f), axis=3, keepdims=True))
+        f_direction = self.f / f_length
+        f_direction[f_length[:, :, :, 0] == 0] = 0
+        velocity = f_direction * np.sqrt(f_length) / np.reshape(self.n, self.field + (1,))  # TODO replace self.n by mass
+        velocity[self.n == 0] = 0
+        #timestep = min(1, 0.5 / np.max(np.sqrt(np.sum(np.square(velocity), axis=3))))
+        new_pos = self.true_pos + velocity * timestep
+        new_pos_round = np.round(new_pos)
+        moved = np.logical_or.reduce(new_pos_round[:, :, :] != [0, 0, 0], axis=3)
+        not_moved = np.logical_and.reduce(new_pos_round[:, :, :] == [0, 0, 0], axis=3)
+        pos_change_targets = (new_pos_round[moved] + self.pos_indices[moved]).astype(np.int)
+        movers = self.pos_indices[moved]
+
+        print('timestep: %f' % timestep)
+
+        update_n = np.zeros(self.n.shape, np.int)
+        update_f = np.zeros(self.f.shape)
+        update_u = np.zeros(self.u.shape)
+        update_true_pos = np.zeros(self.true_pos.shape)
+
+        update_n[not_moved] = self.n[not_moved]
+        update_f[not_moved, :] = self.f[not_moved, :]
+        update_u[not_moved, :] = velocity[not_moved, :]
+        update_true_pos[not_moved, :] = new_pos[not_moved, :]
+
+        for indice in range(len(movers)):
+            mover = movers[indice]
+            target = pos_change_targets[indice]
+            update_n[target[0], target[1], target[2]] = self.n[mover[0], mover[1], mover[2]]
+            update_f[target[0], target[1], target[2], :] = self.f[mover[0], mover[1], mover[2], :]
+            update_u[target[0], target[1], target[2], :] = velocity[mover[0], mover[1], mover[2], :]
+            update_true_pos[target[0], target[1], target[2], :] = new_pos[mover[0], mover[1], mover[2], :] - new_pos_round[mover[0], mover[1], mover[2], :]
+
+        self.n = update_n
+        self.f = update_f
+        self.u = update_u
+        self.true_pos = update_true_pos
+
+        print(1.0 / (time.time() - self.time))
         self.time = time.time()
         glutPostRedisplay()
 
@@ -282,7 +454,7 @@ class Client:
             self.opening += 0.25
 
         if key == b'+':
-            self.state = (self.state +1) % 3
+            self.state = (self.state + 1) % 3
 
         if key == b'r':
             print(self.cx, self.cy, self.opening)

From 95bd3de45a60da51b91010718015a8fa0b2bcfea Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Tue, 3 Aug 2021 10:33:07 +0200
Subject: [PATCH 02/14] lattice Boltzmann example/playground

---
 FluidSim/LatticeBoltzmann.py | 181 +++++++++++++++++++++++++++++++++++
 1 file changed, 181 insertions(+)
 create mode 100644 FluidSim/LatticeBoltzmann.py

diff --git a/FluidSim/LatticeBoltzmann.py b/FluidSim/LatticeBoltzmann.py
new file mode 100644
index 0000000..3940a7a
--- /dev/null
+++ b/FluidSim/LatticeBoltzmann.py
@@ -0,0 +1,181 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+"""
+Create Your Own Lattice Boltzmann Simulation (With Python)
+Philip Mocz (2020) Princeton Univeristy, @PMocz
+Simulate flow past cylinder
+for an isothermal fluid
+"""
+
+
+def main():
+    """ Finite Volume simulation """
+
+    # Simulation parameters
+    Nx = 400  # resolution x-dir
+    Ny = 100  # resolution y-dir
+    rho0 = 100  # average density
+    tau = 0.6  # collision timescale
+    Nt = 80000  # number of timesteps
+    plotRealTime = True  # switch on for plotting as the simulation goes along
+
+    # Lattice speeds / weights
+    NL = 9
+    idxs = np.arange(NL)
+    cxs = np.array([0, 0, 1, 1, 1, 0, -1, -1, -1])
+    cys = np.array([0, 1, 1, 0, -1, -1, -1, 0, 1])
+    weights = np.array([4 / 9, 1 / 9, 1 / 36, 1 / 9, 1 / 36, 1 / 9, 1 / 36, 1 / 9, 1 / 36])  # sums to 1
+
+    # Initial Conditions
+    F = np.ones((Ny, Nx, NL))  # * rho0 / NL
+    has_fluid = np.ones((Ny, Nx), dtype=np.bool)
+    has_fluid[int(Ny/2):, :] = False
+    np.random.seed(42)
+    F += 0.01 * np.random.randn(Ny, Nx, NL)
+    X, Y = np.meshgrid(range(Nx), range(Ny))
+    F[:, :, 3] += 2 * (1 + 0.2 * np.cos(2 * np.pi * X / Nx * 4))
+    # F[:, :, 5] += 1
+    rho = np.sum(F, 2)
+    for i in idxs:
+        F[:, :, i] *= rho0 / rho
+
+    # Cylinder boundary
+    X, Y = np.meshgrid(range(Nx), range(Ny))
+    cylinder = (X - Nx / 4) ** 2 + (Y - Ny / 2) ** 2 < (Ny / 4) ** 2
+    inner_cylinder = (X - Nx / 4) ** 2 + (Y - Ny / 2) ** 2 < (Ny / 4 - 2) ** 2
+    F[cylinder] = 0
+    F[0, :] = 0
+    F[Ny - 1, :] = 0
+    # F[int(Ny/2):, :] = 0
+
+    has_fluid[cylinder] = False
+    has_fluid[0, :] = False
+    has_fluid[Ny - 1, :] = False
+
+    # for i in idxs:
+    #     F[:, :, i] *= has_fluid
+
+    # Prep figure
+    fig = plt.figure(figsize=(4, 2), dpi=80)
+
+    reflection_mapping = [0, 5, 6, 7, 8, 1, 2, 3, 4]
+    # Simulation Main Loop
+    for it in range(Nt):
+        print(it)
+
+        # Drift
+        new_has_fluid = np.zeros((Ny, Nx))
+        F_sum = np.sum(F, 2)
+        for i, cx, cy in zip(idxs, cxs, cys):
+            F_part = F[:, :, i] / F_sum
+            F_part[F_sum == 0] = 0
+            to_move = F_part * (has_fluid * 1.0)
+            to_move = (np.roll(to_move, cx, axis=1))
+            to_move = (np.roll(to_move, cy, axis=0))
+
+            new_has_fluid += to_move
+
+            F[:, :, i] = np.roll(F[:, :, i], cx, axis=1)
+            F[:, :, i] = np.roll(F[:, :, i], cy, axis=0)
+
+        # has_fluid = new_has_fluid > 0.5
+        # new_has_fluid[F_sum == 0] += has_fluid[F_sum == 0] * 1.0
+        # new_has_fluid[(np.abs(F_sum) < 0.000000001)] = 0
+
+        fluid_sum = np.sum(has_fluid * 1.0)
+        has_fluid = (new_has_fluid / np.sum(new_has_fluid * 1.0)) * fluid_sum
+
+        print('fluid_cells: %d' % np.sum(has_fluid * 1))
+
+        # for i in idxs:
+        #     F[:, :, i] *= has_fluid
+
+        bndry = np.zeros((Ny, Nx), dtype=np.bool)
+        bndry[0, :] = True
+        bndry[Ny - 1, :] = True
+        # bndry[:, 0] = True
+        # bndry[:, Nx - 1] = True
+        bndry = np.logical_or(bndry, cylinder)
+
+        # bndry = np.logical_or(bndry, has_fluid < 0.5)
+
+        # Set reflective boundaries
+        bndryF = F[bndry, :]
+        bndryF = bndryF[:, reflection_mapping]
+
+        sum_f = np.sum(F)
+        print('Sum of Forces: %f' % sum_f)
+
+        # sum_f_cyl = np.sum(F[cylinder])
+        # print('Sum of Forces in cylinder: %f' % sum_f_cyl)
+
+        # sum_f_inner_cyl = np.sum(F[inner_cylinder])
+        # print('Sum of Forces in inner cylinder: %f' % sum_f_inner_cyl)
+
+        # if sum_f > 4000000.000000:
+        #     test = 1
+
+        # F[Ny - 1, :, 5] += 0.1
+        # F[0, :, 1] -= 0.1
+        # F[0, :, 5] += 0.1
+        # F[Ny - 1, :, 1] -= 0.1
+
+        # Calculate fluid variables
+        rho = np.sum(F, 2)
+        ux = np.sum(F * cxs, 2) / rho
+        uy = np.sum(F * cys, 2) / rho
+
+        ux[(np.abs(rho) < 0.000000001)] = 0
+        uy[(np.abs(rho) < 0.000000001)] = 0
+
+        # print('minimum rho: %f' % np.min(np.abs(rho)))
+        # print('Maximum F: %f' % np.max(F))
+        # print('Minimum F: %f' % np.min(F))
+
+        # Apply Collision
+        Feq = np.zeros(F.shape)
+        for i, cx, cy, w in zip(idxs, cxs, cys, weights):
+            Feq[:, :, i] = rho * w * (
+                        1 + 3 * (cx * ux + cy * uy) + 9 * (cx * ux + cy * uy) ** 2 / 2 - 3 * (ux ** 2 + uy ** 2) / 2)
+
+        F += -(1.0 / tau) * (F - Feq)
+
+        # Apply boundary
+        F[bndry, :] = bndryF
+
+        # plot in real time - color 1/2 particles blue, other half red
+        if (plotRealTime and (it % 10) == 0) or (it == Nt - 1):
+            plt.cla()
+            ux[cylinder] = 0
+            uy[cylinder] = 0
+            vorticity = (np.roll(ux, -1, axis=0) - np.roll(ux, 1, axis=0)) - (
+                        np.roll(uy, -1, axis=1) - np.roll(uy, 1, axis=1))
+            vorticity[cylinder] = np.nan
+
+            # vorticity *= has_fluid
+
+            cmap = plt.cm.bwr
+            cmap.set_bad('black')
+
+            # plt.imshow(vorticity, cmap='bwr')
+            plt.imshow(has_fluid * 2.0 - 1.0, cmap='bwr')
+            # plt.imshow(bndry * 2.0 - 1.0, cmap='bwr')
+
+            plt.clim(-.1, .1)
+            ax = plt.gca()
+            ax.invert_yaxis()
+            ax.get_xaxis().set_visible(False)
+            ax.get_yaxis().set_visible(False)
+            ax.set_aspect('equal')
+            plt.pause(0.001)
+
+    # Save figure
+    # plt.savefig('latticeboltzmann.png', dpi=240)
+    plt.show()
+
+    return 0
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file

From 8a5de47da3f02147c5128097e8bc5dcb13602637 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Fri, 15 Oct 2021 19:21:17 +0200
Subject: [PATCH 03/14] lava lamp lattice boltzmann

---
 FluidSim/FluidSimParameters.py |  24 ++++
 FluidSim/LatticeBoltzmann.py   | 209 ++++++++++++++++++++++++++-------
 2 files changed, 188 insertions(+), 45 deletions(-)
 create mode 100644 FluidSim/FluidSimParameters.py

diff --git a/FluidSim/FluidSimParameters.py b/FluidSim/FluidSimParameters.py
new file mode 100644
index 0000000..2170be4
--- /dev/null
+++ b/FluidSim/FluidSimParameters.py
@@ -0,0 +1,24 @@
+class FluidSimParameter:
+    viscosity = 0.1 / 3.0
+    # Pr = 1.0
+    Pr = 100.0
+    # vc = 1.0
+    vc = 0.5
+
+    def __init__(self, height: int):
+        self.t1 = 3 * self.viscosity + 0.5
+        self.t2 = (2 * self.t1 - 1) / (2 * self.Pr) + 0.5
+        self.g = (self.vc ** 2) / height
+
+        self.R = self.Pr * self.g * (height ** 3) / (self.viscosity ** 2)
+
+
+class MagmaParameter(FluidSimParameter):
+    viscosity = 10 ** 19
+    Pr = 10 ** 25
+
+
+class WaterParameter(FluidSimParameter):
+    viscosity = 8.9 * 10 ** -4
+    Pr = 7.56
+    vc = 0.05
diff --git a/FluidSim/LatticeBoltzmann.py b/FluidSim/LatticeBoltzmann.py
index 3940a7a..b39a8b3 100644
--- a/FluidSim/LatticeBoltzmann.py
+++ b/FluidSim/LatticeBoltzmann.py
@@ -1,5 +1,6 @@
 import matplotlib.pyplot as plt
 import numpy as np
+from FluidSim.FluidSimParameters import *
 
 """
 Create Your Own Lattice Boltzmann Simulation (With Python)
@@ -13,48 +14,63 @@ def main():
     """ Finite Volume simulation """
 
     # Simulation parameters
+    epsilon = 0.000000001
     Nx = 400  # resolution x-dir
     Ny = 100  # resolution y-dir
-    rho0 = 100  # average density
+    rho0 = 1  # average density
     tau = 0.6  # collision timescale
     Nt = 80000  # number of timesteps
     plotRealTime = True  # switch on for plotting as the simulation goes along
 
+    params = FluidSimParameter(Ny)
+    # params = WaterParameter(Ny)
+    # params = MagmaParameter(Ny)
+
     # Lattice speeds / weights
     NL = 9
     idxs = np.arange(NL)
     cxs = np.array([0, 0, 1, 1, 1, 0, -1, -1, -1])
     cys = np.array([0, 1, 1, 0, -1, -1, -1, 0, 1])
     weights = np.array([4 / 9, 1 / 9, 1 / 36, 1 / 9, 1 / 36, 1 / 9, 1 / 36, 1 / 9, 1 / 36])  # sums to 1
+    xx, yy = np.meshgrid(range(Nx), range(Ny))
 
     # Initial Conditions
-    F = np.ones((Ny, Nx, NL))  # * rho0 / NL
+    N = np.ones((Ny, Nx, NL))  # * rho0 / NL
+    temperature = np.ones((Ny, Nx, NL), np.float)  # * rho0 / NL
     has_fluid = np.ones((Ny, Nx), dtype=np.bool)
     has_fluid[int(Ny/2):, :] = False
     np.random.seed(42)
-    F += 0.01 * np.random.randn(Ny, Nx, NL)
+    N += 0.01 * np.random.randn(Ny, Nx, NL)
     X, Y = np.meshgrid(range(Nx), range(Ny))
-    F[:, :, 3] += 2 * (1 + 0.2 * np.cos(2 * np.pi * X / Nx * 4))
-    # F[:, :, 5] += 1
-    rho = np.sum(F, 2)
+    N[:, :, 3] += 2 * (1 + 0.2 * np.cos(2 * np.pi * X / Nx * 4))
+    # N[:, :, 5] += 1
+    rho = np.sum(N, 2)
+    temperature_rho = np.sum(temperature, 2)
     for i in idxs:
-        F[:, :, i] *= rho0 / rho
+        N[:, :, i] *= rho0 / rho
+        temperature[:, :, i] *= 1 / temperature_rho
+    # N[50:, :] = 0
+    temperature[:, :] = 0
+    # temperature += 0.01 * np.random.randn(Ny, Nx, NL)
+
 
     # Cylinder boundary
     X, Y = np.meshgrid(range(Nx), range(Ny))
     cylinder = (X - Nx / 4) ** 2 + (Y - Ny / 2) ** 2 < (Ny / 4) ** 2
     inner_cylinder = (X - Nx / 4) ** 2 + (Y - Ny / 2) ** 2 < (Ny / 4 - 2) ** 2
-    F[cylinder] = 0
-    F[0, :] = 0
-    F[Ny - 1, :] = 0
-    # F[int(Ny/2):, :] = 0
+    N[cylinder] = 0
+    N[0, :] = 0
+    N[Ny - 1, :] = 0
+
+    temperature[cylinder] = 0
+    # N[int(Ny/2):, :] = 0
 
     has_fluid[cylinder] = False
     has_fluid[0, :] = False
     has_fluid[Ny - 1, :] = False
 
     # for i in idxs:
-    #     F[:, :, i] *= has_fluid
+    #     N[:, :, i] *= has_fluid
 
     # Prep figure
     fig = plt.figure(figsize=(4, 2), dpi=80)
@@ -66,9 +82,9 @@ def main():
 
         # Drift
         new_has_fluid = np.zeros((Ny, Nx))
-        F_sum = np.sum(F, 2)
+        F_sum = np.sum(N, 2)
         for i, cx, cy in zip(idxs, cxs, cys):
-            F_part = F[:, :, i] / F_sum
+            F_part = N[:, :, i] / F_sum
             F_part[F_sum == 0] = 0
             to_move = F_part * (has_fluid * 1.0)
             to_move = (np.roll(to_move, cx, axis=1))
@@ -76,8 +92,11 @@ def main():
 
             new_has_fluid += to_move
 
-            F[:, :, i] = np.roll(F[:, :, i], cx, axis=1)
-            F[:, :, i] = np.roll(F[:, :, i], cy, axis=0)
+            N[:, :, i] = np.roll(N[:, :, i], cx, axis=1)
+            N[:, :, i] = np.roll(N[:, :, i], cy, axis=0)
+
+            temperature[:, :, i] = np.roll(temperature[:, :, i], cx, axis=1)
+            temperature[:, :, i] = np.roll(temperature[:, :, i], cy, axis=0)
 
         # has_fluid = new_has_fluid > 0.5
         # new_has_fluid[F_sum == 0] += has_fluid[F_sum == 0] * 1.0
@@ -89,7 +108,7 @@ def main():
         print('fluid_cells: %d' % np.sum(has_fluid * 1))
 
         # for i in idxs:
-        #     F[:, :, i] *= has_fluid
+        #     N[:, :, i] *= has_fluid
 
         bndry = np.zeros((Ny, Nx), dtype=np.bool)
         bndry[0, :] = True
@@ -101,51 +120,138 @@ def main():
         # bndry = np.logical_or(bndry, has_fluid < 0.5)
 
         # Set reflective boundaries
-        bndryF = F[bndry, :]
-        bndryF = bndryF[:, reflection_mapping]
+        bndryN = N[bndry, :]
+        bndryN = bndryN[:, reflection_mapping]
 
-        sum_f = np.sum(F)
-        print('Sum of Forces: %f' % sum_f)
+        bndryTemp = temperature[bndry, :]
+        bndryTemp = bndryTemp[:, reflection_mapping]
 
-        # sum_f_cyl = np.sum(F[cylinder])
+        sum_f = np.sum(N)
+        print('Sum of Particles: %f' % sum_f)
+        print('Sum of Temperature: %f' % np.sum(temperature))
+
+        # sum_f_cyl = np.sum(N[cylinder])
         # print('Sum of Forces in cylinder: %f' % sum_f_cyl)
 
-        # sum_f_inner_cyl = np.sum(F[inner_cylinder])
+        # sum_f_inner_cyl = np.sum(N[inner_cylinder])
         # print('Sum of Forces in inner cylinder: %f' % sum_f_inner_cyl)
 
         # if sum_f > 4000000.000000:
         #     test = 1
 
-        # F[Ny - 1, :, 5] += 0.1
-        # F[0, :, 1] -= 0.1
-        # F[0, :, 5] += 0.1
-        # F[Ny - 1, :, 1] -= 0.1
+        # N[Ny - 1, :, 5] += 0.1
+        # N[0, :, 1] -= 0.1
+        # N[0, :, 5] += 0.1
+        # N[Ny - 1, :, 1] -= 0.1
 
         # Calculate fluid variables
-        rho = np.sum(F, 2)
-        ux = np.sum(F * cxs, 2) / rho
-        uy = np.sum(F * cys, 2) / rho
+        rho = np.sum(N, 2)
+        temp_rho = np.sum(temperature, 2)
+        ux = np.sum(N * cxs, 2) / rho
+        uy = np.sum(N * cys, 2) / rho
 
-        ux[(np.abs(rho) < 0.000000001)] = 0
-        uy[(np.abs(rho) < 0.000000001)] = 0
+        ux[(np.abs(rho) < epsilon)] = 0
+        uy[(np.abs(rho) < epsilon)] = 0
+
+        g = -params.g * (temp_rho - yy / Ny)
+        # uy[np.abs(rho) >= epsilon] += g[np.abs(rho) >= epsilon] / 2.0
+        uy += g / 2.0
+
+        # u_length = np.maximum(np.abs(ux), np.abs(uy))
+        u_length = np.sqrt(np.square(ux) + np.square(uy))
+
+        u_max_length = np.max(u_length)
+        if u_max_length > np.sqrt(2):
+            ux = (ux / u_max_length) * np.sqrt(2)
+            uy = (uy / u_max_length) * np.sqrt(2)
+
+        print('max vector part: %f' % u_max_length)
+        # ux /= u_max_length
+        # uy /= u_max_length
+
+        # scale = abs(np.max(np.maximum(np.abs(ux), np.abs(uy))) - 1.0) < epsilon
+        # if scale:
+        #     g = 0.01 * (temp_rho - yy / Ny)
+        #
+        #     # F = np.zeros((Ny, Nx), dtype=np.bool)
+        #     # F = -0.1 * rho
+        #
+        #     # uy[np.abs(rho) >= epsilon] += tau * F[np.abs(rho) >= epsilon] / rho[np.abs(rho) >= epsilon]
+        #     uy[np.abs(rho) >= epsilon] += g[np.abs(rho) >= epsilon] / 2.0
+        #
+        #     u_length = np.maximum(np.abs(ux), np.abs(uy))
+        #     u_max_length = np.max(u_length)
+        #
+        #     print('max vector part: %f' % u_max_length)
+        #
+        #     ux /= u_max_length
+        #     uy /= u_max_length
 
         # print('minimum rho: %f' % np.min(np.abs(rho)))
-        # print('Maximum F: %f' % np.max(F))
-        # print('Minimum F: %f' % np.min(F))
+        # print('Maximum N: %f' % np.max(N))
+        # print('Minimum N: %f' % np.min(N))
 
         # Apply Collision
-        Feq = np.zeros(F.shape)
+        temperature_eq = np.zeros(temperature.shape)
+        Neq = np.zeros(N.shape)
         for i, cx, cy, w in zip(idxs, cxs, cys, weights):
-            Feq[:, :, i] = rho * w * (
+            Neq[:, :, i] = rho * w * (
                         1 + 3 * (cx * ux + cy * uy) + 9 * (cx * ux + cy * uy) ** 2 / 2 - 3 * (ux ** 2 + uy ** 2) / 2)
 
-        F += -(1.0 / tau) * (F - Feq)
+            temperature_eq[:, :, i] = temp_rho * w * (
+                    1 + 3 * (cx * ux + cy * uy) + 9 * (cx * ux + cy * uy) ** 2 / 2 - 3 * (ux ** 2 + uy ** 2) / 2)
+        # test1 = np.sum(Neq)
+        test2 = np.sum(N-Neq)
+        if abs(test2) > 0.0001:
+            test = ''
+
+        print('Overall change: %f' % test2)
+
+        n_pre_sum = np.sum(N[np.logical_not(bndry)])
+        temperature_pre_sum = np.sum(temperature[np.logical_not(bndry)])
+
+        N += -(1.0 / params.t1) * (N - Neq)
+        temperature += -(1.0 / params.t2) * (temperature - temperature_eq)
 
         # Apply boundary
-        F[bndry, :] = bndryF
+        N[bndry, :] = bndryN
+        temperature[bndry, :] = bndryTemp
+
+        # temperature[0, :, 0] = 0
+        # temperature[1, :, 0] = 0
+
+        temperature[0, :, 0] /= 2
+        temperature[1, :, 0] /= 2
+
+        temperature[Ny - 1, :, 0] = 1
+        temperature[Ny - 2, :, 0] = 1
+
+        # n_sum = np.sum(N, 2)
+        # n_sum_min = np.min(n_sum)
+        # if n_sum_min < 0:
+        #     N[np.logical_not(bndry)] += abs(n_sum_min)
+        #     N[np.logical_not(bndry)] /= np.sum(N[np.logical_not(bndry)])
+        #     N[np.logical_not(bndry)] *= n_pre_sum
+        #     print('Sum of Forces: %f' % np.sum(N))
+
+        # temperature_sum = np.sum(temperature, 2)
+        # temperature_sum_min = np.min(temperature_sum)
+        # if temperature_sum_min < 0:
+        #     temperature[np.logical_not(bndry)] += abs(temperature_sum_min)
+        #     temperature[np.logical_not(bndry)] /= np.sum(temperature[np.logical_not(bndry)])
+        #     temperature[np.logical_not(bndry)] *= temperature_pre_sum
+        #     print('Sum of Temperature: %f' % np.sum(temperature))
+
+        no_cylinder_mask = np.sum(N, 2) != 0
+        print('min N: %f' % np.min(np.sum(N, 2)[no_cylinder_mask]))
+        print('max N: %f' % np.max(np.sum(N, 2)))
+
+        print('min Temp: %f' % np.min(np.sum(temperature, 2)[no_cylinder_mask]))
+        print('max Temp: %f' % np.max(np.sum(temperature, 2)))
 
         # plot in real time - color 1/2 particles blue, other half red
         if (plotRealTime and (it % 10) == 0) or (it == Nt - 1):
+            fig.clear()
             plt.cla()
             ux[cylinder] = 0
             uy[cylinder] = 0
@@ -158,16 +264,29 @@ def main():
             cmap = plt.cm.bwr
             cmap.set_bad('black')
 
-            # plt.imshow(vorticity, cmap='bwr')
-            plt.imshow(has_fluid * 2.0 - 1.0, cmap='bwr')
+            plt.subplot(2, 2, 1)
+            plt.imshow(vorticity, cmap='bwr')
+            plt.clim(-.1, .1)
+            # plt.imshow(has_fluid * 2.0 - 1.0, cmap='bwr')
             # plt.imshow(bndry * 2.0 - 1.0, cmap='bwr')
+            # plt.imshow(np.sum(N, 2) * 2.0 - 1.0, cmap='bwr')
+            # plt.imshow((np.sum(temperature, 2) / np.max(np.sum(temperature, 2))) * 2.0 - 1.0, cmap='bwr')
+            plt.subplot(2, 2, 2)
+            max_temp = np.max(np.sum(temperature, 2))
+            # plt.imshow(np.sum(temperature, 2) / max_temp * 2.0 - 1.0, cmap='bwr')
+            plt.imshow(np.sum(temperature, 2) * 2.0 - 1.0, cmap='bwr')
+            plt.clim(-.1, .1)
+
+            plt.subplot(2, 2, 3)
+            max_N = np.max(np.sum(N, 2))
+            plt.imshow(np.sum(N, 2) / max_N * 2.0 - 1.0, cmap='bwr')
 
             plt.clim(-.1, .1)
-            ax = plt.gca()
-            ax.invert_yaxis()
-            ax.get_xaxis().set_visible(False)
-            ax.get_yaxis().set_visible(False)
-            ax.set_aspect('equal')
+            # ax = plt.gca()
+            # ax.invert_yaxis()
+            # ax.get_xaxis().set_visible(False)
+            # ax.get_yaxis().set_visible(False)
+            # ax.set_aspect('equal')
             plt.pause(0.001)
 
     # Save figure

From 54bda855e5adedc4a6ba020970887f357db0ff14 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Mon, 20 Dec 2021 17:21:46 +0100
Subject: [PATCH 04/14] travking fluid particle

---
 Client/Client.py               |  92 ++++++++++---
 FluidSim/FluidSimParameters.py |   2 +-
 FluidSim/FluidSimulator.py     | 185 ++++++++++++++++++++++++++
 FluidSim/LatticeBoltzmann.py   |  71 +++++++---
 FluidSim/StaggeredArray.py     | 108 ++++++++++++++++
 FluidSim/__init__.py           |   0
 Objects/World.py               | 229 +++++++++++++++++++++++++++++++++
 WorldProvider/WorldProvider.py |   1 +
 tests/test_FluidSimulator.py   |  29 +++++
 tests/test_Staggered_Array.py  |  22 ++++
 10 files changed, 702 insertions(+), 37 deletions(-)
 create mode 100644 FluidSim/FluidSimulator.py
 create mode 100644 FluidSim/StaggeredArray.py
 create mode 100644 FluidSim/__init__.py
 create mode 100644 tests/test_FluidSimulator.py
 create mode 100644 tests/test_Staggered_Array.py

diff --git a/Client/Client.py b/Client/Client.py
index 9baea65..d87efc3 100644
--- a/Client/Client.py
+++ b/Client/Client.py
@@ -103,6 +103,53 @@ class Client:
                     self.world_provider.world.put_object(x_pos, y_pos, z_pos, Cuboid().setColor(
                         random.randint(0, 100) / 100.0, random.randint(0, 100) / 100.0, random.randint(0, 100) / 100.0))
 
+        colors = {}
+        for plate in range(int(np.max(self.world_provider.world.plates))):
+            colors[plate + 1] = (random.randint(0, 100) / 100.0,
+                                 random.randint(0, 100) / 100.0,
+                                 random.randint(0, 100) / 100.0)
+
+        for x_pos in range(0, 100):
+            for y_pos in range(0, 100):
+                for z_pos in range(0, 1):
+                    if self.world_provider.world.plates[x_pos, y_pos] == -1:
+                        r, g, b, = 0, 0, 1 #0.5, 0.5, 0.5
+                    else:
+                        r, g, b = colors[int(self.world_provider.world.plates[x_pos, y_pos])]
+                    self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
+
+        total_x = self.world_provider.world.chunk_n_x * self.world_provider.world.chunk_size_x
+        total_y = self.world_provider.world.chunk_n_y * self.world_provider.world.chunk_size_y
+        for x_pos in range(0, 100):
+            for y_pos in range(0, 100):
+                if self.world_provider.world.faults[x_pos, y_pos] == -2:
+                    self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 0)
+
+        for line_index, line in enumerate(self.world_provider.world.fault_lines):
+            for x_pos in range(0, 100):
+                for y_pos in range(0, 100):
+                    if self.world_provider.world.faults[x_pos, y_pos] == line_index:
+                        if line_index != 9:
+                            self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 1)
+                        else:
+                            self.world_provider.world.set_color(x_pos, y_pos, 0, 1, 1, 1)
+
+        for x_pos in range(0, 100):
+            for y_pos in range(0, 100):
+                for z_pos in range(0, 1):
+                    if [x_pos, y_pos] in self.world_provider.world.fault_nodes:
+                        r, g, b = 1, 0, 0
+                        self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
+
+        # # visualize direction lengths
+        # lengths = np.sqrt(np.sum(np.square(self.world_provider.world.directions), axis=2))
+        # lengths = lengths / np.max(lengths)
+        # for x_pos in range(0, 100):
+        #     for y_pos in range(0, 100):
+        #         for z_pos in range(0, 1):
+        #             r, g, b = lengths[x_pos, y_pos], lengths[x_pos, y_pos], lengths[x_pos, y_pos]
+        #             self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
+
         self.projMatrix = perspectiveMatrix(45.0, 400 / 400, 0.01, MAX_DISTANCE)
 
         self.rx = self.cx = self.cy = 0
@@ -222,28 +269,31 @@ class Client:
         max_value_vel = np.max(vel)
         # max_value_vel = np.sqrt(3)
 
-        print('round')
-        print('sum n: %f' % np.sum(self.n))
-        print('max n: %f' % np.max(self.n))
-        print('min n: %f' % np.min(self.n))
-        print('sum vel: %f' % np.sum(vel))
-        print('max vel: %f' % np.max(vel))
-        print('min vel: %f' % np.min(vel))
+        # print('round')
+        # print('sum n: %f' % np.sum(self.n))
+        # print('max n: %f' % np.max(self.n))
+        # print('min n: %f' % np.min(self.n))
+        # print('sum vel: %f' % np.sum(vel))
+        # print('max vel: %f' % np.max(vel))
+        # print('min vel: %f' % np.min(vel))
 
-        for x_pos in range(0, 100):
-            for y_pos in range(0, 100):
-                for z_pos in range(0, 1):
-                    if self.state == 2:
-                        r, g, b = value_to_color(int(self.gravity_applies[x_pos, y_pos, z_pos]), 0, 1)
-                    if self.state == 1:
-                        r, g, b = value_to_color(vel[x_pos, y_pos, z_pos], min_value, max_value_vel)
-                    if self.state == 0:
-                        r, g, b = value_to_color(self.n[x_pos, y_pos, z_pos], min_value, max_value_n)
-
-                    self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
-        self.world_provider.world.set_color(int(round(self.test_pixel[0])),
-                                            int(round(self.test_pixel[1])),
-                                            int(round(self.test_pixel[2])), 1.0, 1.0, 1.0)
+        # for x_pos in range(0, 100):
+        #     for y_pos in range(0, 100):
+        #         for z_pos in range(0, 1):
+        #             # if self.state == 2:
+        #             #     r, g, b = value_to_color(int(self.gravity_applies[x_pos, y_pos, z_pos]), 0, 1)
+        #             # if self.state == 1:
+        #             #     r, g, b = value_to_color(vel[x_pos, y_pos, z_pos], min_value, max_value_vel)
+        #             # if self.state == 0:
+        #             #     r, g, b = value_to_color(self.n[x_pos, y_pos, z_pos], min_value, max_value_n)
+        #             r, g, b, = 128, 128, 128
+        #             if [x_pos, y_pos] in self.world_provider.world.fault_nodes:
+        #                 r, g, b = 128, 0, 0
+        #
+        #             self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
+        # self.world_provider.world.set_color(int(round(self.test_pixel[0])),
+        #                                     int(round(self.test_pixel[1])),
+        #                                     int(round(self.test_pixel[2])), 1.0, 1.0, 1.0)
 
         # print(1.0 / (time.time() - self.time))
         self.time = time.time()
diff --git a/FluidSim/FluidSimParameters.py b/FluidSim/FluidSimParameters.py
index 2170be4..0680ea5 100644
--- a/FluidSim/FluidSimParameters.py
+++ b/FluidSim/FluidSimParameters.py
@@ -1,7 +1,7 @@
 class FluidSimParameter:
     viscosity = 0.1 / 3.0
     # Pr = 1.0
-    Pr = 100.0
+    Pr = 1.0
     # vc = 1.0
     vc = 0.5
 
diff --git a/FluidSim/FluidSimulator.py b/FluidSim/FluidSimulator.py
new file mode 100644
index 0000000..ab520e0
--- /dev/null
+++ b/FluidSim/FluidSimulator.py
@@ -0,0 +1,185 @@
+from FluidSim.StaggeredArray import StaggeredArray2D
+import numpy as np
+import scipy
+import scipy.sparse
+import scipy.sparse.linalg
+
+class FluidSimulator2D:
+    def __init__(self, x_n: int, y_n: int):
+        self.x_n = x_n
+        self.y_n = y_n
+
+        self.array = StaggeredArray2D(self.x_n, self.y_n)
+
+        self.coordinate_array = np.zeros((x_n, y_n, 2), dtype=np.int)
+        for x in range(x_n):
+            for y in range(y_n):
+                self.coordinate_array[x, y, :] = x, y
+
+    def advect(self, field: np.ndarray, delta_t: float):
+        u_x, u_y = self.array.get_velocity_arrays()
+        u = np.stack([u_x, u_y], axis=2)
+
+        def runge_kutta_layer(input, time_elapsed, border_handling='clamp'):
+            shifted_pos = np.round(self.coordinate_array - u * time_elapsed).astype(np.int) * (u < 0) +\
+                          (self.coordinate_array - u * time_elapsed).astype(np.int) * (u > 0) + \
+                          self.coordinate_array * (u == 0)
+            # border handling
+            if border_handling == 'clamp':
+                shifted_pos = np.maximum(0, shifted_pos)
+                shifted_pos[:, :, 0] = np.minimum(len(input), shifted_pos[:, :, 0])
+                shifted_pos[:, :, 1] = np.minimum(len(input[0]), shifted_pos[:, :, 1])
+                pass
+            layer = np.zeros(field.shape, dtype=field.dtype)
+            for x in range(self.x_n):
+                for y in range(self.y_n):
+                    layer[x, y] = field[shifted_pos[x, y][0], shifted_pos[x, y][1]] - field[x, y]
+            return layer
+
+        k1 = runge_kutta_layer(field, delta_t)
+
+        k2 = runge_kutta_layer(field + 0.5 * delta_t * k1, 0.5 * delta_t)
+
+        k3 = runge_kutta_layer(field + 0.75 * delta_t * k2, 0.75 * delta_t) # maybe 0.25 instead?
+
+        # new_field = field + 2.0 / 9.0 * delta_t * k1 + 3.0 / 9.0 * delta_t * k2 + 4.0 / 9.0 * delta_t * k3
+        new_field = field + k1
+
+        return new_field
+
+    def get_timestep(self, f, h=1.0, k_cfl=1.0):
+        f_length = np.max(np.sqrt(np.sum(np.square(f), axis=-1)))
+
+        u_x, u_y = self.array.get_velocity_arrays()
+        u_length = np.max(np.sqrt(np.square(u_x) + np.square(u_y)))
+
+        # return k_cfl * h / (u_length + np.sqrt(h + f_length))
+        return k_cfl * h / u_length
+
+    def update_velocity(self, timestep, border_handling='constant'):
+        if border_handling == 'constant':
+            p_diff_x = np.pad(self.array.p, [(0, 1), (0, 0)], mode='constant', constant_values=0) -\
+                       np.pad(self.array.p, [(1, 0), (0, 0)], mode='constant', constant_values=0)
+            borders_fluid_x = np.pad(self.array.has_fluid, [(0, 1), (0, 0)], mode='constant', constant_values=False) +\
+                              np.pad(self.array.has_fluid, [(1, 0), (0, 0)], mode='constant', constant_values=False)
+
+            p_diff_y = np.pad(self.array.p, [(0, 0), (0, 1)], mode='constant', constant_values=0) -\
+                       np.pad(self.array.p, [(0, 0), (1, 0)], mode='constant', constant_values=0)
+            borders_fluid_y = np.pad(self.array.has_fluid, [(0, 0), (0, 1)], mode='constant', constant_values=False) +\
+                              np.pad(self.array.has_fluid, [(0, 0), (1, 0)],  mode='constant', constant_values=False)
+        else:
+            p_diff_x = 0
+            p_diff_y = 0
+            borders_fluid_x = False
+            borders_fluid_y = False
+
+        u_x_new = self.array.u_x - (timestep * p_diff_x) * (1.0 * borders_fluid_x)
+        u_y_new = self.array.u_y - (timestep * p_diff_y) * (1.0 * borders_fluid_y)
+
+        # clear all components that do not border the fluid
+        u_x_new *= (1.0 * borders_fluid_x)
+        u_y_new *= (1.0 * borders_fluid_y)
+
+        return u_x_new, u_y_new
+
+    def calculate_divergence(self, h=1.0):
+        dx_u = (self.array.u_x[1:, :] - self.array.u_x[:-1, :]) / h
+        dy_u = (self.array.u_y[:, 1:] - self.array.u_y[:, -1]) / h
+
+        return dx_u + dy_u
+
+    def pressure_solve(self, divergence):
+        new_p = np.zeros((self.array.x_n, self.array.y_n))
+        connection_matrix = np.zeros((self.array.x_n * self.array.y_n, self.array.x_n * self.array.y_n))
+        flat_divergence = np.zeros((self.array.x_n * self.array.y_n))
+
+        for x in range(self.array.x_n):
+            for y in range(self.array.y_n):
+                flat_divergence[x * self.array.y_n + y] = divergence[x, y]
+
+                neighbors = 4
+                if x == 0:
+                    neighbors -= 1
+                else:
+                    connection_matrix[x * self.array.y_n + y, (x - 1) * self.array.y_n + y] = 1
+                if y == 0:
+                    neighbors -= 1
+                else:
+                    connection_matrix[x * self.array.y_n + y, x * self.array.y_n + y - 1] = 1
+                if x == self.array.x_n - 1:
+                    neighbors -= 1
+                else:
+                    connection_matrix[x * self.array.y_n + y, (x + 1) * self.array.y_n + y] = 1
+                if y == self.array.y_n - 1:
+                    neighbors -= 1
+                else:
+                    connection_matrix[x * self.array.y_n + y, x * self.array.y_n + y + 1] = 1
+
+                connection_matrix[x * self.array.y_n + y, x * self.array.y_n + y] = -neighbors
+
+        p = scipy.sparse.linalg.spsolve(connection_matrix, -flat_divergence)
+
+        for x in range(self.array.x_n):
+            for y in range(self.array.y_n):
+                new_p[x, y] = p[x * self.array.y_n + y]
+        return new_p
+
+    def timestep(self, external_f, h=1.0, k_cfl=1.0):
+        """
+        :param external_f: external forces to be applied
+        :param h: grid size
+        :param k_cfl: speed up multiplier (reduces accuracy if > 1.0. Does not make much sense if smaller than 1.0)
+        :return:
+        """
+        delta_t = self.get_timestep(external_f, h, k_cfl)
+
+        has_fluid = self.advect(self.array.has_fluid * 1.0, delta_t)
+        self.array.density = self.advect(self.array.density, delta_t)
+
+        # temp_u_x = self.advect(self.array.u_x, delta_t)
+        # temp_u_y = self.advect(self.array.u_y, delta_t)
+        # self.array.u_x = temp_u_x
+        # self.array.u_y = temp_u_y
+        #TODO advect velocity
+        test_u = np.stack(self.array.get_velocity_arrays(), axis=2)
+        test = self.advect(np.stack(self.array.get_velocity_arrays(), axis=2), delta_t)
+
+        # TODO maybe use dynamic threshold to conserve amount of cells containing fluid
+        self.array.has_fluid = has_fluid >= 0.5
+        # add more stuff to advect here. For example temperature, density, and other things. Maybe advect velocity.
+
+        # self.array.u_x, self.array.u_y = self.update_velocity(delta_t)
+        # TODO add forces (a = F / m) -> add a to u
+
+        dx_u = (self.array.u_x[1:, :] - self.array.u_x[:-1, :]) / h
+        dy_u = (self.array.u_y[:, 1:] - self.array.u_y[:, -1]) / h
+
+        dx_u = self.advect(dx_u, delta_t)
+        dy_u = self.advect(dy_u, delta_t)
+
+        divergence = dx_u + dy_u
+        # divergence = self.calculate_divergence(h)
+
+        self.array.p = self.pressure_solve(divergence)
+
+        self.array.u_x, self.array.u_y = self.update_velocity(delta_t)
+
+
+if __name__ == '__main__':
+    fs = FluidSimulator2D(50, 50)
+
+    fs.array.has_fluid[10, 10] = True
+
+    for i in range(100):
+        fs.timestep(0)
+        print(i)
+
+    print(fs.array.has_fluid[10, 10])
+    print(np.sum(fs.array.has_fluid * 1.0))
+    # test = fs.advect(fs.array.has_fluid * 1.0, 1.0)
+    #
+    # test2 = fs.update_velocity(1.0)
+    #
+    # test3 = fs.calculate_divergence()
+    #
+    # test4 = fs.pressure_solve(test3)
diff --git a/FluidSim/LatticeBoltzmann.py b/FluidSim/LatticeBoltzmann.py
index b39a8b3..07d5acb 100644
--- a/FluidSim/LatticeBoltzmann.py
+++ b/FluidSim/LatticeBoltzmann.py
@@ -37,6 +37,10 @@ def main():
     # Initial Conditions
     N = np.ones((Ny, Nx, NL))  # * rho0 / NL
     temperature = np.ones((Ny, Nx, NL), np.float)  # * rho0 / NL
+
+    tracked_fluid = np.zeros((Ny, Nx, NL), np.float)  # * rho0 / NL
+    tracked_fluid[50:, :, 0] = 1
+
     has_fluid = np.ones((Ny, Nx), dtype=np.bool)
     has_fluid[int(Ny/2):, :] = False
     np.random.seed(42)
@@ -56,13 +60,17 @@ def main():
 
     # Cylinder boundary
     X, Y = np.meshgrid(range(Nx), range(Ny))
+
     cylinder = (X - Nx / 4) ** 2 + (Y - Ny / 2) ** 2 < (Ny / 4) ** 2
-    inner_cylinder = (X - Nx / 4) ** 2 + (Y - Ny / 2) ** 2 < (Ny / 4 - 2) ** 2
+    cylinder[:, :] = False
+
     N[cylinder] = 0
     N[0, :] = 0
     N[Ny - 1, :] = 0
 
     temperature[cylinder] = 0
+
+    tracked_fluid[cylinder] = 0
     # N[int(Ny/2):, :] = 0
 
     has_fluid[cylinder] = False
@@ -98,14 +106,17 @@ def main():
             temperature[:, :, i] = np.roll(temperature[:, :, i], cx, axis=1)
             temperature[:, :, i] = np.roll(temperature[:, :, i], cy, axis=0)
 
+            tracked_fluid[:, :, i] = np.roll(tracked_fluid[:, :, i], cx, axis=1)
+            tracked_fluid[:, :, i] = np.roll(tracked_fluid[:, :, i], cy, axis=0)
+
         # has_fluid = new_has_fluid > 0.5
         # new_has_fluid[F_sum == 0] += has_fluid[F_sum == 0] * 1.0
         # new_has_fluid[(np.abs(F_sum) < 0.000000001)] = 0
-
-        fluid_sum = np.sum(has_fluid * 1.0)
-        has_fluid = (new_has_fluid / np.sum(new_has_fluid * 1.0)) * fluid_sum
-
-        print('fluid_cells: %d' % np.sum(has_fluid * 1))
+        #
+        # fluid_sum = np.sum(has_fluid * 1.0)
+        # has_fluid = (new_has_fluid / np.sum(new_has_fluid * 1.0)) * fluid_sum
+        #
+        # print('fluid_cells: %d' % np.sum(has_fluid * 1))
 
         # for i in idxs:
         #     N[:, :, i] *= has_fluid
@@ -126,9 +137,13 @@ def main():
         bndryTemp = temperature[bndry, :]
         bndryTemp = bndryTemp[:, reflection_mapping]
 
+        bndryTracked = tracked_fluid[bndry, :]
+        bndryTracked = bndryTracked[:, reflection_mapping]
+
         sum_f = np.sum(N)
         print('Sum of Particles: %f' % sum_f)
         print('Sum of Temperature: %f' % np.sum(temperature))
+        print('Sum of tacked particles: %f' % np.sum(tracked_fluid))
 
         # sum_f_cyl = np.sum(N[cylinder])
         # print('Sum of Forces in cylinder: %f' % sum_f_cyl)
@@ -147,6 +162,9 @@ def main():
         # Calculate fluid variables
         rho = np.sum(N, 2)
         temp_rho = np.sum(temperature, 2)
+
+        tracked_rho = np.sum(tracked_fluid, 2)
+
         ux = np.sum(N * cxs, 2) / rho
         uy = np.sum(N * cys, 2) / rho
 
@@ -154,16 +172,28 @@ def main():
         uy[(np.abs(rho) < epsilon)] = 0
 
         g = -params.g * (temp_rho - yy / Ny)
+
+        uy1 = (uy + g * params.t1 * (tracked_rho * 0.9 + 0.1))
+        uy2 = (uy + g * params.t2 * (tracked_rho * 0.9 + 0.1))
+
         # uy[np.abs(rho) >= epsilon] += g[np.abs(rho) >= epsilon] / 2.0
-        uy += g / 2.0
+        # uy += g / 2.0
 
         # u_length = np.maximum(np.abs(ux), np.abs(uy))
-        u_length = np.sqrt(np.square(ux) + np.square(uy))
+        u_length1 = np.sqrt(np.square(ux) + np.square(uy1))
+        u_length2 = np.sqrt(np.square(ux) + np.square(uy2))
+
+        u_max_length1 = np.max(u_length1)
+        u_max_length2 = np.max(u_length2)
+        u_max_length = max(u_max_length1, u_max_length2)
+
+        if u_max_length > 2:
+            test = 1
 
-        u_max_length = np.max(u_length)
         if u_max_length > np.sqrt(2):
             ux = (ux / u_max_length) * np.sqrt(2)
-            uy = (uy / u_max_length) * np.sqrt(2)
+            uy1 = (uy1 / u_max_length) * np.sqrt(2)
+            uy2 = (uy2 / u_max_length) * np.sqrt(2)
 
         print('max vector part: %f' % u_max_length)
         # ux /= u_max_length
@@ -193,13 +223,17 @@ def main():
 
         # Apply Collision
         temperature_eq = np.zeros(temperature.shape)
+        tracked_eq = np.zeros(temperature.shape)
         Neq = np.zeros(N.shape)
         for i, cx, cy, w in zip(idxs, cxs, cys, weights):
             Neq[:, :, i] = rho * w * (
-                        1 + 3 * (cx * ux + cy * uy) + 9 * (cx * ux + cy * uy) ** 2 / 2 - 3 * (ux ** 2 + uy ** 2) / 2)
+                        1 + 3 * (cx * ux + cy * uy1) + 9 * (cx * ux + cy * uy1) ** 2 / 2 - 3 * (ux ** 2 + uy1 ** 2) / 2)
 
             temperature_eq[:, :, i] = temp_rho * w * (
-                    1 + 3 * (cx * ux + cy * uy) + 9 * (cx * ux + cy * uy) ** 2 / 2 - 3 * (ux ** 2 + uy ** 2) / 2)
+                    1 + 3 * (cx * ux + cy * uy2) + 9 * (cx * ux + cy * uy2) ** 2 / 2 - 3 * (ux ** 2 + uy2 ** 2) / 2)
+
+            tracked_eq[:, :, i] = tracked_rho * w * (
+                    1 + 3 * (cx * ux + cy * uy1) + 9 * (cx * ux + cy * uy1) ** 2 / 2 - 3 * (ux ** 2 + uy1 ** 2) / 2)
         # test1 = np.sum(Neq)
         test2 = np.sum(N-Neq)
         if abs(test2) > 0.0001:
@@ -212,10 +246,12 @@ def main():
 
         N += -(1.0 / params.t1) * (N - Neq)
         temperature += -(1.0 / params.t2) * (temperature - temperature_eq)
+        tracked_fluid += -(1.0 / params.t1) * (tracked_fluid - tracked_eq)
 
         # Apply boundary
         N[bndry, :] = bndryN
         temperature[bndry, :] = bndryTemp
+        tracked_fluid[bndry, :] = bndryTracked
 
         # temperature[0, :, 0] = 0
         # temperature[1, :, 0] = 0
@@ -223,8 +259,9 @@ def main():
         temperature[0, :, 0] /= 2
         temperature[1, :, 0] /= 2
 
-        temperature[Ny - 1, :, 0] = 1
-        temperature[Ny - 2, :, 0] = 1
+        if it <= 3000:
+            temperature[Ny - 1, :, 0] = 1
+            temperature[Ny - 2, :, 0] = 1
 
         # n_sum = np.sum(N, 2)
         # n_sum_min = np.min(n_sum)
@@ -280,8 +317,12 @@ def main():
             plt.subplot(2, 2, 3)
             max_N = np.max(np.sum(N, 2))
             plt.imshow(np.sum(N, 2) / max_N * 2.0 - 1.0, cmap='bwr')
-
             plt.clim(-.1, .1)
+
+            plt.subplot(2, 2, 4)
+            plt.imshow(np.sum(tracked_fluid, 2) * 2.0 - 1.0, cmap='bwr')
+            plt.clim(-.1, .1)
+
             # ax = plt.gca()
             # ax.invert_yaxis()
             # ax.get_xaxis().set_visible(False)
diff --git a/FluidSim/StaggeredArray.py b/FluidSim/StaggeredArray.py
new file mode 100644
index 0000000..ca17a93
--- /dev/null
+++ b/FluidSim/StaggeredArray.py
@@ -0,0 +1,108 @@
+import numpy as np
+from scipy.signal import convolve
+
+
+class StaggeredArray2D:
+    def __init__(self, x_n, y_n):
+        """
+        creates a staggered array
+        :param x_n: x size of the array
+        :param y_n: y size of the array
+        """
+        self.x_n = x_n
+        self.y_n = y_n
+        self.p = np.zeros((x_n, y_n), dtype=np.float)
+
+        self.u_x = np.zeros((x_n + 1, y_n), dtype=np.float)
+        self.u_y = np.zeros((x_n, y_n + 1), dtype=np.float)
+
+        self.has_fluid = np.zeros((x_n, y_n), dtype=np.bool)
+        self.density = np.zeros((x_n, y_n), dtype=np.float)
+
+    def get_velocity(self, x, y):
+        """
+        get mid point value for the coordinates
+        :param x: x coordinate
+        :param y: y coordinate
+        :return:
+        """
+        assert 0 <= x < self.x_n, 'x value out of bounds!'
+        assert 0 <= y < self.y_n, 'y value out of bounds!'
+
+        lower_x = self.u_x[x, y]
+        upper_x = self.u_x[x + 1, y]
+
+        lower_y = self.u_y[x, y]
+        upper_y = self.u_y[x, y + 1]
+
+        return (lower_x + upper_x) / 2.0, (lower_y + upper_y) / 2.0
+
+    def get_velocity_arrays(self):
+        c_x = np.array([[0.5], [0.5]])
+        u_x = convolve(self.u_x, c_x, mode='valid')
+
+        c_y = np.array([[0.5, 0.5]])
+        u_y = convolve(self.u_y, c_y, mode='valid')
+
+        return u_x, u_y
+
+
+class StaggeredArray3D:
+    def __init__(self, x_n, y_n, z_n):
+        """
+        creates a staggered array
+        :param x_n: x size of the array
+        :param y_n: y size of the array
+        :param z_n: z size of the array
+        """
+
+        self.x_n = x_n
+        self.y_n = y_n
+        self.z_n = z_n
+
+        self.p = np.zeros((x_n, y_n, z_n), dtype=np.float)
+
+        self.u_x = np.zeros((x_n + 1, y_n, z_n), dtype=np.float)
+        self.u_y = np.zeros((x_n, y_n + 1, z_n), dtype=np.float)
+        self.u_z = np.zeros((x_n, y_n, z_n + 1), dtype=np.float)
+
+        self.has_fluid = np.zeros((x_n, y_n, z_n), dtype=np.bool)
+
+    def get_velocity(self, x, y, z):
+        """
+        get mid point value for the coordinates
+        :param x: x coordinate
+        :param y: y coordinate
+        :param z: z coordinate
+        :return:
+        """
+        assert 0 <= x < self.x_n, 'x value out of bounds!'
+        assert 0 <= y < self.y_n, 'y value out of bounds!'
+        assert 0 <= z < self.z_n, 'y value out of bounds!'
+
+        lower_x = self.u_x[x, y, z]
+        upper_x = self.u_x[x + 1, y, z]
+
+        lower_y = self.u_y[x, y, z]
+        upper_y = self.u_y[x, y + 1, z]
+
+        lower_z = self.u_z[x, y, z]
+        upper_z = self.u_z[x, y, z + 1]
+
+        return (lower_x + upper_x) / 2.0, (lower_y + upper_y) / 2.0, (lower_z + upper_z) / 2.0
+
+
+if __name__ == '__main__':
+    sa = StaggeredArray2D(10, 10)
+
+    for x in range(11):
+        for y in range(10):
+            sa.u_x[x, y] = y
+
+    for x in range(10):
+        for y in range(11):
+            sa.u_y[x, y] = x
+
+    ux, uy = sa.get_velocity_arrays()
+
+    ux2, uy2 = sa.get_velocity(0, 0)
\ No newline at end of file
diff --git a/FluidSim/__init__.py b/FluidSim/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Objects/World.py b/Objects/World.py
index 4443f01..e21bdfe 100644
--- a/Objects/World.py
+++ b/Objects/World.py
@@ -7,6 +7,8 @@ from OpenGL.GLU import *
 from OpenGL.GL import *
 import math
 import numpy as np
+import random
+import sys
 
 class WorldChunk(Structure):
     def __init__(self, width: int, length: int, height: int, programs: dict):
@@ -200,6 +202,11 @@ class World(Renderable):
         self.chunk_n_z = chunk_n_z
         self.programs = programs
 
+        self.fault_nodes = []
+        self.fault_lines = []
+        self.plates = None
+        self.directions = None
+
         self.chunks: [[[WorldChunk]]] = []
         for x in range(chunk_n_x):
             self.chunks.append([])
@@ -208,6 +215,228 @@ class World(Renderable):
                 for z in range(chunk_n_z):
                     self.chunks[x][y].append(None)
 
+    def generate(self, seed=None, sea_height=50, continental_height=200):
+        if seed is None:
+            seed = random.randrange(2**32)
+            seed = 229805811
+        print('Generation seed is %i' % seed)
+        random.seed(seed)
+        np.random.seed(seed)
+        node_n = self.chunk_n_x + self.chunk_n_y
+        total_x = self.chunk_n_x * self.chunk_size_x
+        total_y = self.chunk_n_y * self.chunk_size_y
+        nodes = []
+        for _ in range(node_n):
+            nodes.append([random.randint(0, total_x - 1), random.randint(0, total_y - 1)])
+
+        # connections = np.random.randint(2, 5, len(nodes)) #np.zeros(len(nodes)) + 3
+        connections = (np.abs(np.random.normal(0, 5, len(nodes))) + 2).astype(np.int)
+
+        def calc_min_vector(start, end):
+            dx = end[0] - start[0]
+            wrapped_dx = dx % total_x
+
+            dy = end[1] - start[1]
+            wrapped_dy = dy % total_y
+
+            vector = np.array([dx, dy])
+            if wrapped_dx < abs(dx):
+                vector[0] = wrapped_dx
+            if wrapped_dy < abs(dy):
+                vector[1] = wrapped_dy
+            return vector
+
+        def is_intersecting_any(start, end, edges):
+            vec1 = calc_min_vector(start, end)
+
+            for (start2_index, end2_index) in edges:
+                start2 = nodes[start2_index]
+                end2 = nodes[end2_index]
+
+                vec2 = calc_min_vector(start2, end2)
+
+                norm1 = vec1 / np.sqrt(np.sum(np.square(vec1)))
+                norm2 = vec2 / np.sqrt(np.sum(np.square(vec2)))
+
+                # parrallel
+                parallel_threshold = 0.0001
+                if np.sqrt(np.sum(np.square(norm1 - norm2))) < parallel_threshold or np.sqrt(np.sum(np.square(norm1 + norm2))) < parallel_threshold:
+                    t = (start[0] - start2[0]) / vec2[0]
+                    t2 = (end[0] - start2[0]) / vec2[0]
+                    s = (start2[0] - start[0]) / vec1[0]
+                    s2 = (end2[0] - start[0]) / vec1[0]
+                    if (start2[1] + t * vec2[1]) - start[1] < parallel_threshold:
+                        if (0 <= t <= 1.0 and 0 <= t2 <= 1.0) or (0 <= s <= 1.0 and 0 <= s2 <= 1.0):
+                            return True
+                else:
+                    if start != start2 and end != end2 and end != start2 and start != end2:
+                        t = (vec1[0] * start[1] + vec1[1] * start2[0] - vec1[1] * start[0] - start2[1] * vec1[0]) / (vec2[1] * vec1[0] - vec2[0] * vec1[1])
+                        if 0 <= t <= 1.0:
+                            intersection = np.array(start2) + vec2 * t
+                            s = (intersection[0] - start[0]) / vec1[0]
+                            if 0 <= s <= 1.0:
+                                return True
+
+            return False
+
+
+
+        for index, node in enumerate(nodes):
+            distances = []
+            for other_index, other_node in enumerate(nodes):
+                if node != other_node and (index, other_index) not in self.fault_lines and\
+                        (other_index, index) not in self.fault_lines:
+                    if (not is_intersecting_any(node, other_node, self.fault_lines)) and (not is_intersecting_any(other_node, node, self.fault_lines)):
+                        distances.append((other_index, np.sqrt(np.sum(np.square(calc_min_vector(node, other_node))))))
+
+            distances.sort(key=lambda element: element[1])
+            while connections[index] > 0 and len(distances) > 0:
+                self.fault_lines.append((index, distances[0][0]))
+                connections[distances[0][0]] -= 1
+                connections[index] -= 1
+                distances.pop(0)
+
+        self.fault_nodes = nodes
+
+        plates = np.zeros((total_x, total_y))
+        faults = np.zeros((total_x, total_y)) - 1
+        plate_bordering_fault = {}
+        # draw fault lines
+        for fault_index, fault_line in enumerate(self.fault_lines):
+            start = self.fault_nodes[fault_line[0]]
+            end = self.fault_nodes[fault_line[1]]
+            vector = calc_min_vector(start, end)
+            vector = vector / np.sqrt(np.sum(np.square(vector)))
+
+            point = np.array(start, dtype=np.float)
+            plate_bordering_fault[fault_index] = []
+
+            while np.sqrt(np.sum(np.square(point - np.array(end)))) > 0.5:
+                plates[int(point[0]), int(point[1])] = -1
+                if faults[int(point[0]), int(point[1])] == -1:
+                    faults[int(point[0]), int(point[1])] = fault_index
+                elif faults[int(point[0]), int(point[1])] != fault_index:
+                    faults[int(point[0]), int(point[1])] = -2
+                point += 0.5 * vector
+                point[0] %= total_x
+                point[1] %= total_y
+        self.faults = faults
+
+        plate = 1
+        while np.any(plates == 0):
+            start = np.where(plates == 0)
+            start = (start[0][0], start[1][0])
+            plates[start] = plate
+            work_list = [start]
+
+            while len(work_list) > 0:
+                work = work_list.pop()
+
+                up = (work[0], (work[1] + 1) % total_y)
+                down = (work[0], (work[1] - 1) % total_y)
+                left = ((work[0] - 1) % total_x, work[1])
+                right = ((work[0] + 1) % total_x, work[1])
+
+                if plates[up] == -1 and plates[down] == -1 and plates[left] == -1 and plates[right] == -1:
+                    plates[work] = -1
+                    continue
+
+                if plates[up] <= 0:
+                    if plates[up] == 0:
+                        work_list.append(up)
+                    plates[up] = plate
+                if plates[down] <= 0:
+                    if plates[down] == 0:
+                        work_list.append(down)
+                    plates[down] = plate
+                if plates[left] <= 0:
+                    if plates[left] == 0:
+                        work_list.append(left)
+                    plates[left] = plate
+                if plates[right] <= 0:
+                    if plates[right] == 0:
+                        work_list.append(right)
+                    plates[right] = plate
+                    
+                if faults[up] > -1:
+                    if plate not in plate_bordering_fault[faults[up]]:
+                        plate_bordering_fault[faults[up]].append(plate)
+                if faults[down] > -1:
+                    if plate not in plate_bordering_fault[faults[down]]:
+                        plate_bordering_fault[faults[down]].append(plate)
+                if faults[left] > -1:
+                    if plate not in plate_bordering_fault[faults[left]]:
+                        plate_bordering_fault[faults[left]].append(plate)
+                if faults[right] > -1:
+                    if plate not in plate_bordering_fault[faults[right]]:
+                        plate_bordering_fault[faults[right]].append(plate)
+            plate += 1
+
+        plate_num = plate
+        for plate in range(1, plate_num):
+            if np.sum(plates == plate) < 20:
+                plates[plates == plate] = -1
+                for key, item in plate_bordering_fault.items():
+                    if plate in item:
+                        item.remove(plate)
+
+        directions = np.zeros((total_x, total_y, 3))
+        heights = np.zeros((total_x, total_y))
+        for plate in range(1, plate_num):
+            if random.randint(1, 2) == 1:
+                heights[plates == plate] = sea_height
+            else:
+                heights[plates == plate] = continental_height
+
+        coords = np.zeros((total_x, total_y, 2))
+        for x in range(total_x):
+            for y in range(total_y):
+                coords[x, y, 0] = x
+                coords[x, y, 1] = y
+
+        for fault_index, fault_line in enumerate(self.fault_lines):
+            start = self.fault_nodes[fault_line[0]]
+            end = self.fault_nodes[fault_line[1]]
+            vector = calc_min_vector(start, end)
+            vector = vector / np.sqrt(np.sum(np.square(vector)))
+
+            perpendicular = np.array([vector[1], -vector[0]])
+
+            if len(plate_bordering_fault[fault_index]) == 2:
+                for plate in plate_bordering_fault[fault_index]:
+                    vecs = coords - np.array(start)
+                    lengths = np.sqrt(np.sum(np.square(vecs), axis=2, keepdims=True))
+                    norm_vecs = vecs / lengths
+                    scalars = np.sum(norm_vecs * perpendicular, axis=2, keepdims=True)
+                    scalars[lengths == 0] = 0
+
+                    end_vecs = coords - np.array(end)
+                    end_lengths = np.sqrt(np.sum(np.square(end_vecs), axis=2, keepdims=True))
+                    end_min_length = np.min(end_lengths[np.logical_and(plates == plate, end_lengths[:, :, 0] > 0)])
+                    end_min_length_scalar = scalars[np.logical_and(plates == plate, end_lengths[:, :, 0] == end_min_length)][0, 0]
+
+                    min_length = np.min(lengths[np.logical_and(plates == plate, lengths[:, :, 0] > 0)])
+                    min_length_scalar = scalars[np.logical_and(plates == plate, lengths[:, :, 0] == min_length)][0, 0]
+
+                    mean_scalar = np.mean(scalars[plates == plate])
+
+                    if (min_length_scalar / abs(min_length_scalar)) == (end_min_length_scalar / abs(end_min_length_scalar)):
+                        scalar = min_length_scalar
+                    else:
+                        if (min_length_scalar / abs(min_length_scalar)) == (mean_scalar / abs(mean_scalar)):
+                            scalar = min_length_scalar
+                        else:
+                            scalar = end_min_length_scalar
+
+                    directions[plates == plate, :2] += perpendicular * (scalar / abs(scalar))
+                    pass
+
+
+
+
+        self.plates = plates
+        self.directions = directions
+
     def set_color(self, x: int, y: int, z: int, r: float, g: float, b: float):
         x = x % (self.chunk_size_x * self.chunk_n_x)
         y = y % (self.chunk_size_y * self.chunk_n_y)
diff --git a/WorldProvider/WorldProvider.py b/WorldProvider/WorldProvider.py
index 0401eca..a9629b7 100644
--- a/WorldProvider/WorldProvider.py
+++ b/WorldProvider/WorldProvider.py
@@ -4,6 +4,7 @@ from Objects.World import World
 class WorldProvider:
     def __init__(self, programs):
         self.world: World = World(10, 10, 10, 10, 10, 10, programs)
+        self.world.generate()
 
     def update(self):
         pass
diff --git a/tests/test_FluidSimulator.py b/tests/test_FluidSimulator.py
new file mode 100644
index 0000000..bc6b31b
--- /dev/null
+++ b/tests/test_FluidSimulator.py
@@ -0,0 +1,29 @@
+from FluidSim.FluidSimulator import FluidSimulator2D
+import numpy as np
+
+
+def test_stand_still():
+    fs = FluidSimulator2D(50, 50)
+
+    fs.array.has_fluid[10, 10] = True
+
+    for i in range(100):
+        fs.timestep(0)
+
+    assert fs.array.has_fluid[10, 10], "Fluid not on the same spot anymore"
+    assert np.sum(fs.array.has_fluid * 1.0) == 1.0, "Fluid amount changed"
+
+
+def test_move_positive_x():
+    fs = FluidSimulator2D(50, 50)
+
+    fs.array.has_fluid[10, 10] = True
+    fs.array.u_x[10, 10] = 1.0
+    fs.array.u_x[11, 10] = 1.0
+    # fs.array.u_x[9, 10] = -1.0
+
+    for i in range(10):
+        fs.timestep(0)
+        assert np.sum(fs.array.has_fluid * 1.0) == 1.0, "Fluid amount changed"
+
+    assert fs.array.has_fluid[0, 10], "Fluid not on the right border"
diff --git a/tests/test_Staggered_Array.py b/tests/test_Staggered_Array.py
new file mode 100644
index 0000000..ea61960
--- /dev/null
+++ b/tests/test_Staggered_Array.py
@@ -0,0 +1,22 @@
+from FluidSim.StaggeredArray import StaggeredArray2D
+
+
+def test_staggered_array_2D():
+    sa = StaggeredArray2D(10, 10)
+
+    for x in range(11):
+        for y in range(10):
+            sa.u_x[x, y] = y
+
+    for x in range(10):
+        for y in range(11):
+            sa.u_y[x, y] = x
+
+    ux, uy = sa.get_velocity_arrays()
+
+    for x in range(10):
+        for y in range(10):
+            ux2, uy2 = sa.get_velocity(x, y)
+
+            assert ux[x, y] == ux2, 'value output should be consistent!'
+            assert uy[x, y] == uy2, 'value output should be consistent!'

From a783afbc7ed1997200a385528c8bf2d132b61148 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Sun, 2 Jan 2022 11:33:28 +0100
Subject: [PATCH 05/14] more parameters and friction

---
 FluidSim/LatticeBoltzmann.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/FluidSim/LatticeBoltzmann.py b/FluidSim/LatticeBoltzmann.py
index 07d5acb..3ed13d4 100644
--- a/FluidSim/LatticeBoltzmann.py
+++ b/FluidSim/LatticeBoltzmann.py
@@ -21,6 +21,9 @@ def main():
     tau = 0.6  # collision timescale
     Nt = 80000  # number of timesteps
     plotRealTime = True  # switch on for plotting as the simulation goes along
+    render_frequency = 10
+    close_up_frequency = 10
+    friction = 0.0001
 
     params = FluidSimParameter(Ny)
     # params = WaterParameter(Ny)
@@ -180,6 +183,8 @@ def main():
         # uy += g / 2.0
 
         # u_length = np.maximum(np.abs(ux), np.abs(uy))
+
+        # safe guard against supersonic streams WIP
         u_length1 = np.sqrt(np.square(ux) + np.square(uy1))
         u_length2 = np.sqrt(np.square(ux) + np.square(uy2))
 
@@ -195,6 +200,11 @@ def main():
             uy1 = (uy1 / u_max_length) * np.sqrt(2)
             uy2 = (uy2 / u_max_length) * np.sqrt(2)
 
+        # apply friction
+        ux *= (1 - friction)
+        uy1 *= (1 - friction)
+        uy2 *= (1 - friction)
+
         print('max vector part: %f' % u_max_length)
         # ux /= u_max_length
         # uy /= u_max_length
@@ -286,8 +296,10 @@ def main():
         print('min Temp: %f' % np.min(np.sum(temperature, 2)[no_cylinder_mask]))
         print('max Temp: %f' % np.max(np.sum(temperature, 2)))
 
+        if it > render_frequency:
+            render_frequency = close_up_frequency
         # plot in real time - color 1/2 particles blue, other half red
-        if (plotRealTime and (it % 10) == 0) or (it == Nt - 1):
+        if (plotRealTime and (it % render_frequency) == 0) or (it == Nt - 1):
             fig.clear()
             plt.cla()
             ux[cylinder] = 0

From 6c5cae958b6d5a7cb609b57c38f69794d731ca33 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Sat, 15 Jan 2022 14:31:19 +0100
Subject: [PATCH 06/14] starts plate movement and frame improvements

---
 Client/Client.py      |  46 ++++++++---------
 Objects/Renderable.py |  12 +++--
 Objects/Structure.py  |  75 +++++++++++++++++++++------
 Objects/World.py      | 117 +++++++++++++++++++++++++++++++++---------
 4 files changed, 182 insertions(+), 68 deletions(-)

diff --git a/Client/Client.py b/Client/Client.py
index d87efc3..2f6699d 100644
--- a/Client/Client.py
+++ b/Client/Client.py
@@ -118,28 +118,28 @@ class Client:
                         r, g, b = colors[int(self.world_provider.world.plates[x_pos, y_pos])]
                     self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
 
-        total_x = self.world_provider.world.chunk_n_x * self.world_provider.world.chunk_size_x
-        total_y = self.world_provider.world.chunk_n_y * self.world_provider.world.chunk_size_y
-        for x_pos in range(0, 100):
-            for y_pos in range(0, 100):
-                if self.world_provider.world.faults[x_pos, y_pos] == -2:
-                    self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 0)
-
-        for line_index, line in enumerate(self.world_provider.world.fault_lines):
-            for x_pos in range(0, 100):
-                for y_pos in range(0, 100):
-                    if self.world_provider.world.faults[x_pos, y_pos] == line_index:
-                        if line_index != 9:
-                            self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 1)
-                        else:
-                            self.world_provider.world.set_color(x_pos, y_pos, 0, 1, 1, 1)
-
-        for x_pos in range(0, 100):
-            for y_pos in range(0, 100):
-                for z_pos in range(0, 1):
-                    if [x_pos, y_pos] in self.world_provider.world.fault_nodes:
-                        r, g, b = 1, 0, 0
-                        self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
+        # total_x = self.world_provider.world.chunk_n_x * self.world_provider.world.chunk_size_x
+        # total_y = self.world_provider.world.chunk_n_y * self.world_provider.world.chunk_size_y
+        # for x_pos in range(0, 100):
+        #     for y_pos in range(0, 100):
+        #         if self.world_provider.world.faults[x_pos, y_pos] == -2:
+        #             self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 0)
+        #
+        # for line_index, line in enumerate(self.world_provider.world.fault_lines):
+        #     for x_pos in range(0, 100):
+        #         for y_pos in range(0, 100):
+        #             if self.world_provider.world.faults[x_pos, y_pos] == line_index:
+        #                 if line_index != 9:
+        #                     self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 1)
+        #                 else:
+        #                     self.world_provider.world.set_color(x_pos, y_pos, 0, 1, 1, 1)
+        #
+        # for x_pos in range(0, 100):
+        #     for y_pos in range(0, 100):
+        #         for z_pos in range(0, 1):
+        #             if [x_pos, y_pos] in self.world_provider.world.fault_nodes:
+        #                 r, g, b = 1, 0, 0
+        #                 self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
 
         # # visualize direction lengths
         # lengths = np.sqrt(np.sum(np.square(self.world_provider.world.directions), axis=2))
@@ -295,7 +295,7 @@ class Client:
         #                                     int(round(self.test_pixel[1])),
         #                                     int(round(self.test_pixel[2])), 1.0, 1.0, 1.0)
 
-        # print(1.0 / (time.time() - self.time))
+        print(1.0 / (time.time() - self.time))
         self.time = time.time()
         glutPostRedisplay()
 
diff --git a/Objects/Renderable.py b/Objects/Renderable.py
index a24d804..fa1f9b8 100644
--- a/Objects/Renderable.py
+++ b/Objects/Renderable.py
@@ -1,8 +1,12 @@
-from OpenGL.GLU import gluErrorString
-from OpenGL.GL import glGetError, GL_NO_ERROR
+from OpenGL.GLU import *
+from OpenGL.GL import *
+
+import numpy as np
+
 
 class Renderable:
-    def render(self, projMatrix, geometryRotMatrix, alternateprograms=None):
+    def render(self, projMatrix, geometryRotMatrix, alternateprograms=None,
+               preselected_program=None, projection_pos=None, rot_pos=None):
         pass
 
     @staticmethod
@@ -15,4 +19,4 @@ class Renderable:
             else:
                 print(hex(gl_error))
             return True
-        return False
\ No newline at end of file
+        return False
diff --git a/Objects/Structure.py b/Objects/Structure.py
index 5b71802..5c49893 100644
--- a/Objects/Structure.py
+++ b/Objects/Structure.py
@@ -15,10 +15,41 @@ from Objects.Renderable import Renderable
 
 
 class Structure(Renderable):
-    def __init__(self):
+    def __init__(self, x_offset=0, y_offset=1, z_offset=0):
         self.Objects = {}
         self.vais = {}
-        self.dirty = False
+        self.dirty = True
+
+        self.x_offset = x_offset
+        self.y_offset = y_offset
+        self.z_offset = z_offset
+
+    @property
+    def x_offset(self):
+        return self._x_offset
+
+    @x_offset.setter
+    def x_offset(self, value):
+        self.dirty = True
+        self._x_offset = value
+
+    @property
+    def y_offset(self):
+        return self._y_offset
+
+    @y_offset.setter
+    def y_offset(self, value):
+        self.dirty = True
+        self._y_offset = value
+
+    @property
+    def z_offset(self):
+        return self._z_offset
+
+    @z_offset.setter
+    def z_offset(self, value):
+        self.dirty = True
+        self._z_offset = value
 
     def addShape(self, program, shape):
         if not program in self.Objects.keys():
@@ -59,9 +90,9 @@ class Structure(Renderable):
                 glBindBuffer(GL_ARRAY_BUFFER, tpbi)
                 positions = []
                 for o in objects:
-                    positions.append(o.pos[0])
-                    positions.append(o.pos[1])
-                    positions.append(o.pos[2])
+                    positions.append(o.pos[0] + self.x_offset)
+                    positions.append(o.pos[1] + self.y_offset)
+                    positions.append(o.pos[2] + self.z_offset)
                 glBufferData(GL_ARRAY_BUFFER, np.array(positions, dtype=np.float32), GL_STATIC_DRAW)
                 glVertexAttribPointer(vid, 3, GL_FLOAT, GL_FALSE, 0, None)
                 self.check_error("Could not create position buffer")
@@ -115,29 +146,39 @@ class Structure(Renderable):
             glDeleteVertexArrays(1, vertex_array_ids[0])
             self.check_error("Could not destroy vertex array")
 
-    def render(self, projMatrix, geometryRotMatrix, alternateprograms=None):
+    def render(self, projMatrix, geometryRotMatrix, alternateprograms=None,
+               preselected_program=None, projection_pos=None, rot_pos=None):
         self.buildvertexArrays()
         for key, vertex_array_ids in self.vais.items():
             if alternateprograms == None:
                 program_id = key
             else:
-                assert key in alternateprograms
+                assert key in alternateprograms.keys()
                 program_id = alternateprograms[key]
-            glUseProgram(program_id)
-            self.check_error("Renderingprogram is not initialized!")
+            # check if a program was preloaded
+            if preselected_program is not None:
+                # if preloaded we only want to render the matching vertex arrays
+                if preselected_program != program_id:
+                    continue
+            else:
+                glUseProgram(program_id)
+                self.check_error("Renderingprogram is not initialized!")
 
-            projection = glGetUniformLocation(program_id, 'projModelViewMatrix')
-            rot = glGetUniformLocation(program_id, 'rotMatrix')
+            if rot_pos is None:
+                rot = glGetUniformLocation(program_id, 'rotMatrix')
+                glUniformMatrix3fv(rot, 1, GL_FALSE, np.array(geometryRotMatrix))
 
-            glUniformMatrix4fv(projection, 1, GL_FALSE, np.array(projMatrix))
-            glUniformMatrix3fv(rot, 1, GL_FALSE, np.array(geometryRotMatrix))
+            if projection_pos is None:
+                projection = glGetUniformLocation(program_id, 'projModelViewMatrix')
+                glUniformMatrix4fv(projection, 1, GL_FALSE, np.array(projMatrix))
 
             glBindVertexArray(vertex_array_ids[0])
             glDrawArrays(GL_POINTS, 0, vertex_array_ids[4])
             self.check_error("Rendering problem")
 
             glBindVertexArray(0)
-            glUseProgram(0)
+            if preselected_program is None:
+                glUseProgram(0)
 
     def __eq__(self, other):
         if type(other) is type(self):
@@ -154,9 +195,11 @@ class CompoundStructure(Renderable):
                      R: np.matrix = np.identity(3, np.float)):
         self.Structures.append((structure, M, R))
 
-    def render(self, projMatrix, geometryRotMatrix, alternateprograms=None):
+    def render(self, projMatrix, geometryRotMatrix, alternateprograms=None,
+               preselected_program=None, projection_pos=None, rot_pos=None):
         for (structure, M, R) in self.Structures:
-            structure.render(M * projMatrix, R * geometryRotMatrix, alternateprograms)
+            structure.render(M * projMatrix, R * geometryRotMatrix, alternateprograms,
+                             preselected_program, projection_pos, rot_pos)
 
     def __eq__(self, other):
         if type(other) is type(self):
diff --git a/Objects/World.py b/Objects/World.py
index e21bdfe..473ccf2 100644
--- a/Objects/World.py
+++ b/Objects/World.py
@@ -10,6 +10,18 @@ import numpy as np
 import random
 import sys
 
+# Plate Types
+SEA_PLATE = 0
+CONTINENTAL_PLATE = 1
+
+# Rock types
+EMPTY = 0
+SEA_PLATE_STONE = 1
+MAGMATIC_STONE = 2
+METAMORPH_STONE = 3
+SEDIMENTAL_STONE = 4
+SEDIMENT = 5
+
 class WorldChunk(Structure):
     def __init__(self, width: int, length: int, height: int, programs: dict):
         assert width > 0, 'Width must be greater than 0'
@@ -135,9 +147,9 @@ class WorldChunk(Structure):
                 glBindBuffer(GL_ARRAY_BUFFER, tpbi)
                 positions = []
                 for o in object_list:
-                    positions.append(o.pos[0])
-                    positions.append(o.pos[1])
-                    positions.append(o.pos[2])
+                    positions.append(o.pos[0] + self.x_offset)
+                    positions.append(o.pos[1] + self.y_offset)
+                    positions.append(o.pos[2] + self.z_offset)
                 glBufferData(GL_ARRAY_BUFFER, np.array(positions, dtype=np.float32), GL_STATIC_DRAW)
                 glVertexAttribPointer(vid, 3, GL_FLOAT, GL_FALSE, 0, None)
                 self.check_error("Could not create position buffer")
@@ -175,11 +187,14 @@ class WorldChunk(Structure):
                 self.vais[key] = (tvai, tpbi, tcbi, tsbi, counts[key])
             self.dirty = False
 
-    def render(self, proj_matrix, geometry_rot_matrix, alternate_programs=None):
-        super(WorldChunk, self).render(proj_matrix, geometry_rot_matrix, alternate_programs)
+    def render(self, proj_matrix, geometry_rot_matrix, alternate_programs=None,
+               preselected_program=None, projection_pos=None, rot_pos=None):
+        super(WorldChunk, self).render(proj_matrix, geometry_rot_matrix, alternate_programs,
+                          preselected_program, projection_pos, rot_pos)
 
         for entity in self.entities:
-            entity.render(proj_matrix, geometry_rot_matrix, alternate_programs)
+            entity.render(proj_matrix, geometry_rot_matrix, alternate_programs,
+                          preselected_program, projection_pos, rot_pos)
 
     def set_color(self, x: int, y: int, z: int, r: float, g: float, b: float):
         assert 0 <= x < self.width, 'Put out of bounds for x coordinate! Must be between 0 and %i' % self.width
@@ -206,6 +221,9 @@ class World(Renderable):
         self.fault_lines = []
         self.plates = None
         self.directions = None
+        self.num_plates = 0
+        self.stone = None
+        self.faults = None
 
         self.chunks: [[[WorldChunk]]] = []
         for x in range(chunk_n_x):
@@ -215,7 +233,7 @@ class World(Renderable):
                 for z in range(chunk_n_z):
                     self.chunks[x][y].append(None)
 
-    def generate(self, seed=None, sea_height=50, continental_height=200):
+    def generate(self, seed: int=None, sea_plate_height: int = 50, continental_plate_height: int = 200):
         if seed is None:
             seed = random.randrange(2**32)
             seed = 229805811
@@ -381,12 +399,6 @@ class World(Renderable):
                         item.remove(plate)
 
         directions = np.zeros((total_x, total_y, 3))
-        heights = np.zeros((total_x, total_y))
-        for plate in range(1, plate_num):
-            if random.randint(1, 2) == 1:
-                heights[plates == plate] = sea_height
-            else:
-                heights[plates == plate] = continental_height
 
         coords = np.zeros((total_x, total_y, 2))
         for x in range(total_x):
@@ -431,11 +443,37 @@ class World(Renderable):
                     directions[plates == plate, :2] += perpendicular * (scalar / abs(scalar))
                     pass
 
-
-
+        for x in range(total_x):
+            for y in range(total_y):
+                if plates[x, y] == -1:
+                    plate = np.max(plates[x - 1: x + 1, y - 1: y + 1])
+                    plates[x, y] = plate
 
         self.plates = plates
         self.directions = directions
+        self.num_plates = plate_num
+
+        # max height will be three times the continental height
+        # sea level will be at one and a half time continental height
+        # with the continental plates top end ending there
+        # sea plates will be flush at the bottom end
+        max_height = 3 * continental_plate_height
+        sea_level = int(1.5 * continental_plate_height)
+        lower_level = sea_level - continental_plate_height
+        upper_sea_plate_level = lower_level + sea_plate_height
+
+        # stone kinds: 0: lava/air, 1: sea_plate, 2: magmatic_continental, 3: metamorph, 4: sedimental_rock, 5: sediment
+        self.stone = np.zeros((total_x, total_y, max_height), np.int)
+        plate_to_type = {}
+        for plate in range(1, plate_num):
+            if random.randint(1, 2) == 1:
+                self.stone[plates == plate, lower_level:upper_sea_plate_level] = SEA_PLATE_STONE
+                plate_to_type[plate] = SEA_PLATE
+            else:
+                self.stone[plates == plate, lower_level:sea_level] = MAGMATIC_STONE
+                plate_to_type[plate] = CONTINENTAL_PLATE
+
+        pass
 
     def set_color(self, x: int, y: int, z: int, r: float, g: float, b: float):
         x = x % (self.chunk_size_x * self.chunk_n_x)
@@ -462,6 +500,9 @@ class World(Renderable):
 
         if self.chunks[chunk_x][chunk_y][chunk_z] is None:
             self.chunks[chunk_x][chunk_y][chunk_z] = WorldChunk(self.chunk_size_x, self.chunk_size_y, self.chunk_size_z, self.programs)
+            self.chunks[chunk_x][chunk_y][chunk_z].x_offset = chunk_x * self.chunk_size_x
+            self.chunks[chunk_x][chunk_y][chunk_z].y_offset = chunk_y * self.chunk_size_z
+            self.chunks[chunk_x][chunk_y][chunk_z].z_offset = chunk_z * self.chunk_size_y
 
         carry_overs = self.chunks[chunk_x][chunk_y][chunk_z].put_object(x % self.chunk_size_x,
                                                                         y % self.chunk_size_y,
@@ -529,17 +570,38 @@ class World(Renderable):
                                                                  y % self.chunk_size_y,
                                                                  z % self.chunk_size_z)
 
-    def render(self, proj_matrix, geometry_rot_matrix, alternate_programs=None):
-        for x in range(self.chunk_n_x):
-            for y in range(self.chunk_n_y):
-                for z in range(self.chunk_n_z):
-                    if self.chunks[x][y][z] is not None:
-                        self.chunks[x][y][z].render(translate(x * self.chunk_size_x,
-                                                              y * self.chunk_size_y,
-                                                              z * self.chunk_size_z) * proj_matrix,
-                                                    geometry_rot_matrix, alternate_programs)
+    def render(self, proj_matrix, geometry_rot_matrix, alternate_programs=None,
+               preselected_program=None, projection_pos=None, rot_pos=None):
+        if preselected_program is not None:
+            for x in range(self.chunk_n_x):
+                for y in range(self.chunk_n_y):
+                    for z in range(self.chunk_n_z):
+                        if self.chunks[x][y][z] is not None:
+                            self.chunks[x][y][z].render(proj_matrix,
+                                                        geometry_rot_matrix, alternate_programs,
+                                                        preselected_program, projection_pos, rot_pos)
+        else:
+            for _, program_id in self.programs.items():
+                if alternate_programs == None:
+                    used_program_id = program_id
+                else:
+                    assert program_id in alternate_programs.keys()
+                    used_program_id = alternate_programs[program_id]
+                glUseProgram(used_program_id)
+                self.check_error("Renderingprogram is not initialized!")
+                projection = glGetUniformLocation(used_program_id, 'projModelViewMatrix')
+                rot = glGetUniformLocation(used_program_id, 'rotMatrix')
+                glUniformMatrix3fv(rot, 1, GL_FALSE, np.array(geometry_rot_matrix))
+                glUniformMatrix4fv(projection, 1, GL_FALSE, np.array(proj_matrix))
+                for x in range(self.chunk_n_x):
+                    for y in range(self.chunk_n_y):
+                        for z in range(self.chunk_n_z):
+                            if self.chunks[x][y][z] is not None:
+                                self.chunks[x][y][z].render(proj_matrix,
+                                                            geometry_rot_matrix, alternate_programs,
+                                                            used_program_id, projection, rot)
 
-    def add_light(self, x: float, y: float, z: float, l: Light):
+    def add_light(self, x: float, y: float, z: float, l: Light)-> Light:
         x = x % (self.chunk_size_x * self.chunk_n_x)
         y = y % (self.chunk_size_y * self.chunk_n_y)
         z = z % (self.chunk_size_z * self.chunk_n_z)
@@ -550,10 +612,15 @@ class World(Renderable):
 
         if self.chunks[chunk_x][chunk_y][chunk_z] is None:
             self.chunks[chunk_x][chunk_y][chunk_z] = WorldChunk(self.chunk_size_x, self.chunk_size_y, self.chunk_size_z, self.programs)
+            self.chunks[chunk_x][chunk_y][chunk_z].x_offset = chunk_x * self.chunk_size_x
+            self.chunks[chunk_x][chunk_y][chunk_z].y_offset = chunk_y * self.chunk_size_z
+            self.chunks[chunk_x][chunk_y][chunk_z].z_offset = chunk_z * self.chunk_size_y
 
         self.chunks[chunk_x][chunk_y][chunk_z].lights.append(l)
         l.pos = [x, y, z]
 
+        return l
+
     def remove_light(self, l: Light):
         chunk_x = int(l.pos[0] / self.chunk_size_x)
         chunk_y = int(l.pos[1] / self.chunk_size_y)

From 0638d5e6662ff3bf4e0ccdd63080ce6e18f6f25b Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Mon, 7 Feb 2022 21:08:45 +0100
Subject: [PATCH 07/14] adds labyrinth and subjects as well as performance
 increases

---
 Client/Client.py                  |  146 +---
 Objects/Structure.py              |  126 ++--
 Objects/World.py                  |  176 +++--
 WorldProvider/WorldProvider.py    |    4 +-
 labirinth_ai/LabyrinthClient.py   |   43 ++
 labirinth_ai/LabyrinthProvider.py |    6 +
 labirinth_ai/LabyrinthWorld.py    |  232 +++++++
 labirinth_ai/Subject.py           | 1055 +++++++++++++++++++++++++++++
 labirinth_ai/__init__.py          |    0
 labirinth_ai/loss.py              |   37 +
 10 files changed, 1591 insertions(+), 234 deletions(-)
 create mode 100644 labirinth_ai/LabyrinthClient.py
 create mode 100644 labirinth_ai/LabyrinthProvider.py
 create mode 100644 labirinth_ai/LabyrinthWorld.py
 create mode 100644 labirinth_ai/Subject.py
 create mode 100644 labirinth_ai/__init__.py
 create mode 100644 labirinth_ai/loss.py

diff --git a/Client/Client.py b/Client/Client.py
index 2f6699d..e64554a 100644
--- a/Client/Client.py
+++ b/Client/Client.py
@@ -41,10 +41,30 @@ def value_to_color(v, min_value, max_value):
 
 
 class Client:
-    def __init__(self, test=False, pos=[0, 0, 0]):
+    def __init__(self, test=False, pos=[0, 0, 0], world_class=WorldProvider):
         self.state = 0
         with open('./config.json', 'r') as f:
             self.config = json.load(f)
+        self.init_shaders()
+
+        self.world_provider = world_class(self.normal_program)
+        self.draw_world()
+
+        self.pos = pos
+        self.time = time.time()
+
+        glutReshapeFunc(self.resize)
+        glutDisplayFunc(self.display)
+        glutKeyboardFunc(self.keyboardHandler)
+        glutSpecialFunc(self.funcKeydHandler)
+
+        if not test:
+            glutMainLoop()
+        else:
+            self.display()
+            self.resize(100, 100)
+
+    def init_shaders(self):
         glutInit(sys.argv)
         self.width = 1920
         self.height = 1080
@@ -96,7 +116,7 @@ class Client:
             self.depth_program[self.normal_program[key]] = Spotlight.getDepthProgram(self.vertex_shader_id,
                                                                                      key.GeometryShaderId)
 
-        self.world_provider = WorldProvider(self.normal_program)
+    def draw_world(self):
         for x_pos in range(0, 100):
             for y_pos in range(0, 100):
                 for z_pos in range(0, 1):
@@ -118,97 +138,11 @@ class Client:
                         r, g, b = colors[int(self.world_provider.world.plates[x_pos, y_pos])]
                     self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
 
-        # total_x = self.world_provider.world.chunk_n_x * self.world_provider.world.chunk_size_x
-        # total_y = self.world_provider.world.chunk_n_y * self.world_provider.world.chunk_size_y
-        # for x_pos in range(0, 100):
-        #     for y_pos in range(0, 100):
-        #         if self.world_provider.world.faults[x_pos, y_pos] == -2:
-        #             self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 0)
-        #
-        # for line_index, line in enumerate(self.world_provider.world.fault_lines):
-        #     for x_pos in range(0, 100):
-        #         for y_pos in range(0, 100):
-        #             if self.world_provider.world.faults[x_pos, y_pos] == line_index:
-        #                 if line_index != 9:
-        #                     self.world_provider.world.set_color(x_pos, y_pos, 0, 0, 0, 1)
-        #                 else:
-        #                     self.world_provider.world.set_color(x_pos, y_pos, 0, 1, 1, 1)
-        #
-        # for x_pos in range(0, 100):
-        #     for y_pos in range(0, 100):
-        #         for z_pos in range(0, 1):
-        #             if [x_pos, y_pos] in self.world_provider.world.fault_nodes:
-        #                 r, g, b = 1, 0, 0
-        #                 self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
-
-        # # visualize direction lengths
-        # lengths = np.sqrt(np.sum(np.square(self.world_provider.world.directions), axis=2))
-        # lengths = lengths / np.max(lengths)
-        # for x_pos in range(0, 100):
-        #     for y_pos in range(0, 100):
-        #         for z_pos in range(0, 1):
-        #             r, g, b = lengths[x_pos, y_pos], lengths[x_pos, y_pos], lengths[x_pos, y_pos]
-        #             self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
-
         self.projMatrix = perspectiveMatrix(45.0, 400 / 400, 0.01, MAX_DISTANCE)
 
         self.rx = self.cx = self.cy = 0
         self.opening = 45
 
-        glutReshapeFunc(self.resize)
-        glutDisplayFunc(self.display)
-        glutKeyboardFunc(self.keyboardHandler)
-        glutSpecialFunc(self.funcKeydHandler)
-
-        self.pos = pos
-
-        self.time = time.time()
-
-        self.field = (100, 100, 1)
-        self.e_a = np.array([
-            [0, 0, 0],
-            [1, 0, 0],
-            [1, 1, 0],
-            [0, 1, 0],
-            [-1, 1, 0],
-            [-1, 0, 0],
-            [-1, -1, 0],
-            [0, -1, 0],
-            [1, -1, 0],
-        ])
-
-        self.relaxation_time = 0.55  # 0.55
-        self.w_a = [
-            4.0 / 9.0,
-            1.0 / 9.0,
-            1.0 / 36.0,
-            1.0 / 9.0,
-            1.0 / 36.0,
-            1.0 / 9.0,
-            1.0 / 36.0,
-            1.0 / 9.0,
-            1.0 / 36.0
-        ]
-
-        self.n_a = np.zeros((len(self.e_a),) + self.field)
-        self.n_a_eq = np.zeros(self.n_a.shape)
-        self.n = np.zeros(self.field)
-        self.n[:, :, :] += 1.0
-        self.gravity_applies = np.zeros(self.field)
-        # self.n /= np.sum(self.n)
-        self.n_a[0] = np.array(self.n)
-        self.u = np.zeros(self.field + (self.e_a.shape[1],))
-
-        self.compressible = True
-        self.max_n = self.w_a[0]
-
-        self.test_pixel = [40, 50, 0]
-
-        if not test:
-            glutMainLoop()
-        else:
-            self.display()
-            self.resize(100, 100)
 
     def display(self):
         glClearColor(0, 0, 0, 0)
@@ -261,41 +195,7 @@ class Client:
 
         glutSwapBuffers()
 
-        min_value = 0
-        max_value_n = np.max(self.n)
-        # max_value_n = 1.0
-
-        vel = np.sqrt(np.sum(np.square(self.u), axis=3)) *self.n
-        max_value_vel = np.max(vel)
-        # max_value_vel = np.sqrt(3)
-
-        # print('round')
-        # print('sum n: %f' % np.sum(self.n))
-        # print('max n: %f' % np.max(self.n))
-        # print('min n: %f' % np.min(self.n))
-        # print('sum vel: %f' % np.sum(vel))
-        # print('max vel: %f' % np.max(vel))
-        # print('min vel: %f' % np.min(vel))
-
-        # for x_pos in range(0, 100):
-        #     for y_pos in range(0, 100):
-        #         for z_pos in range(0, 1):
-        #             # if self.state == 2:
-        #             #     r, g, b = value_to_color(int(self.gravity_applies[x_pos, y_pos, z_pos]), 0, 1)
-        #             # if self.state == 1:
-        #             #     r, g, b = value_to_color(vel[x_pos, y_pos, z_pos], min_value, max_value_vel)
-        #             # if self.state == 0:
-        #             #     r, g, b = value_to_color(self.n[x_pos, y_pos, z_pos], min_value, max_value_n)
-        #             r, g, b, = 128, 128, 128
-        #             if [x_pos, y_pos] in self.world_provider.world.fault_nodes:
-        #                 r, g, b = 128, 0, 0
-        #
-        #             self.world_provider.world.set_color(x_pos, y_pos, z_pos, r, g, b)
-        # self.world_provider.world.set_color(int(round(self.test_pixel[0])),
-        #                                     int(round(self.test_pixel[1])),
-        #                                     int(round(self.test_pixel[2])), 1.0, 1.0, 1.0)
-
-        print(1.0 / (time.time() - self.time))
+        print('fps', 1.0 / (time.time() - self.time))
         self.time = time.time()
         glutPostRedisplay()
 
diff --git a/Objects/Structure.py b/Objects/Structure.py
index 5c49893..02cd5a4 100644
--- a/Objects/Structure.py
+++ b/Objects/Structure.py
@@ -19,6 +19,9 @@ class Structure(Renderable):
         self.Objects = {}
         self.vais = {}
         self.dirty = True
+        self.dirty_pos = True
+        self.dirty_color = True
+        self.dirty_size = True
 
         self.x_offset = x_offset
         self.y_offset = y_offset
@@ -31,6 +34,7 @@ class Structure(Renderable):
     @x_offset.setter
     def x_offset(self, value):
         self.dirty = True
+        self.dirty_pos = True
         self._x_offset = value
 
     @property
@@ -40,6 +44,7 @@ class Structure(Renderable):
     @y_offset.setter
     def y_offset(self, value):
         self.dirty = True
+        self.dirty_pos = True
         self._y_offset = value
 
     @property
@@ -49,6 +54,7 @@ class Structure(Renderable):
     @z_offset.setter
     def z_offset(self, value):
         self.dirty = True
+        self.dirty_pos = True
         self._z_offset = value
 
     def addShape(self, program, shape):
@@ -56,6 +62,9 @@ class Structure(Renderable):
             self.Objects[program] = []
         self.Objects[program].append(shape)
         self.dirty = True
+        self.dirty_color = True
+        self.dirty_pos = True
+        self.dirty_size = True
 
     def removeShape(self, program, shape):
         if program in self.Objects.keys():
@@ -63,72 +72,89 @@ class Structure(Renderable):
             if len(self.Objects[program]) == 0:
                 self.Objects.pop(program)
         self.dirty = True
+        self.dirty_color = True
+        self.dirty_pos = True
+        self.dirty_size = True
 
     def buildvertexArrays(self):
         if self.dirty:
-            self.clearVertexArrays()
+            # self.clearVertexArrays()
             glEnableClientState(GL_VERTEX_ARRAY)
             glEnableClientState(GL_TEXTURE_COORD_ARRAY)
             glEnableClientState(GL_NORMAL_ARRAY)
             glEnableClientState(GL_COLOR_ARRAY)
-            self.vais = {}
 
             for key, objects in self.Objects.items():
-                tvai = GLuint(0)
-                tpbi = GLuint(0)
-                tcbi = GLuint(0)
-                tsbi = GLuint(0)
-                num = len(objects)
-
-                glGenVertexArrays(1, tvai)
+                needs_new_buffers = key not in self.vais.keys()
+                if needs_new_buffers:
+                    tvai = GLuint(0)
+                    tpbi = GLuint(0)
+                    tcbi = GLuint(0)
+                    tsbi = GLuint(0)
+                    num = len(objects)
+                else:
+                    tvai, tpbi, tcbi, tsbi, num = self.vais[key]
+                if needs_new_buffers:
+                    glGenVertexArrays(1, tvai)
                 glBindVertexArray(tvai)
+                if self.dirty_pos:
+                    if needs_new_buffers:
+                        vid = glGetAttribLocation(key, "in_position")
+                        glEnableVertexAttribArray(vid)
 
-                vid = glGetAttribLocation(key, "in_position")
-                glEnableVertexAttribArray(vid)
-
-                tpbi = glGenBuffers(1)
-                glBindBuffer(GL_ARRAY_BUFFER, tpbi)
-                positions = []
-                for o in objects:
-                    positions.append(o.pos[0] + self.x_offset)
-                    positions.append(o.pos[1] + self.y_offset)
-                    positions.append(o.pos[2] + self.z_offset)
-                glBufferData(GL_ARRAY_BUFFER, np.array(positions, dtype=np.float32), GL_STATIC_DRAW)
-                glVertexAttribPointer(vid, 3, GL_FLOAT, GL_FALSE, 0, None)
-                self.check_error("Could not create position buffer")
-
-                colors = []
-                for o in objects:
-                    colors.append(o.color[0])
-                    colors.append(o.color[1])
-                    colors.append(o.color[2])
-                tcbi = glGenBuffers(1)
-                glBindBuffer(GL_ARRAY_BUFFER, tcbi)
-                glBufferData(GL_ARRAY_BUFFER, np.array(colors, dtype=np.float32), GL_STATIC_DRAW)
-                vc = glGetAttribLocation(key, "MyInColor")
-                if vc != -1:
-                    glEnableVertexAttribArray(vc)
-                    glVertexAttribPointer(vc, 3, GL_FLOAT, GL_FALSE, 0, None)
-                    self.check_error("Could not create color buffer")
-
-                if hasattr(objects[0], 'size'):
-                    sizes = []
+                    tpbi = glGenBuffers(1)
+                    glBindBuffer(GL_ARRAY_BUFFER, tpbi)
+                    positions = []
                     for o in objects:
-                        sizes.append(o.size[0])
-                        sizes.append(o.size[1])
-                        sizes.append(o.size[2])
-                    tsbi = glGenBuffers(1)
-                    glBindBuffer(GL_ARRAY_BUFFER, tsbi)
-                    glBufferData(GL_ARRAY_BUFFER, np.array(sizes, dtype=np.float32), GL_STATIC_DRAW)
-                    vs = glGetAttribLocation(key, "MyInSize")
-                    if vs != -1:
-                        glEnableVertexAttribArray(vs)
-                        glVertexAttribPointer(vs, 3, GL_FLOAT, GL_FALSE, 0, None)
-                        self.check_error("Could not create size buffer")
+                        positions.append(o.pos[0] + self.x_offset)
+                        positions.append(o.pos[1] + self.y_offset)
+                        positions.append(o.pos[2] + self.z_offset)
+                    glBufferData(GL_ARRAY_BUFFER, np.array(positions, dtype=np.float32), GL_STATIC_DRAW)
+                    if needs_new_buffers:
+                        glVertexAttribPointer(vid, 3, GL_FLOAT, GL_FALSE, 0, None)
+                    self.check_error("Could not create position buffer")
+
+                if self.dirty_color:
+                    colors = []
+                    for o in objects:
+                        colors.append(o.color[0])
+                        colors.append(o.color[1])
+                        colors.append(o.color[2])
+                    if needs_new_buffers:
+                        tcbi = glGenBuffers(1)
+                    glBindBuffer(GL_ARRAY_BUFFER, tcbi)
+                    glBufferData(GL_ARRAY_BUFFER, np.array(colors, dtype=np.float32), GL_STATIC_DRAW)
+                    if needs_new_buffers:
+                        vc = glGetAttribLocation(key, "MyInColor")
+                        if vc != -1:
+                            glEnableVertexAttribArray(vc)
+                            glVertexAttribPointer(vc, 3, GL_FLOAT, GL_FALSE, 0, None)
+                            self.check_error("Could not create color buffer")
+
+                if self.dirty_size:
+                    if hasattr(objects[0], 'size'):
+                        sizes = []
+                        for o in objects:
+                            sizes.append(o.size[0])
+                            sizes.append(o.size[1])
+                            sizes.append(o.size[2])
+                        if needs_new_buffers:
+                            tsbi = glGenBuffers(1)
+                        glBindBuffer(GL_ARRAY_BUFFER, tsbi)
+                        glBufferData(GL_ARRAY_BUFFER, np.array(sizes, dtype=np.float32), GL_STATIC_DRAW)
+                        if needs_new_buffers:
+                            vs = glGetAttribLocation(key, "MyInSize")
+                            if vs != -1:
+                                glEnableVertexAttribArray(vs)
+                                glVertexAttribPointer(vs, 3, GL_FLOAT, GL_FALSE, 0, None)
+                                self.check_error("Could not create size buffer")
 
                 glBindVertexArray(0)
                 self.vais[key] = (tvai, tpbi, tcbi, tsbi, num)
             self.dirty = False
+            self.dirty_pos = False
+            self.dirty_color = False
+            self.dirty_size = False
 
     def clearVertexArrays(self):
         temp = dict(self.vais)
diff --git a/Objects/World.py b/Objects/World.py
index 473ccf2..1f0d4a3 100644
--- a/Objects/World.py
+++ b/Objects/World.py
@@ -1,3 +1,5 @@
+import time
+
 from Lights.Lights import Light
 from Objects.Objects import Object
 from Objects.Renderable import Renderable
@@ -9,7 +11,8 @@ import math
 import numpy as np
 import random
 import sys
-
+import ctypes
+float_pointer = ctypes.POINTER(ctypes.c_float)
 # Plate Types
 SEA_PLATE = 0
 CONTINENTAL_PLATE = 1
@@ -22,6 +25,7 @@ METAMORPH_STONE = 3
 SEDIMENTAL_STONE = 4
 SEDIMENT = 5
 
+
 class WorldChunk(Structure):
     def __init__(self, width: int, length: int, height: int, programs: dict):
         assert width > 0, 'Width must be greater than 0'
@@ -38,6 +42,8 @@ class WorldChunk(Structure):
         self.height = height
         self.programs = programs
 
+        self.objects = {}
+
         for x in range(width):
             self.content.append([])
             self.visible.append([])
@@ -54,6 +60,7 @@ class WorldChunk(Structure):
         assert 0 <= z < self.height, 'Put out of bounds for z coordinate! Must be between 0 and %i' % self.height
         no_visibility_changes = (self.content[x][y][z] is None) == (new_object is None)
 
+        old_object = self.content[x][y][z]
         self.content[x][y][z] = new_object
         new_object.translate(translate(x, y, z))
 
@@ -87,6 +94,32 @@ class WorldChunk(Structure):
             else:
                 self.visible[x][y][z - 1] += change
 
+        # todo: add visibility check for object listing
+        added = False
+        if old_object is not None:
+            if new_object is not None and type(old_object) == type(new_object):
+                new_object.buffer_id = old_object.buffer_id
+                self.objects[self.programs[type(old_object)]][old_object.buffer_id] = new_object
+                added = True
+            else:
+                # todo: maybe replace the element with a placeholder that is skipped when rendering/ saving and have a
+                #  cleanup task, since this could be exploited to lower update rates
+                leading = self.objects[self.programs[type(old_object)]][:old_object.buffer_id]
+                following = self.objects[self.programs[type(old_object)]][old_object.buffer_id + 1:]
+                for element in following:
+                    element.buffer_id -= 1
+                self.objects[self.programs[type(old_object)]] = leading + following
+
+        if not added and new_object is not None:
+            if self.programs[type(new_object)] not in self.objects.keys():
+                self.objects[self.programs[type(new_object)]] = []
+            new_object.buffer_id = len(self.objects[self.programs[type(new_object)]])
+            self.objects[self.programs[type(new_object)]].append(new_object)
+
+        self.dirty = True
+        self.dirty_pos = True
+        self.dirty_color = True
+        self.dirty_size = True
         return visible_carry_over
 
     def get_object(self, x: int, y: int, z: int):
@@ -112,80 +145,92 @@ class WorldChunk(Structure):
 
     def buildvertexArrays(self):
         if self.dirty:
-            self.clearVertexArrays()
+            # self.clearVertexArrays()
             glEnableClientState(GL_VERTEX_ARRAY)
             glEnableClientState(GL_TEXTURE_COORD_ARRAY)
             glEnableClientState(GL_NORMAL_ARRAY)
             glEnableClientState(GL_COLOR_ARRAY)
-            self.vais = {}
 
-            objects = {}
-            counts = {}
-            for x in range(self.width):
-                for y in range(self.length):
-                    for z in range(self.height):
-                        if self.content[x][y][z] is not None:  # and self.visible[x][y][z] > 0: TODO: check visibility...
-                            if self.programs[type(self.content[x][y][z])] not in objects.keys():
-                                objects[self.programs[type(self.content[x][y][z])]] = []
-                                counts[self.programs[type(self.content[x][y][z])]] = 0
-                            objects[self.programs[type(self.content[x][y][z])]].append(self.content[x][y][z])
-                            counts[self.programs[type(self.content[x][y][z])]] += 1
+            for key, object_list in self.objects.items():
+                needs_new_buffers = key not in self.vais.keys()
+                if needs_new_buffers:
+                    tvai = GLuint(0)
+                    tpbi = GLuint(0)
+                    tcbi = GLuint(0)
+                    tsbi = GLuint(0)
 
-            for key, object_list in objects.items():
-                tvai = GLuint(0)
-                tpbi = GLuint(0)
-                tcbi = GLuint(0)
-                tsbi = GLuint(0)
-
-                glGenVertexArrays(1, tvai)
+                    glGenVertexArrays(1, tvai)
+                else:
+                    tvai, tpbi, tcbi, tsbi, old_len = self.vais[key]
                 glBindVertexArray(tvai)
 
-                vid = glGetAttribLocation(key, "in_position")
-                glEnableVertexAttribArray(vid)
+                if self.dirty_pos:
+                    if needs_new_buffers:
+                        vid = glGetAttribLocation(key, "in_position")
+                        glEnableVertexAttribArray(vid)
+                        tpbi = glGenBuffers(1)
+                    glBindBuffer(GL_ARRAY_BUFFER, tpbi)
+                    positions = []
+                    for index, o in enumerate(object_list):
+                        o.buffer_id = index
+                        positions.append(o.pos[0] + self.x_offset)
+                        positions.append(o.pos[1] + self.y_offset)
+                        positions.append(o.pos[2] + self.z_offset)
 
-                tpbi = glGenBuffers(1)
-                glBindBuffer(GL_ARRAY_BUFFER, tpbi)
-                positions = []
-                for o in object_list:
-                    positions.append(o.pos[0] + self.x_offset)
-                    positions.append(o.pos[1] + self.y_offset)
-                    positions.append(o.pos[2] + self.z_offset)
-                glBufferData(GL_ARRAY_BUFFER, np.array(positions, dtype=np.float32), GL_STATIC_DRAW)
-                glVertexAttribPointer(vid, 3, GL_FLOAT, GL_FALSE, 0, None)
-                self.check_error("Could not create position buffer")
+                    glBufferData(GL_ARRAY_BUFFER, np.array(positions, dtype=np.float32), GL_STATIC_DRAW)
 
-                colors = []
-                for o in object_list:
-                    colors.append(o.color[0])
-                    colors.append(o.color[1])
-                    colors.append(o.color[2])
-                tcbi = glGenBuffers(1)
-                glBindBuffer(GL_ARRAY_BUFFER, tcbi)
-                glBufferData(GL_ARRAY_BUFFER, np.array(colors, dtype=np.float32), GL_STATIC_DRAW)
-                vc = glGetAttribLocation(key, "MyInColor")
-                if vc != -1:
-                    glEnableVertexAttribArray(vc)
-                    glVertexAttribPointer(vc, 3, GL_FLOAT, GL_FALSE, 0, None)
+                    if needs_new_buffers:
+                        glVertexAttribPointer(vid, 3, GL_FLOAT, GL_FALSE, 0, None)
+                    self.check_error("Could not create position buffer")
+
+                if self.dirty_color:
+                    colors = []
+                    for o in object_list:
+                        colors.append(o.color[0])
+                        colors.append(o.color[1])
+                        colors.append(o.color[2])
+                    if needs_new_buffers:
+                        tcbi = glGenBuffers(1)
+                    glBindBuffer(GL_ARRAY_BUFFER, tcbi)
+                    if needs_new_buffers or old_len != len(object_list):
+                        glBufferData(GL_ARRAY_BUFFER, np.array(colors, dtype=np.float32), GL_STATIC_DRAW)
+                    else:
+                        # todo: check if this improves anything. Timewise it seems to be the same
+                        ptr = ctypes.cast(glMapBuffer(GL_ARRAY_BUFFER, GL_READ_WRITE), float_pointer)
+                        for index, value in enumerate(colors):
+                            ptr[index] = value
+                        glUnmapBuffer(GL_ARRAY_BUFFER)
+                    if needs_new_buffers:
+                        vc = glGetAttribLocation(key, "MyInColor")
+                        if vc != -1:
+                            glEnableVertexAttribArray(vc)
+                            glVertexAttribPointer(vc, 3, GL_FLOAT, GL_FALSE, 0, None)
                     self.check_error("Could not create color buffer")
 
-                if hasattr(object_list[0], 'size'):
-                    sizes = []
-                    for o in object_list:
-                        sizes.append(o.size[0])
-                        sizes.append(o.size[1])
-                        sizes.append(o.size[2])
-                    tsbi = glGenBuffers(1)
-                    glBindBuffer(GL_ARRAY_BUFFER, tsbi)
-                    glBufferData(GL_ARRAY_BUFFER, np.array(sizes, dtype=np.float32), GL_STATIC_DRAW)
-                    vs = glGetAttribLocation(key, "MyInSize")
-                    if vs != -1:
-                        glEnableVertexAttribArray(vs)
-                        glVertexAttribPointer(vs, 3, GL_FLOAT, GL_FALSE, 0, None)
+                if self.dirty_size:
+                    if hasattr(object_list[0], 'size'):
+                        sizes = []
+                        for o in object_list:
+                            sizes.append(o.size[0])
+                            sizes.append(o.size[1])
+                            sizes.append(o.size[2])
+                        if needs_new_buffers:
+                            tsbi = glGenBuffers(1)
+                        glBindBuffer(GL_ARRAY_BUFFER, tsbi)
+                        glBufferData(GL_ARRAY_BUFFER, np.array(sizes, dtype=np.float32), GL_STATIC_DRAW)
+                        if needs_new_buffers:
+                            vs = glGetAttribLocation(key, "MyInSize")
+                            if vs != -1:
+                                glEnableVertexAttribArray(vs)
+                                glVertexAttribPointer(vs, 3, GL_FLOAT, GL_FALSE, 0, None)
                         self.check_error("Could not create size buffer")
 
                 glBindVertexArray(0)
-                self.vais[key] = (tvai, tpbi, tcbi, tsbi, counts[key])
+                self.vais[key] = (tvai, tpbi, tcbi, tsbi, len(object_list))
             self.dirty = False
+            self.dirty_pos = False
+            self.dirty_color = False
+            self.dirty_size = False
 
     def render(self, proj_matrix, geometry_rot_matrix, alternate_programs=None,
                preselected_program=None, projection_pos=None, rot_pos=None):
@@ -204,6 +249,17 @@ class WorldChunk(Structure):
         if self.content[x][y][z] is not None:
             self.content[x][y][z].setColor(r, g, b)
             self.dirty = True
+            self.dirty_color = True
+
+    def load(self):
+        for x in range(self.width):
+            for y in range(self.length):
+                for z in range(self.height):
+                    if self.content[x][y][z] is not None:  # and self.visible[x][y][z] > 0: TODO: check visibility...
+                        if self.programs[type(self.content[x][y][z])] not in self.objects.keys():
+                            self.objects[self.programs[type(self.content[x][y][z])]] = []
+                        self.objects[self.programs[type(self.content[x][y][z])]].append(self.content[x][y][z])
+
 
 class World(Renderable):
     def __init__(self, chunk_size_x: int, chunk_size_y: int, chunk_size_z: int,
@@ -488,6 +544,8 @@ class World(Renderable):
                                                               y % self.chunk_size_y,
                                                               z % self.chunk_size_z,
                                                               r, g, b)
+        else:
+            print('Changing color of nonexistant element!')
 
     def put_object(self, x: int, y: int, z: int, new_object: Object):
         x = x % (self.chunk_size_x * self.chunk_n_x)
diff --git a/WorldProvider/WorldProvider.py b/WorldProvider/WorldProvider.py
index a9629b7..1e6367c 100644
--- a/WorldProvider/WorldProvider.py
+++ b/WorldProvider/WorldProvider.py
@@ -2,8 +2,8 @@ from Objects.World import World
 
 
 class WorldProvider:
-    def __init__(self, programs):
-        self.world: World = World(10, 10, 10, 10, 10, 10, programs)
+    def __init__(self, programs, world_class=World):
+        self.world: World = world_class(10, 10, 10, 10, 10, 10, programs)
         self.world.generate()
 
     def update(self):
diff --git a/labirinth_ai/LabyrinthClient.py b/labirinth_ai/LabyrinthClient.py
new file mode 100644
index 0000000..fdcb22e
--- /dev/null
+++ b/labirinth_ai/LabyrinthClient.py
@@ -0,0 +1,43 @@
+import time
+
+from Client.Client import Client, MAX_DISTANCE
+from MatrixStuff.Transformations import perspectiveMatrix
+from labirinth_ai.LabyrinthProvider import LabyrinthProvider
+
+import numpy as np
+
+class LabyrinthClient(Client):
+    def __init__(self, test=False, pos=[0, 0, 0], world_class=LabyrinthProvider):
+        super(LabyrinthClient, self).__init__(test, pos, world_class)
+
+    def draw_world(self):
+        start_time = time.time()
+        for x in range(self.world_provider.world.chunk_size_x * self.world_provider.world.chunk_n_x):
+            for y in range(self.world_provider.world.chunk_size_y * self.world_provider.world.chunk_n_y):
+                if self.world_provider.world.board[x, y] in [1, 2]:
+                    r, g, b = 57, 92, 152
+                    if 1.5 >= self.world_provider.world.hunter_grass[x, y] > 0.5:
+                        r, g, b = 25, 149, 156
+                    if 3 >= self.world_provider.world.hunter_grass[x, y] > 1.5:
+                        r, g, b = 112, 198, 169
+                    self.world_provider.world.set_color(x, y, 0, r / 255.0, g / 255.0, b / 255.0)
+                if self.world_provider.world.board[x, y] == 3:
+                    self.world_provider.world.set_color(x, y, 0, 139 / 255.0, 72 / 255.0, 82 / 255.0)
+
+        for sub in self.world_provider.world.subjects:
+            if not sub.random:
+                # pyxel.rectb(sub.x * 4 + 1, sub.y * 4 + 1, 2, 2, sub.col)
+                self.world_provider.world.set_color(sub.x, sub.y, 0, sub.r / 255.0, sub.g / 255.0, sub.b / 255.0)
+            else:
+                self.world_provider.world.set_color(sub.x, sub.y, 0, 212 / 255.0, 150 / 255.0, 222 / 255.0)
+
+        self.projMatrix = perspectiveMatrix(45.0, 400 / 400, 0.01, MAX_DISTANCE)
+        print('redraw', time.time() - start_time)
+
+    def display(self):
+        super(LabyrinthClient, self).display()
+        self.draw_world()
+        self.world_provider.world.update()
+
+if __name__ == '__main__':
+    client = LabyrinthClient(pos=[-50, -50, -200])
diff --git a/labirinth_ai/LabyrinthProvider.py b/labirinth_ai/LabyrinthProvider.py
new file mode 100644
index 0000000..4af8345
--- /dev/null
+++ b/labirinth_ai/LabyrinthProvider.py
@@ -0,0 +1,6 @@
+from WorldProvider.WorldProvider import WorldProvider
+from labirinth_ai.LabyrinthWorld import LabyrinthWorld
+
+class LabyrinthProvider(WorldProvider):
+    def __init__(self, programs):
+        super(LabyrinthProvider, self).__init__(programs, LabyrinthWorld)
diff --git a/labirinth_ai/LabyrinthWorld.py b/labirinth_ai/LabyrinthWorld.py
new file mode 100644
index 0000000..2a2e3e7
--- /dev/null
+++ b/labirinth_ai/LabyrinthWorld.py
@@ -0,0 +1,232 @@
+import time
+
+from Objects.Cube.Cube import Cube
+from Objects.World import World
+import numpy as np
+import random
+
+
+class LabyrinthWorld(World):
+    randomBuffer = 0
+    batchsize = 1000
+    randomBuffer = max(4 * batchsize, randomBuffer)
+
+    def __init__(self, chunk_size_x: int, chunk_size_y: int, chunk_size_z: int,
+                 chunk_n_x: int, chunk_n_y: int, chunk_n_z: int, programs: dict):
+        self.board_shape = (chunk_size_x * chunk_n_x, chunk_size_y * chunk_n_y)
+        self.board = np.zeros(self.board_shape)
+        super(LabyrinthWorld, self).__init__(chunk_size_x, chunk_size_y, chunk_size_z,
+                                             chunk_n_x, chunk_n_y, chunk_n_z, programs)
+        self.max_room_dim = 20
+
+        self.min_room_dim = 6
+
+        self.max_room_num = 32
+        self.max_corridors = 4 * self.max_room_num
+
+        self.max_crates = self.max_room_num
+
+        self.subjects = []
+        self.ins = []
+        self.actions = []
+        self.targets = []
+
+        self.model = None
+        self.lastUpdate = time.time()
+        self.nextTrain = self.randomBuffer
+        self.round = 0
+
+        self.trailMix = np.zeros(self.board_shape)
+        self.grass = np.zeros(self.board_shape)
+        self.hunter_grass = np.zeros(self.board_shape)
+        self.subjectDict = {}
+
+    def generate(self, seed: int = None, sea_plate_height: int = 50, continental_plate_height: int = 200):
+        board = np.zeros(self.board_shape)
+        random.seed(seed)
+        np.random.seed(seed)
+
+        # find random starting point
+        px = random.randint(self.max_room_dim, (self.board_shape[0] - 1) - self.max_room_dim)
+        py = random.randint(self.max_room_dim, (self.board_shape[1] - 1) - self.max_room_dim)
+
+        # 0, 0 is top left
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+
+        # place rooms
+        room_num = 0
+        corridor_num = 0
+        while room_num < self.max_room_num and corridor_num < self.max_corridors:
+            # try to place Room
+            w = random.randint(self.min_room_dim, self.max_room_dim)
+            h = random.randint(self.min_room_dim, self.max_room_dim)
+            can_place_room = np.sum(
+                board[px - int(w / 2.0):px + int(w / 2.0), py - int(h / 2.0):py + int(h / 2.0)] == 1) == 0 and px - int(
+                w / 2.0) >= 0 and px + int(w / 2.0) < self.board_shape[0] and \
+                             py - int(h / 2.0) >= 0 and py + int(h / 2.0) < self.board_shape[1]
+
+            if can_place_room:
+                # place Room
+                board[px - int(w / 2.0):px + int(w / 2.0), py - int(h / 2.0):py + int(h / 2.0)] = 1
+                room_num += 1
+            else:
+                # move && place Corridor
+                directions = []
+                while len(directions) == 0:
+                    movable = []
+                    corridor_length = random.randint(self.min_room_dim, self.max_room_dim)
+                    if px - corridor_length >= 0:
+                        movable.append(left)
+                        if board[px - 1, py] != 2:
+                            directions.append(left)
+
+                    if px + corridor_length < self.board_shape[0]:
+                        movable.append(right)
+                        if board[px + 1, py] != 2:
+                            directions.append(right)
+
+                    if py - corridor_length >= 0:
+                        movable.append(up)
+                        if board[px, py - 1] != 2:
+                            directions.append(up)
+
+                    if py + corridor_length < self.board_shape[1]:
+                        movable.append(down)
+                        if board[px, py + 1] != 2:
+                            directions.append(down)
+
+                    if len(directions) != 0:
+                        if len(directions) > 1:
+                            d = directions[random.randint(0, len(directions) - 1)]
+                        else:
+                            d = directions[0]
+                        changed = False
+                        for _ in range(corridor_length):
+                            if board[px, py] != 1 and board[px, py] != 2:
+                                board[px, py] = 2
+                                if (-d[0], -d[1]) not in movable or board[px - d[0], py - d[1]] != 2:
+                                    changed = True
+                            px += d[0]
+                            py += d[1]
+                        if changed:
+                            corridor_num += 1
+                    else:
+                        if len(movable) != 0:
+                            if len(movable) > 1:
+                                d = movable[random.randint(0, len(movable) - 1)]
+                            else:
+                                d = movable[0]
+                            for _ in range(corridor_length):
+                                px += d[0]
+                                py += d[1]
+
+        crates = 0
+        while crates < self.max_crates:
+            px = random.randint(0, (self.board_shape[0] - 1))
+            py = random.randint(0, (self.board_shape[1] - 1))
+
+            if board[px, py] == 1:
+                board[px, py] = 3
+                crates += 1
+
+        board[board == 2] = 1
+
+        print((room_num, self.max_room_num))
+        print((corridor_num, self.max_corridors))
+        self.board = board
+
+        # setting up the board
+        for x_pos in range(0, self.board_shape[0]):
+            for y_pos in range(0, self.board_shape[1]):
+                for z_pos in range(0, 1):
+                    self.put_object(x_pos, y_pos, z_pos, Cube().setColor(1, 1, 1))
+
+        # adding subjects
+        from labirinth_ai.Subject import Hunter, Herbivore
+        while len(self.subjects) < 2:
+            px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
+            py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
+            if self.board[px, py] == 1:
+                self.subjects.append(Hunter(px, py))
+                self.ins += self.subjects[-1].x_in
+                self.actions += self.subjects[-1].actions
+                self.targets += self.subjects[-1].target
+
+        while len(self.subjects) < 10:
+            px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
+            py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
+            if self.board[px, py] == 1:
+                self.subjects.append(Herbivore(px, py))
+                self.ins += self.subjects[-1].x_in
+                self.actions += self.subjects[-1].actions
+                self.targets += self.subjects[-1].target
+
+        for x in range(self.board_shape[0]):
+            for y in range(self.board_shape[1]):
+                self.subjectDict[(x, y)] = []
+
+        for sub in self.subjects:
+            self.subjectDict[(sub.x, sub.y)].append(sub)
+
+    def update(self):
+        # start = time.time()
+        if self.model is None:
+            for sub in self.subjects:
+                sub.calculateAction(self)
+        else:
+            states = list(map(lambda e: e.createState(self), self.subjects))
+            states = sum(list(map(lambda e: [e, e, e, e], states)), [])
+            vals = self.model.predict(states)
+            vals = np.reshape(np.transpose(np.reshape(vals, (len(self.subjects), 4, 2)), (0, 2, 1)),
+                              (len(self.subjects), 1, 8))
+            list(map(lambda e: e[1].calculateAction(self, vals[e[0]], states[e[0]]), enumerate(self.subjects)))
+
+        for sub in self.subjects:
+            if sub.alive:
+                sub.update(self, doTrain=self.model is None)
+            sub.tick += 1
+
+        if self.model is not None:
+            if self.round >= self.nextTrain:
+                samples = list(map(lambda e: e.generateSamples(), self.subjects))
+                states = sum(list(map(lambda e: e[0], samples)), [])
+                targets = sum(list(map(lambda e: e[1], samples)), [])
+                self.model.fit(states, targets)
+                self.nextTrain = self.batchsize / 5
+                self.round = 0
+                for sub in self.subjects:
+                    if len(sub.samples) > 20*self.batchsize:
+                        sub.samples = sub.samples[:-20*self.batchsize]
+            else:
+                self.round += 1
+
+        new_subjects = []
+        kill_table = {}
+        live_table = {}
+        for sub in self.subjects:
+            if sub.name not in kill_table.keys():
+                kill_table[sub.name] = 0
+                live_table[sub.name] = 0
+            kill_table[sub.name] += sub.kills
+            live_table[sub.name] += sub.lives
+            if sub.alive:
+                new_subjects.append(sub)
+            else:
+                px = random.randint(self.max_room_dim, (self.board_shape[0] - 1) - self.max_room_dim)
+                py = random.randint(self.max_room_dim, (self.board_shape[1] - 1) - self.max_room_dim)
+                while self.board[px, py] == 0:
+                    px = random.randint(self.max_room_dim, (self.board_shape[0] - 1) - self.max_room_dim)
+                    py = random.randint(self.max_room_dim, (self.board_shape[1] - 1) - self.max_room_dim)
+                sub.respawnUpdate(px, py, self)
+                new_subjects.append(sub)
+
+        self.subjects = new_subjects
+        self.trailMix *= 0.99
+
+        self.grass = np.minimum(self.grass + 0.01 * (self.board != 0), 3)
+        self.hunter_grass = np.minimum(self.hunter_grass + 0.01 * (self.board != 0), 3)
+
+        self.trailMix *= (self.trailMix > 0.01)
diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py
new file mode 100644
index 0000000..ec0593c
--- /dev/null
+++ b/labirinth_ai/Subject.py
@@ -0,0 +1,1055 @@
+import random
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+
+from labirinth_ai.LabyrinthWorld import LabyrinthWorld
+from labirinth_ai.loss import loss2, loss3
+
+# import torch
+# dtype = torch.float
+# device = torch.device("cpu")
+
+
+class Subject:
+    name = 'random'
+    col = 8
+    num = 0
+    random = True
+    r = 255
+    g = 255
+    b = 255
+
+    def __init__(self, x, y):
+        self.alive = True
+        self.x = x
+        self.y = y
+        self.kills = 0
+        self.lives = 1
+        self.tick = 0
+
+        self.id = self.num
+        Subject.num += 1
+
+    def update(self, world: LabyrinthWorld):
+        # 0, 0 is top left
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+
+        if directions != [] and self.alive:
+            if len(directions) > 1:
+                d = directions[random.randint(0, len(directions) - 1)]
+            else:
+                d = directions[0]
+
+            if len(world.subjectDict[(self.x + d[0], self.y + d[1])]) > 0:
+                for sub in world.subjectDict[(self.x + d[0], self.y + d[1])]:
+                    if sub.alive:
+                        self.kills += 1
+                    sub.alive = False
+                    self.alive = True
+
+            world.subjectDict[(self.x, self.y)].remove(self)
+            world.trailMix[self.x, self.y] += 1
+            self.x += d[0]
+            self.y += d[1]
+            world.subjectDict[(self.x, self.y)].append(self)
+
+    def respawnUpdate(self, x, y, world: LabyrinthWorld):
+        world.subjectDict[(self.x, self.y)].remove(self)
+        self.x = x
+        self.y = y
+        world.subjectDict[(self.x, self.y)].append(self)
+        self.alive = True
+        self.lives += 1
+
+
+class QLearner(Subject):
+    name = 'QLearner'
+    col = 14
+    learningRate = 0.25
+    discountFactor = 0.5
+    random = False
+
+    Q = {}
+    def __init__(self, x, y):
+        super(QLearner, self).__init__(x, y)
+        # self.Q = {}
+        self.viewD = 3
+        self.lastAction = None
+        self.lastState = None
+        self.lastReward = 0
+
+    def respawnUpdate(self, x, y, world: LabyrinthWorld):
+        super(QLearner, self).respawnUpdate(x, y, world)
+        self.lastReward -= 20
+
+    def createState(self, world: LabyrinthWorld):
+        state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.int)  # - 1
+
+        # # floodfill state
+        # queued = [(0, 0)]
+        # todo = [(0, 0, 0)]
+        # while todo != []:
+        #     doing = todo.pop(0)
+        #
+        #     if self.x + doing[0] >= 0 and self.x + doing[0] < 64 and self.y + doing[1] >= 0 and self.y + doing[1] < 64:
+        #         value = world.board[self.x + doing[0], self.y + doing[1]]
+        #         state[self.viewD + doing[0], self.viewD + doing[1]] = value
+        #
+        #         # if value == 3:
+        #         #     state[self.viewD + doing[0], self.viewD + doing[1]] = value
+        #
+        #         if value != 0 and doing[2] < self.viewD:
+        #             for i in range(-1, 2, 1):
+        #                 for j in range(-1, 2, 1):
+        #                     # 4-neighbour. without it it is 8-neighbour
+        #                     if abs(i) + abs(j) == 1:
+        #                         if (doing[0] + i, doing[1] + j) not in queued:
+        #                             queued.append((doing[0] + i, doing[1] + j))
+        #                             todo.append((doing[0] + i, doing[1] + j, doing[2] + 1))
+        #
+        # for sub in world.subjects:
+        #     if (sub.x - self.x, sub.y - self.y) in queued and state[
+        #         self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] != 3:
+        #         state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] = state[
+        #                                                                               self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] * 100 + sub.col
+
+        maxdirleft = self.x - max(self.x - (self.viewD), 0)
+        maxdirright = min(self.x + (self.viewD), (world.board_shape[0] - 1)) - self.x
+        maxdirup = self.y - max(self.y - (self.viewD), 0)
+        maxdirdown = min(self.y + (self.viewD), (world.board_shape[1] - 1)) - self.y
+
+        # state[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.board[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+        for sub in world.subjects:
+            if abs(sub.x - self.x) < self.viewD and abs(sub.y - self.y) < self.viewD:
+                if state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] != 3:
+                    state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] = state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] * 100 + 1# sub.col
+
+        return state
+
+    def update(self, world: LabyrinthWorld):
+        # 0, 0 is top left
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+
+        if directions != [] and self.alive:
+            state = self.createState(world)
+
+            if str(state) not in self.Q.keys():
+                self.Q[str(state)] = {}
+            for dir in directions:
+                if dir not in self.Q[str(state)].keys():
+                    self.Q[str(state)][dir] = random.randint(0, 5)
+
+            allowedActions = dict(filter(lambda elem: elem[0] in directions,self.Q[str(state)].items()))
+            action = max(allowedActions, key=allowedActions.get)
+
+            if self.learningRate != 0:
+                self.Q[str(state)][action] = (1 - self.learningRate) * self.Q[str(state)][action] + self.learningRate * (self.lastReward + self.discountFactor * self.Q[str(state)][action])
+
+            self.lastAction = action
+            self.lastState = state
+            self.lastReward = 0
+
+            if len(action) == 2:
+                if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
+                        if sub.alive:
+                            self.kills += 1
+                        sub.alive = False
+                        self.alive = True
+                        self.lastReward += 10
+
+                world.subjectDict[(self.x, self.y)].remove(self)
+                self.x += action[0]
+                self.y += action[1]
+                world.subjectDict[(self.x, self.y)].append(self)
+            pass
+
+
+class DoubleQLearner(QLearner):
+    name = 'DoubleQLearner'
+    col = 11
+    learningRate = 0.5
+    discountFactor = 0.5
+    random = False
+
+    QA = {}
+    QB = {}
+    def __init__(self, x, y):
+        super(DoubleQLearner, self).__init__(x, y)
+        self.viewD = 3
+        self.lastAction = None
+        self.lastState = None
+        self.lastReward = 0
+
+    def respawnUpdate(self, x, y, world: LabyrinthWorld):
+        super(DoubleQLearner, self).respawnUpdate(x, y, world)
+
+    def update(self, world: LabyrinthWorld):
+        # 0, 0 is top left
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+
+        if directions != [] and self.alive:
+            state = self.createState(world)
+
+            if str(state) not in self.QA.keys():
+                self.QA[str(state)] = {}
+                self.QB[str(state)] = {}
+            for dir in directions:
+                if dir not in self.QA[str(state)].keys():
+                    self.QA[str(state)][dir] = random.randint(0, 5)
+                    self.QB[str(state)][dir] = random.randint(0, 5)
+
+            allowedActionsA = dict(filter(lambda elem: elem[0] in directions, self.QA[str(state)].items()))
+            allowedActionsB = dict(filter(lambda elem: elem[0] in directions, self.QB[str(state)].items()))
+            allowedActions = {}
+            for key in allowedActionsA.keys():
+                allowedActions[key] = allowedActionsA[key] + allowedActionsB[key]
+
+            actionA = max(allowedActionsA, key=allowedActionsA.get)
+            actionB = max(allowedActionsB, key=allowedActionsB.get)
+            action = max(allowedActions, key=allowedActions.get)
+
+            if self.learningRate != 0:
+                if random.randint(0, 1) == 0:
+                    valA = self.QA[str(state)][action]
+                    self.QA[str(state)][action] = valA + self.learningRate * (self.lastReward + self.discountFactor * self.QB[str(state)][actionA] - valA)
+                else:
+                    valB = self.QB[str(state)][action]
+                    self.QB[str(state)][action] = valB + self.learningRate * (self.lastReward + self.discountFactor * self.QA[str(state)][actionB] - valB)
+
+            self.lastAction = action
+            self.lastState = state
+            self.lastReward = 0
+
+            if len(action) == 2:
+                if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
+                        if sub.alive:
+                            self.kills += 1
+                        sub.alive = False
+                        self.alive = True
+                        self.lastReward += 10
+
+                world.subjectDict[(self.x, self.y)].remove(self)
+                self.x += action[0]
+                self.y += action[1]
+                world.subjectDict[(self.x, self.y)].append(self)
+            pass
+
+
+class NetLearner(Subject):
+    right = (1, 0)
+    left = (-1, 0)
+    up = (0, -1)
+    down = (0, 1)
+    act2IDict = {right: 0, left: 1, up: 2, down: 3}
+
+    name = 'NetLearner'
+    col = 15
+    viewD = 3
+    historyLength = 2
+    channels = 4
+
+    learningRate = 0.001
+    discountFactor = 0.5
+    randomBuffer = 0
+    batchsize = 1000
+    randomBuffer = max(4*batchsize, randomBuffer)
+    randomChance = 9
+
+    historySizeMul = 20
+
+    # samples = []
+
+    # x_in = keras.Input(shape=(4 * (2 * viewD + 1) * (2 * viewD + 1) + 2))
+    # target = keras.Input(shape=(10, 1))
+    # inVec = keras.layers.Flatten()(x_in)
+    # # kernel_regularizer=keras.regularizers.l2(0.01)
+    # actions = keras.layers.Dense((3 * (2 * viewD + 1) * (2 * viewD + 1)), activation='relu')(inVec)
+    # actions = keras.layers.Dense(((2 * viewD + 1) * (2 * viewD + 1)), activation='relu')(actions)
+    # actions = keras.layers.Dense(8, activation='linear', use_bias=False)(actions)
+    #
+    # model = keras.Model(inputs=x_in, outputs=actions)
+    #
+    # # model.compile(optimizer='adam', loss=loss, target_tensors=[target])
+    # model.compile(optimizer=tf.keras.optimizers.RMSprop(learningRate), loss=loss, target_tensors=[target])
+
+    def respawnUpdate(self, x, y, world: LabyrinthWorld):
+        super(NetLearner, self).respawnUpdate(x, y, world)
+        # self.lastReward -= 20
+
+        if len(self.samples) < self.randomBuffer or random.randint(0, 10) > self.randomChance:
+            self.random = True
+            # print('Rando ' + self.name)
+            pass
+        else:
+            self.random = False
+            # print('Slau ' + self.name)
+
+        self.strikes = 0
+
+    def __init__(self, x, y):
+        super(NetLearner, self).__init__(x, y)
+
+        self.action = None
+        self.state = None
+        self.actDict = {}
+
+        self.history = []
+        self.lastAction = None
+        self.lastState = None
+        self.lastReward = 0
+        self.lastVal = 0
+        self.random = False
+        self.nextTrain = self.randomBuffer
+
+        self.samples = []
+
+        self.x_in = []
+        self.actions = []
+        self.target = []
+        for i in range(4):
+            x_in = keras.Input(shape=(self.channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
+            self.x_in.append(x_in)
+            inVec = keras.layers.Flatten()(x_in)
+            actions = keras.layers.Dense(((2 * self.viewD + 1) * (2 * self.viewD + 1)), activation='elu',
+                                         kernel_regularizer=keras.regularizers.l2(0.001),
+                                         name=self.name + str(self.id) + 'Dense' + str(i) + 'l1')(inVec)
+            actions = keras.layers.Dense(((self.viewD + 1) * (self.viewD + 1)), activation='elu',
+                                         kernel_regularizer=keras.regularizers.l2(0.001))(actions)
+            self.target.append(keras.Input(shape=(2, 1)))
+            self.actions.append(keras.layers.Dense(2, activation='linear', use_bias=False, kernel_regularizer=keras.regularizers.l2(0.001))(actions))
+
+        self.model = keras.Model(inputs=self.x_in, outputs=self.actions)
+
+        self.model.compile(optimizer=tf.keras.optimizers.RMSprop(self.learningRate), loss=loss3,
+                           target_tensors=self.target)
+
+        if len(self.samples) < self.randomBuffer:
+            self.random = True
+        else:
+            self.random = False
+
+        self.strikes = 0
+
+        self.lastRewards = []
+
+    def visualize(self):
+        print(self.name)
+        layers = self.model.get_weights()
+        # layers.reverse()
+        layersN = [[0, 1, 8, 9, 16], [2, 3, 10, 11, 17], [4, 5, 12, 13, 18], [6, 7, 14, 15, 19]]
+        for action in range(8):
+            v = np.zeros((1, 2))
+            v[0][0 if action < 4 else 1] = 1.0
+            layerN = list(layersN[action % 4])
+            layerN.reverse()
+            for n in layerN:
+                l = layers[n]
+                if len(l.shape) == 2:
+                    layer = np.transpose(l)
+                    v = np.dot(v, layer)
+                else:
+                    layer = np.array([l])
+                    v = v + layer
+            lastAction = v[0, -2:]
+            v = np.reshape(v[0, :-2], (4, (2 * self.viewD + 1), (2 * self.viewD + 1)))
+
+            # right, left, up, down
+            dir = {0: 'right', 1: 'left', 2: 'up', 3: 'down'}
+            dir = dir[action % 4]
+            #0-3 current
+            #4-8 future
+            if action < 4:
+                time = 'current '
+            else:
+                time = 'future '
+            import matplotlib
+            import matplotlib.pyplot as plt
+            fig, axs = plt.subplots(2, 2, figsize=(5, 5))
+
+            fig.suptitle(time + dir)
+            im = axs[0, 0].pcolor(np.rot90(v[0]))
+            fig.colorbar(im, ax=axs[0, 0])
+            axs[0, 0].set_title('board')
+
+            axs[0, 1].pcolor(np.rot90(v[1]))
+            fig.colorbar(im, ax=axs[0, 1])
+            axs[0, 1].set_title('subjects')
+
+            axs[1, 0].pcolor(np.rot90(v[2]))
+            fig.colorbar(im, ax=axs[1, 0])
+            axs[1, 0].set_title('trail')
+
+            axs[1, 1].pcolor(np.rot90(v[3]))
+            fig.colorbar(im, ax=axs[1, 1])
+            axs[1, 1].set_title('grass')
+            plt.show(block=True)
+
+
+    def createState(self, world: LabyrinthWorld):
+        state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+        state2 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+        state3 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+        state4 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+
+        maxdirleft = self.x - max(self.x - (self.viewD), 0)
+        maxdirright = min(self.x + (self.viewD), (world.board_shape[0] - 1)) - self.x
+        maxdirup = self.y - max(self.y - (self.viewD), 0)
+        maxdirdown = min(self.y + (self.viewD), (world.board_shape[1] - 1)) - self.y
+
+        state[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.board[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+        # for sub in world.subjects:
+        #     if abs(sub.x - self.x) < self.viewD and abs(sub.y - self.y) < self.viewD:
+        #         if state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] != 3:
+        #             state2[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] = sub.col
+        for x in range(-maxdirleft, maxdirright, 1):
+            for y in range(-maxdirup, maxdirdown, 1):
+                if world.subjectDict[(self.x + x, self.y + y)] != []:
+                    state2[x + maxdirleft, y + maxdirup] = 1#world.subjectDict[(self.x + x, self.y + y)][0].col
+
+        state3[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.trailMix[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+        state4[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.hunter_grass[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+
+        if not self.random:
+            test=1
+
+        area = np.reshape(np.stack((state, state2, state3, state4)), (4 * (2 * self.viewD + 1) * (2 * self.viewD + 1)))
+        action = [0, 0]
+        if self.lastAction is not None:
+            action = self.lastAction
+        return np.reshape(np.concatenate((area, action)), (1, 4 * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
+
+    def calculateAction(self, world: LabyrinthWorld, vals=None, state=None):
+        # 0, 0 is top left
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(self.left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(self.right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(self.up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(self.down)
+
+        if directions == []:
+            print('Wut?')
+
+        if directions != [] and self.alive:
+            if state is None:
+                state = self.createState(world)
+            if vals is None:
+                vals = self.model.predict([state, state, state, state])
+                vals = np.reshape(np.transpose(np.reshape(vals, (4, 2)), (1, 0)),
+                                  (1, 8))
+
+            self.actDict = {self.right: vals[0][0] + vals[0][4], self.left: vals[0][1] + vals[0][5], self.up: vals[0][2] + vals[0][6], self.down: vals[0][3] + vals[0][7]}
+
+            allowedActions = dict(filter(lambda elem: elem[0] in directions, self.actDict.items()))
+
+            # if self.name == 'Herbivore' and self.id == 11 and not self.random:
+            #     print(allowedActions)
+            #     print(self.lastReward)
+            if self.strikes <= 0:
+                self.random = False
+
+            if not self.random:
+                self.action = max(allowedActions, key=allowedActions.get)
+            else:
+                self.action = self.randomAct(world)
+
+            self.state = state
+
+    def update(self, world: LabyrinthWorld, doTrain=True):
+        if self.lastAction is not None:
+            if not self.random:
+                if self.lastAction[0] + self.action[0] == 0 and self.lastAction[1] + self.action[1] == 0:
+                    self.strikes += 1
+                else:
+                    self.strikes -= 1
+                if self.strikes > 100:
+                    self.random = True
+            else:
+                self.strikes -= 1
+
+            if len(self.history) >= self.historyLength:
+                self.history.pop(0)
+            self.history.append((self.lastState.copy(), int(self.act2IDict[self.lastAction]), int(self.lastVal), float(self.lastReward), np.array(self.lastRewards)))
+
+            # if self.lastReward != 0 or random.randint(0, 9) == 0:
+            if len(self.history) == self.historyLength:
+                self.samples.append(self.history.copy())
+
+            # if len(self.samples) % self.batchsize == 0 and len(self.samples) >= self.randomBuffer:
+            if len(self.samples) > self.nextTrain and doTrain:
+                print('train')
+                self.train()
+                self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
+
+        self.lastAction = self.action
+        self.lastState = self.state
+        self.lastReward = 0
+        self.lastVal = self.actDict[self.action]
+
+        maxVal = 0
+
+        self.executeAction(world, self.action)
+
+    def randomAct(self, world: LabyrinthWorld):
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+
+        d = random.randint(0, len(directions) - 1)
+        action = directions[d]
+
+        return action
+
+    def executeAction(self, world: LabyrinthWorld, action):
+        pass
+
+    def generateSamples(self):
+        # history element: (self.lastState.copy(), self.act2IDict[self.lastAction], self.lastVal, self.lastReward, np.array(self.lastRewards))
+        # history: [t-2, t-1]
+        states = []
+        targets = []
+        for i in range(4):
+            true_batch = int(self.batchsize/4)
+            target = np.zeros((true_batch, 2, 1))
+            samples = np.array(self.samples[:-self.batchsize])
+            # print('Samples for ' + str(i))
+            # print(len(samples))
+            samples = np.array(list(filter(lambda e: e[0, 1] == i, list(samples))))
+            # print(len(samples))
+            partTwo = True
+            if len(samples) == 0:
+                print('No samples for:' + str(i))
+                partTwo = False
+                samples = np.array(self.samples[:-self.batchsize])
+            buffer_size = len(samples)
+            index = np.random.choice(np.arange(buffer_size),
+                                     size=true_batch,
+                                     replace=True)
+            samples = samples[index]
+            # self.samples = []
+            if partTwo:
+                target[:, 1, 0] = samples[:, 1, 3] #reward t-2 got
+
+                nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
+                nextVals = self.model.predict([nextState, nextState, nextState, nextState])
+
+                nextVals2 = nextVals[i][:,  0] + nextVals[i][:, 1]
+                target[:, 0, 0] = nextVals2 #best q t-1
+            else:
+                target[:, 1, 0] = np.array(list(map(lambda elem: list(elem), list(np.array(samples[:, 1, 4])))))[:, i]  # reward t-2 got
+
+            targets.append(target)
+
+            states.append(np.concatenate(samples[:, 0, 0])) #states of t-2
+
+        return states, targets
+
+    def train(self):
+        print(self.name)
+        states, target = self.generateSamples()
+        self.model.fit(states, target, epochs=1)
+
+        self.samples = self.samples[-self.historySizeMul*self.batchsize:]
+
+        # print(self.model.get_weights())
+
+        pass
+
+
+class Herbivore(NetLearner):
+    name = 'Herbivore'
+    col = 9
+    r = 255
+    g = 255
+    b = 0
+    viewD = 3
+    historyLength = 2
+
+    learningRate = 0.001
+    discountFactor = 0.5
+    randomBuffer = 0
+    batchsize = 1000
+    randomBuffer = max(2 * batchsize, randomBuffer)
+    randomChance = 9
+
+    samples = []
+
+    # x_in = keras.Input(shape=(4 * (2 * viewD + 1) * (2 * viewD + 1) + 2))
+    # target = keras.Input(shape=(10, 1))
+    # inVec = keras.layers.Flatten()(x_in)
+    # # kernel_regularizer=keras.regularizers.l2(0.01)
+    # actions = keras.layers.Dense((4 * (2 * viewD + 1) * (2 * viewD + 1)), activation='elu')(inVec)
+    # actions = keras.layers.Dense(((2 * viewD + 1) * (2 * viewD + 1)), activation='elu')(actions)
+    # actions = keras.layers.Dense(8, activation='linear', use_bias=False)(actions)
+    # # actions = keras.layers.Dense(4, activation='linear', use_bias=False)(inVec)
+    #
+    # model = keras.Model(inputs=x_in, outputs=actions)
+    #
+    # # model.compile(optimizer='adam', loss=loss2, target_tensors=[target])
+    # model.compile(optimizer=tf.keras.optimizers.RMSprop(learningRate), loss=loss2, target_tensors=[target])
+
+    # def __init__(self, x, y):
+    #     super(Herbivore, self).__init__(x, y)
+
+    def createState(self, world: LabyrinthWorld):
+        state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+        state2 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+        state3 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+        state4 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
+
+        maxdirleft = self.x - max(self.x - (self.viewD), 0)
+        maxdirright = min(self.x + (self.viewD), (world.board_shape[0] - 1)) - self.x
+        maxdirup = self.y - max(self.y - (self.viewD), 0)
+        maxdirdown = min(self.y + (self.viewD), (world.board_shape[1] - 1)) - self.y
+
+        state[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.board[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+        # for sub in world.subjects:
+        #     if abs(sub.x - self.x) < self.viewD and abs(sub.y - self.y) < self.viewD:
+        #         if state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] != 3:
+        #             state2[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] = sub.col
+        for x in range(-maxdirleft, maxdirright, 1):
+            for y in range(-maxdirup, maxdirdown, 1):
+                if world.subjectDict[(self.x + x, self.y + y)] != []:
+                    state2[x + maxdirleft, y + maxdirup] = 1#world.subjectDict[(self.x + x, self.y + y)][0].col
+
+        state3[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.trailMix[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+        state4[self.viewD - maxdirleft: self.viewD + maxdirright, self.viewD - maxdirup: self.viewD + maxdirdown] = world.grass[self.x - maxdirleft: self.x + maxdirright, self.y - maxdirup: self.y + maxdirdown]
+
+        if not self.random:
+            test=1
+
+        area = np.reshape(np.stack((state, state2, state3, state4)), (4 * (2 * self.viewD + 1) * (2 * self.viewD + 1)))
+        action = [0, 0]
+        if self.lastAction is not None:
+            action = self.lastAction
+        return np.reshape(np.concatenate((area, action)), (1, 4 * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
+
+    def executeAction(self, world: LabyrinthWorld, action):
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+        if len(action) == 2:
+            if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
+                for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
+                    if sub.alive:
+                        self.kills += 1
+                    sub.alive = False
+                    self.alive = True
+
+            self.lastRewards = []
+            if right in directions:
+                self.lastRewards.append(world.grass[self.x + 1, self.y])
+            else:
+                self.lastRewards.append(0)
+            if left in directions:
+                self.lastRewards.append(world.grass[self.x - 1, self.y])
+            else:
+                self.lastRewards.append(0)
+            if up in directions:
+                self.lastRewards.append(world.grass[self.x, self.y - 1])
+            else:
+                self.lastRewards.append(0)
+            if down in directions:
+               self.lastRewards.append(world.grass[self.x, self.y + 1])
+            else:
+                self.lastRewards.append(0)
+            assert len(self.lastRewards) == 4, 'Last Rewards not filled correctly!'
+
+            world.subjectDict[(self.x, self.y)].remove(self)
+            self.lastReward += world.trailMix[self.x, self.y]
+            self.x += action[0]
+            self.y += action[1]
+            world.subjectDict[(self.x, self.y)].append(self)
+            world.trailMix[self.x, self.y] = max(1.0, world.trailMix[self.x, self.y])
+            self.lastReward += (world.grass[self.x, self.y] - 0.0)
+            world.grass[self.x, self.y] = 0
+            world.hunter_grass[self.x, self.y] = 0
+
+    def randomAct(self, world: LabyrinthWorld):
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+        actDict = {}
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+                actDict[left] = world.grass[self.x - 1, self.y]
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+                actDict[right] = world.grass[self.x + 1, self.y]
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+                actDict[up] = world.grass[self.x, self.y - 1]
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+                actDict[down] = world.grass[self.x, self.y + 1]
+
+        allowedActions = dict(filter(lambda elem: elem[0] in directions, actDict.items()))
+        action = max(allowedActions, key=allowedActions.get)
+
+        return action
+
+
+class Hunter(NetLearner):
+    name = 'Hunter'
+    hunterGrassScale = 0.5
+    r = 0
+    g = 255
+    b = 255
+    def randomAct(self, world: LabyrinthWorld):
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+        actDict = {}
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] > 0.01:
+                directions.append(left)
+
+                sub = self.getClosestSubject(world, self.x - 1, self.y)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x - 1 - sub.x) + np.square(self.y - sub.y))
+                distReward = self.viewD - dist
+
+                actDict[left] = world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + left[0], self.y + left[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + left[0], self.y + left[1])]:
+                        if sub.col != self.col:
+                            actDict[left] += 10
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] > 0.01:
+                directions.append(right)
+
+                sub = self.getClosestSubject(world, self.x + 1, self.y)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x + 1 - sub.x) + np.square(self.y - sub.y))
+                distReward = self.viewD - dist
+
+                actDict[right] = world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + right[0], self.y + right[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + right[0], self.y + right[1])]:
+                        if sub.col != self.col:
+                            actDict[right] += 10
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] > 0.01:
+                directions.append(up)
+
+                sub = self.getClosestSubject(world, self.x, self.y - 1)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y - 1 - sub.y))
+                distReward = self.viewD - dist
+
+                actDict[up] = world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + up[0], self.y + up[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + up[0], self.y + up[1])]:
+                        if sub.col != self.col:
+                            actDict[up] += 10
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] > 0.01:
+                directions.append(down)
+
+                sub = self.getClosestSubject(world, self.x, self.y + 1)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y + 1 - sub.y))
+                distReward = self.viewD - dist
+
+                actDict[down] = world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + down[0], self.y + down[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + down[0], self.y + down[1])]:
+                        if sub.col != self.col:
+                            actDict[down] += 10
+
+        if len(actDict) > 0:
+            allowedActions = dict(filter(lambda elem: elem[0] in directions, actDict.items()))
+        else:
+            return super(Hunter, self).randomAct(world)
+        action = max(allowedActions, key=allowedActions.get)
+
+        return action
+
+    def respawnUpdate(self, x, y, world: LabyrinthWorld):
+        super(Hunter, self).respawnUpdate(x, y, world)
+        self.lastReward -= 1
+
+    def getClosestSubject(self, world, x, y):
+        for dist in range(1, self.viewD):
+            dy = dist
+            for dx in range(-dist, dist):
+                if world.board_shape[0] > x + dx >= 0 and world.board_shape[1] > y + dy >= 0:
+                    for sub in world.subjectDict[(x + dx, y + dy)]:
+                        if sub.alive and sub.col != self.col:
+                            return sub
+
+            dy = -dist
+            for dx in range(-dist, dist):
+                if world.board_shape[0] > x + dx >= 0 and world.board_shape[1] > y + dy >= 0:
+                    for sub in world.subjectDict[(x + dx, y + dy)]:
+                        if sub.alive and sub.col != self.col:
+                            return sub
+
+            dx = dist
+            for dy in range(-dist, dist):
+                if world.board_shape[0] > x + dx >= 0 and world.board_shape[1] > y + dy >= 0:
+                    for sub in world.subjectDict[(x + dx, y + dy)]:
+                        if sub.alive and sub.col != self.col:
+                            return sub
+
+            dx = -dist
+            for dy in range(-dist, dist):
+                if world.board_shape[0] > x + dx >= 0 and world.board_shape[1] > y + dy >= 0:
+                    for sub in world.subjectDict[(x + dx, y + dy)]:
+                        if sub.alive and sub.col != self.col:
+                            return sub
+        return None
+
+    def executeAction(self, world: LabyrinthWorld, action):
+        grass_factor = 0.5
+
+        right = (1, 0)
+        left = (-1, 0)
+        up = (0, -1)
+        down = (0, 1)
+        directions = []
+
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                directions.append(left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                directions.append(right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                directions.append(up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                directions.append(down)
+
+        if len(action) == 2:
+            right_kill = left_kill = up_kill = down_kill = False
+            if right in directions:
+                for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
+                    if sub.alive:
+                        if sub.col != self.col:
+                            right_kill = True
+            if left in directions:
+                for sub in world.subjectDict[(self.x + left[0], self.y + left[1])]:
+                    if sub.alive:
+                        if sub.col != self.col:
+                            left_kill = True
+            if up in directions:
+                for sub in world.subjectDict[(self.x + up[0], self.y + up[1])]:
+                    if sub.alive:
+                        if sub.col != self.col:
+                            up_kill = True
+            if down in directions:
+                for sub in world.subjectDict[(self.x + down[0], self.y + down[1])]:
+                    if sub.alive:
+                        if sub.col != self.col:
+                            down_kill = True
+            
+            if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
+                for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
+                    if sub.alive:
+                        self.kills += 1
+                        if sub.col != self.col:
+                            self.lastReward += 10
+                    sub.alive = False
+                    self.alive = True
+
+            self.lastRewards = []
+            if right in directions:
+                sub = self.getClosestSubject(world, self.x + 1, self.y)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x + 1 - sub.x) + np.square(self.y - sub.y))
+                distReward = self.viewD - dist
+                if right_kill:
+                    self.lastRewards.append(10 + world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * grass_factor + distReward)
+                else:
+                    self.lastRewards.append(world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * grass_factor + distReward)
+            else:
+                self.lastRewards.append(0)
+            if left in directions:
+                sub = self.getClosestSubject(world, self.x - 1, self.y)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x - 1 - sub.x) + np.square(self.y - sub.y))
+                distReward = self.viewD - dist
+                if left_kill:
+                    self.lastRewards.append(10 + world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * grass_factor + distReward)
+                else:
+                    self.lastRewards.append(world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * grass_factor + distReward)
+            else:
+                self.lastRewards.append(0)
+            if up in directions:
+                sub = self.getClosestSubject(world, self.x, self.y - 1)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y - sub.y - 1))
+                distReward = self.viewD - dist
+                if up_kill:
+                    self.lastRewards.append(10 + world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * grass_factor + distReward)
+                else:
+                    self.lastRewards.append(world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * grass_factor + distReward)
+            else:
+                self.lastRewards.append(0)
+            if down in directions:
+                sub = self.getClosestSubject(world, self.x, self.y + 1)
+                dist = self.viewD
+                if sub is not None:
+                    dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y + 1 - sub.y))
+                distReward = self.viewD - dist
+                if down_kill:
+                    self.lastRewards.append(10 + world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * grass_factor + distReward)
+                else:
+                    self.lastRewards.append(world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * grass_factor + distReward)
+            else:
+                self.lastRewards.append(0)
+            assert len(self.lastRewards) == 4, 'Last Rewards not filled correctly!'
+
+            world.subjectDict[(self.x, self.y)].remove(self)
+            self.x += action[0]
+            self.y += action[1]
+            self.lastReward += world.trailMix[self.x, self.y]
+            world.subjectDict[(self.x, self.y)].append(self)
+            self.lastReward += (world.hunter_grass[self.x, self.y] * 0.1)
+            world.hunter_grass[self.x, self.y] = 0
+
+            sub = self.getClosestSubject(world, self.x, self.y)
+            dist = self.viewD
+            if sub is not None:
+                dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y - sub.y))
+            distReward = self.viewD - dist
+
+            self.lastReward += distReward
diff --git a/labirinth_ai/__init__.py b/labirinth_ai/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/labirinth_ai/loss.py b/labirinth_ai/loss.py
new file mode 100644
index 0000000..333a9a4
--- /dev/null
+++ b/labirinth_ai/loss.py
@@ -0,0 +1,37 @@
+import tensorflow as tf
+
+
+def loss(nextState, actions):
+    # return tf.reduce_sum(tf.square(nextState[:, 2:, 0] * (0.5 * (nextState[:, 0] + 0.25 * nextState[:, 1] - actions))), axis=1)
+    return tf.reduce_mean(tf.square(nextState[:, 0] + 0.25 * nextState[:, 1] - tf.reduce_sum(
+        nextState[:, 2:6, 0] * (actions[:, :4] + actions[:, 4:]), axis=1))) + tf.reduce_mean(
+        tf.reduce_sum(tf.square(nextState[:, 6:, 0] - actions[:, :4]), axis=1), axis=0)
+
+
+def loss2(nextState, actions):
+    # return tf.reduce_sum(tf.square(nextState[:, 2:, 0] * (0.5 * (nextState[:, 0] + 0.25 * nextState[:, 1] - actions))), axis=1)
+
+    # return 0.1 * tf.reduce_mean(tf.square(0.75 * nextState[:, 1] - tf.reduce_sum(nextState[:, 2:6, 0] * (actions[:, 4:] + actions[:, :4]),axis=1))) + 0.9 * tf.reduce_mean(tf.reduce_sum(tf.square(nextState[:, 6:, 0] - actions[:, :4]), axis=1), axis=0)
+
+    # return 0.0 * tf.reduce_mean(tf.square(0.75 * nextState[:, 1] - tf.reduce_sum(nextState[:, 2:6, 0] * (actions[:, :4]),axis=1))) + 1.0 * tf.reduce_mean(tf.reduce_sum(tf.square(nextState[:, 6:, 0] - actions[:, :4]), axis=1), axis=0)
+
+    return tf.reduce_mean(
+        tf.reduce_max(nextState[:, 2:6, 0] * tf.square((nextState[:, 6:, 0] - (actions[:, :4] + actions[:, 4:]))),
+                      axis=1), axis=0)
+
+    # action = nextState[:, 3] * 1 + nextState[:, 4] * 2 + nextState[:, 5] * 3
+    # action = tf.cast(action, tf.int32)
+    # action = tf.reshape(action, (-1,))
+    #
+    # # test = actions[:, action[:]]
+    #
+    # test1 = tf.slice(actions[:, :4], action, (-1, 1))
+    # test2 = tf.slice(actions[:, 4:], action, (-1, 1))
+    #
+    # return 1.0 * tf.reduce_mean(tf.reduce_sum(tf.square((0.1 * nextState[:, 1] + nextState[:, 6:, 0]) - (test1 + test2)), axis=1)) + 0.0 * tf.reduce_mean(tf.reduce_sum(tf.square(nextState[:, 6:, 0] - actions[:, :4]), axis=1), axis=0)
+    # return 1.0 * tf.reduce_mean(tf.reduce_sum(tf.square((0.1 * nextState[:, 1] + nextState[:, 6:, 0]) - (actions[:, :4] + actions[:, 4:])), axis=1)) + 0.0 * tf.reduce_mean(tf.reduce_sum(tf.square(nextState[:, 6:, 0] - actions[:, :4]), axis=1), axis=0)
+
+
+def loss3(target, pred):
+    return tf.reduce_mean(0.5 * tf.square(0.1 * target[:, 0, 0] + target[:, 1, 0] - (pred[:, 0] + pred[:, 1]))
+                          + 0.5 * tf.square(target[:, 1, 0] - pred[:, 0]), axis=0)

From e718873caaa9728a93219c8d786708448ca761ec Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Sat, 12 Feb 2022 17:35:15 +0100
Subject: [PATCH 08/14] move to pytorch

---
 labirinth_ai/LabyrinthClient.py  |   4 +-
 labirinth_ai/Models/BaseModel.py | 125 +++++++++++++++++++++++++++++++
 labirinth_ai/Models/__init__.py  |   0
 labirinth_ai/Subject.py          |  28 ++-----
 4 files changed, 135 insertions(+), 22 deletions(-)
 create mode 100644 labirinth_ai/Models/BaseModel.py
 create mode 100644 labirinth_ai/Models/__init__.py

diff --git a/labirinth_ai/LabyrinthClient.py b/labirinth_ai/LabyrinthClient.py
index fdcb22e..daa8e93 100644
--- a/labirinth_ai/LabyrinthClient.py
+++ b/labirinth_ai/LabyrinthClient.py
@@ -17,9 +17,9 @@ class LabyrinthClient(Client):
                 if self.world_provider.world.board[x, y] in [1, 2]:
                     r, g, b = 57, 92, 152
                     if 1.5 >= self.world_provider.world.hunter_grass[x, y] > 0.5:
-                        r, g, b = 25, 149, 156
-                    if 3 >= self.world_provider.world.hunter_grass[x, y] > 1.5:
                         r, g, b = 112, 198, 169
+                    if 3 >= self.world_provider.world.hunter_grass[x, y] > 1.5:
+                        r, g, b = 25, 149, 156
                     self.world_provider.world.set_color(x, y, 0, r / 255.0, g / 255.0, b / 255.0)
                 if self.world_provider.world.board[x, y] == 3:
                     self.world_provider.world.set_color(x, y, 0, 139 / 255.0, 72 / 255.0, 82 / 255.0)
diff --git a/labirinth_ai/Models/BaseModel.py b/labirinth_ai/Models/BaseModel.py
new file mode 100644
index 0000000..e4210ad
--- /dev/null
+++ b/labirinth_ai/Models/BaseModel.py
@@ -0,0 +1,125 @@
+import torch
+from torch import nn
+import numpy as np
+import tqdm
+from torch.utils.data import Dataset, DataLoader
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(f"Using {device} device")
+
+
+# Define model
+class BaseModel(nn.Module):
+    def __init__(self, view_dimension, action_num, channels):
+        super(BaseModel, self).__init__()
+        self.flatten = nn.Flatten()
+        self.actions = []
+        self.action_num = action_num
+        self.viewD = view_dimension
+        self.channels = channels
+        for action in range(action_num):
+            action_sequence = nn.Sequential(
+                nn.Linear(channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2,
+                          (2 * self.viewD + 1) * (2 * self.viewD + 1)),
+                nn.ELU(),
+                nn.Linear((2 * self.viewD + 1) * (2 * self.viewD + 1), (self.viewD + 1) * (self.viewD + 1)),
+                nn.ELU(),
+                nn.Linear((self.viewD + 1) * (self.viewD + 1), 2)
+            )
+            self.add_module('action_' + str(action), action_sequence)
+            self.actions.append(action_sequence)
+
+    def forward(self, x):
+        x_flat = self.flatten(x)
+        actions = []
+        for action in range(self.action_num):
+            actions.append(self.actions[action](x_flat))
+        return torch.stack(actions, dim=1)
+
+
+class BaseDataSet(Dataset):
+    def __init__(self, states, targets):
+        assert len(states) == len(targets), "Needs to have as many states as targets!"
+        self.states = torch.tensor(states, dtype=torch.float32)
+        self.targets = torch.tensor(targets, dtype=torch.float32)
+
+    def __len__(self):
+        return len(self.states)
+
+    def __getitem__(self, idx):
+        return self.states[idx], self.targets[idx]
+
+
+def create_optimizer(model):
+    return torch.optim.RMSprop(model.parameters(), lr=1e-3)
+
+
+def create_loss_function(action):
+    def custom_loss(prediction, target):
+        return torch.mean(0.5 * torch.square(
+            0.1 * target[:, 0, 0] + target[:, 1, 0] - (
+                        prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
+            target[:, 1, 0] - prediction[:, action, 0]), dim=0)
+
+    return custom_loss
+
+
+def from_numpy(x):
+    return torch.tensor(x, dtype=torch.float32)
+
+
+def train(states, targets, model, optimizer):
+    for action in range(model.action_num):
+        data_set = BaseDataSet(states[action], targets[action])
+        dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
+        loss_fn = create_loss_function(action)
+
+        size = len(dataloader)
+        model.train()
+        for batch, (X, y) in enumerate(dataloader):
+            X, y = X.to(device), y.to(device)
+
+            # Compute prediction error
+            pred = model(X)
+            loss = loss_fn(pred, y)
+
+            # Backpropagation
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            if batch % 100 == 0:
+                loss, current = loss.item(), batch * len(X)
+                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
+        model.eval()
+
+
+if __name__ == '__main__':
+    sample = np.random.random((1, 4, 11, 11))
+
+    model = BaseModel(5, 4, 4).to(device)
+    print(model)
+
+    test = model(torch.tensor(sample, dtype=torch.float32))
+    # test = test.cpu().detach().numpy()
+    print(test)
+
+    state = np.random.random((4, 11, 11))
+    target = np.random.random((4, 2))
+    states = [
+        [state],
+        [state],
+        [state],
+        [state],
+    ]
+    targets = [
+        [target],
+        [target],
+        [target],
+        [target],
+    ]
+
+    optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
+
+    train(states, targets, model, optimizer)
+
diff --git a/labirinth_ai/Models/__init__.py b/labirinth_ai/Models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py
index ec0593c..3fe291a 100644
--- a/labirinth_ai/Subject.py
+++ b/labirinth_ai/Subject.py
@@ -5,6 +5,7 @@ from tensorflow import keras
 
 from labirinth_ai.LabyrinthWorld import LabyrinthWorld
 from labirinth_ai.loss import loss2, loss3
+from labirinth_ai.Models.BaseModel import BaseModel, train, create_optimizer, device, from_numpy
 
 # import torch
 # dtype = torch.float
@@ -369,22 +370,9 @@ class NetLearner(Subject):
         self.x_in = []
         self.actions = []
         self.target = []
-        for i in range(4):
-            x_in = keras.Input(shape=(self.channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
-            self.x_in.append(x_in)
-            inVec = keras.layers.Flatten()(x_in)
-            actions = keras.layers.Dense(((2 * self.viewD + 1) * (2 * self.viewD + 1)), activation='elu',
-                                         kernel_regularizer=keras.regularizers.l2(0.001),
-                                         name=self.name + str(self.id) + 'Dense' + str(i) + 'l1')(inVec)
-            actions = keras.layers.Dense(((self.viewD + 1) * (self.viewD + 1)), activation='elu',
-                                         kernel_regularizer=keras.regularizers.l2(0.001))(actions)
-            self.target.append(keras.Input(shape=(2, 1)))
-            self.actions.append(keras.layers.Dense(2, activation='linear', use_bias=False, kernel_regularizer=keras.regularizers.l2(0.001))(actions))
-
-        self.model = keras.Model(inputs=self.x_in, outputs=self.actions)
-
-        self.model.compile(optimizer=tf.keras.optimizers.RMSprop(self.learningRate), loss=loss3,
-                           target_tensors=self.target)
+        self.model = BaseModel(self.viewD, 4, 4)
+        self.model.to(device)
+        self.optimizer = create_optimizer(self.model)
 
         if len(self.samples) < self.randomBuffer:
             self.random = True
@@ -508,7 +496,7 @@ class NetLearner(Subject):
             if state is None:
                 state = self.createState(world)
             if vals is None:
-                vals = self.model.predict([state, state, state, state])
+                vals = self.model(from_numpy(state)).detach().numpy()
                 vals = np.reshape(np.transpose(np.reshape(vals, (4, 2)), (1, 0)),
                                   (1, 8))
 
@@ -623,9 +611,9 @@ class NetLearner(Subject):
                 target[:, 1, 0] = samples[:, 1, 3] #reward t-2 got
 
                 nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
-                nextVals = self.model.predict([nextState, nextState, nextState, nextState])
+                nextVals = self.model(from_numpy(nextState)).detach().numpy()
 
-                nextVals2 = nextVals[i][:,  0] + nextVals[i][:, 1]
+                nextVals2 = nextVals[:, i,  0] + nextVals[:, i, 1]
                 target[:, 0, 0] = nextVals2 #best q t-1
             else:
                 target[:, 1, 0] = np.array(list(map(lambda elem: list(elem), list(np.array(samples[:, 1, 4])))))[:, i]  # reward t-2 got
@@ -639,7 +627,7 @@ class NetLearner(Subject):
     def train(self):
         print(self.name)
         states, target = self.generateSamples()
-        self.model.fit(states, target, epochs=1)
+        train(states, target, self.model, self.optimizer)
 
         self.samples = self.samples[-self.historySizeMul*self.batchsize:]
 

From 33b5d9c83e9f5ea7fa1de7af7e1adccbb1e7c0af Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Sat, 12 Feb 2022 19:30:03 +0100
Subject: [PATCH 09/14] solves exiting

---
 labirinth_ai/Models/BaseModel.py | 7 ++++++-
 labirinth_ai/Subject.py          | 3 +--
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/labirinth_ai/Models/BaseModel.py b/labirinth_ai/Models/BaseModel.py
index e4210ad..e87a3c9 100644
--- a/labirinth_ai/Models/BaseModel.py
+++ b/labirinth_ai/Models/BaseModel.py
@@ -4,6 +4,9 @@ import numpy as np
 import tqdm
 from torch.utils.data import Dataset, DataLoader
 
+import os
+os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0"
+
 device = "cuda" if torch.cuda.is_available() else "cpu"
 print(f"Using {device} device")
 
@@ -36,7 +39,6 @@ class BaseModel(nn.Module):
             actions.append(self.actions[action](x_flat))
         return torch.stack(actions, dim=1)
 
-
 class BaseDataSet(Dataset):
     def __init__(self, states, targets):
         assert len(states) == len(targets), "Needs to have as many states as targets!"
@@ -93,6 +95,9 @@ def train(states, targets, model, optimizer):
                 print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
         model.eval()
 
+        del data_set
+        del dataloader
+
 
 if __name__ == '__main__':
     sample = np.random.random((1, 4, 11, 11))
diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py
index 3fe291a..5afa2a7 100644
--- a/labirinth_ai/Subject.py
+++ b/labirinth_ai/Subject.py
@@ -370,8 +370,7 @@ class NetLearner(Subject):
         self.x_in = []
         self.actions = []
         self.target = []
-        self.model = BaseModel(self.viewD, 4, 4)
-        self.model.to(device)
+        self.model = BaseModel(self.viewD, 4, 4).to(device)
         self.optimizer = create_optimizer(self.model)
 
         if len(self.samples) < self.randomBuffer:

From 4a05baa10398eb81008ba13664da1af714ccff67 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Fri, 11 Mar 2022 14:19:55 +0100
Subject: [PATCH 10/14] base evolution model. needs different memory handling

---
 labirinth_ai/LabyrinthWorld.py        |  64 ++++------
 labirinth_ai/Models/BaseModel.py      |   8 +-
 labirinth_ai/Models/EvolutionModel.py | 176 ++++++++++++++++++++++++++
 labirinth_ai/Subject.py               |  18 ++-
 4 files changed, 218 insertions(+), 48 deletions(-)
 create mode 100644 labirinth_ai/Models/EvolutionModel.py

diff --git a/labirinth_ai/LabyrinthWorld.py b/labirinth_ai/LabyrinthWorld.py
index 2a2e3e7..b22a0ea 100644
--- a/labirinth_ai/LabyrinthWorld.py
+++ b/labirinth_ai/LabyrinthWorld.py
@@ -146,23 +146,27 @@ class LabyrinthWorld(World):
 
         # adding subjects
         from labirinth_ai.Subject import Hunter, Herbivore
-        while len(self.subjects) < 2:
-            px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
-            py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
-            if self.board[px, py] == 1:
-                self.subjects.append(Hunter(px, py))
-                self.ins += self.subjects[-1].x_in
-                self.actions += self.subjects[-1].actions
-                self.targets += self.subjects[-1].target
+        for _ in range(10):
+            while True:
+                px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
+                py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
+                if self.board[px, py] == 1:
+                    self.subjects.append(Hunter(px, py))
+                    self.ins += self.subjects[-1].x_in
+                    self.actions += self.subjects[-1].actions
+                    self.targets += self.subjects[-1].target
+                    break
 
-        while len(self.subjects) < 10:
-            px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
-            py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
-            if self.board[px, py] == 1:
-                self.subjects.append(Herbivore(px, py))
-                self.ins += self.subjects[-1].x_in
-                self.actions += self.subjects[-1].actions
-                self.targets += self.subjects[-1].target
+        for _ in range(40):
+            while True:
+                px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
+                py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
+                if self.board[px, py] == 1:
+                    self.subjects.append(Herbivore(px, py))
+                    self.ins += self.subjects[-1].x_in
+                    self.actions += self.subjects[-1].actions
+                    self.targets += self.subjects[-1].target
+                    break
 
         for x in range(self.board_shape[0]):
             for y in range(self.board_shape[1]):
@@ -173,36 +177,14 @@ class LabyrinthWorld(World):
 
     def update(self):
         # start = time.time()
-        if self.model is None:
-            for sub in self.subjects:
-                sub.calculateAction(self)
-        else:
-            states = list(map(lambda e: e.createState(self), self.subjects))
-            states = sum(list(map(lambda e: [e, e, e, e], states)), [])
-            vals = self.model.predict(states)
-            vals = np.reshape(np.transpose(np.reshape(vals, (len(self.subjects), 4, 2)), (0, 2, 1)),
-                              (len(self.subjects), 1, 8))
-            list(map(lambda e: e[1].calculateAction(self, vals[e[0]], states[e[0]]), enumerate(self.subjects)))
+        for sub in self.subjects:
+            sub.calculateAction(self)
 
         for sub in self.subjects:
             if sub.alive:
-                sub.update(self, doTrain=self.model is None)
+                sub.update(self)
             sub.tick += 1
 
-        if self.model is not None:
-            if self.round >= self.nextTrain:
-                samples = list(map(lambda e: e.generateSamples(), self.subjects))
-                states = sum(list(map(lambda e: e[0], samples)), [])
-                targets = sum(list(map(lambda e: e[1], samples)), [])
-                self.model.fit(states, targets)
-                self.nextTrain = self.batchsize / 5
-                self.round = 0
-                for sub in self.subjects:
-                    if len(sub.samples) > 20*self.batchsize:
-                        sub.samples = sub.samples[:-20*self.batchsize]
-            else:
-                self.round += 1
-
         new_subjects = []
         kill_table = {}
         live_table = {}
diff --git a/labirinth_ai/Models/BaseModel.py b/labirinth_ai/Models/BaseModel.py
index e87a3c9..9678f50 100644
--- a/labirinth_ai/Models/BaseModel.py
+++ b/labirinth_ai/Models/BaseModel.py
@@ -13,6 +13,7 @@ print(f"Using {device} device")
 
 # Define model
 class BaseModel(nn.Module):
+    evolutionary = False
     def __init__(self, view_dimension, action_num, channels):
         super(BaseModel, self).__init__()
         self.flatten = nn.Flatten()
@@ -39,6 +40,7 @@ class BaseModel(nn.Module):
             actions.append(self.actions[action](x_flat))
         return torch.stack(actions, dim=1)
 
+
 class BaseDataSet(Dataset):
     def __init__(self, states, targets):
         assert len(states) == len(targets), "Needs to have as many states as targets!"
@@ -87,7 +89,7 @@ def train(states, targets, model, optimizer):
 
             # Backpropagation
             optimizer.zero_grad()
-            loss.backward()
+            loss.backward(retain_graph=True)
             optimizer.step()
 
             if batch % 100 == 0:
@@ -100,7 +102,7 @@ def train(states, targets, model, optimizer):
 
 
 if __name__ == '__main__':
-    sample = np.random.random((1, 4, 11, 11))
+    sample = np.random.random((1, 486))
 
     model = BaseModel(5, 4, 4).to(device)
     print(model)
@@ -109,7 +111,7 @@ if __name__ == '__main__':
     # test = test.cpu().detach().numpy()
     print(test)
 
-    state = np.random.random((4, 11, 11))
+    state = np.random.random((486,))
     target = np.random.random((4, 2))
     states = [
         [state],
diff --git a/labirinth_ai/Models/EvolutionModel.py b/labirinth_ai/Models/EvolutionModel.py
new file mode 100644
index 0000000..38276f6
--- /dev/null
+++ b/labirinth_ai/Models/EvolutionModel.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+import numpy as np
+import tqdm
+from torch.utils.data import Dataset, DataLoader
+from labirinth_ai.Models.BaseModel import device
+
+
+class NodeGene:
+    valid_types = ['sensor', 'hidden', 'output']
+
+    def __init__(self, node_id, node_type, bias=None):
+        assert node_type in self.valid_types, 'Unknown node type!'
+        self.node_id = node_id
+        self.node_type = node_type
+        if node_type == 'hidden':
+            assert bias is not None, 'Expected a bias for hidden node types!'
+            self.bias = bias
+        else:
+            self.bias = None
+
+
+class ConnectionGene:
+    def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
+        self.start = start
+        self.end = end
+        self.enabled = enabled
+        self.innvovation_num = innovation_num
+        self.recurrent = recurrent
+        if weight is None:
+            self.weight = np.random.random(1)[0] * 2 - 1.0
+        else:
+            self.weight = weight
+
+
+class EvolutionModel(nn.Module):
+    evolutionary = True
+
+    def __init__(self, view_dimension, action_num, channels, genes=None):
+        super(EvolutionModel, self).__init__()
+        self.flatten = nn.Flatten()
+
+        self.action_num = action_num
+        self.viewD = view_dimension
+        self.channels = channels
+
+        if genes is None:
+            self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2
+
+            self.genes = {'nodes': {}, 'connections': []}
+            node_id = 0
+            for _ in range(self.num_input_nodes):
+                self.genes['nodes'][node_id] = NodeGene(node_id, 'sensor')
+                node_id += 1
+            first_action = node_id
+            for _ in range(action_num * 2):
+                self.genes['nodes'][node_id] = NodeGene(node_id, 'output')
+                node_id += 1
+
+            for index in range(self.num_input_nodes):
+                for action in range(action_num * 2):
+                    self.genes['connections'].append(
+                        ConnectionGene(index, first_action + action, True, index*(action_num * 2) + action)
+                    )
+
+        self.incoming_connections = {}
+        for connection in self.genes['connections']:
+            if connection.end not in self.incoming_connections.keys():
+                self.incoming_connections[connection.end] = []
+            self.incoming_connections[connection.end].append(connection)
+
+        self.layers = {}
+        self.indices = {}
+
+        self.has_recurrent = False
+        non_recurrent_indices = {}
+        with torch.no_grad():
+            for key, value in self.incoming_connections.items():
+                value.sort(key=lambda element: element.start)
+
+                lin = nn.Linear(len(value), 1, bias=self.genes['nodes'][key].bias is not None)
+                for index, connection in enumerate(value):
+                    lin.weight[0, index] = value[index].weight
+                if self.genes['nodes'][key].bias is not None:
+                    lin.bias[0] = self.genes['nodes'][key].bias
+
+                non_lin = nn.ELU()
+                sequence = nn.Sequential(
+                    lin,
+                    non_lin
+                )
+                self.add_module('layer_' + str(key), sequence)
+                self.layers[key] = sequence
+                self.indices[key] = list(map(lambda element: element.start, value))
+
+                non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
+                if not self.has_recurrent and len(non_recurrent_indices[key]) != len(self.indices[key]):
+                    self.has_recurrent = True
+                non_recurrent_indices[key] = list(map(lambda element: element.start, non_recurrent_indices[key]))
+        rank_of_node = {}
+        for i in range(self.num_input_nodes):
+            rank_of_node[i] = 0
+
+        layers_to_add = list(non_recurrent_indices.items())
+        while len(layers_to_add) > 0:
+            for index, (key, incoming_nodes) in enumerate(list(layers_to_add)):
+                max_rank = -1
+                all_ranks_found = True
+
+                for incoming_node in incoming_nodes:
+                    if incoming_node in rank_of_node.keys():
+                        max_rank = max(max_rank, rank_of_node[incoming_node])
+                    else:
+                        all_ranks_found = False
+
+                if all_ranks_found:
+                    rank_of_node[key] = max_rank + 1
+
+            layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
+        ranked_layers = list(rank_of_node.items())
+        ranked_layers.sort(key=lambda element: element[1])
+        ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
+        self.layer_order = list(map(lambda element: element[0], ranked_layers))
+        self.memory = torch.Tensor((max(map(lambda element: element[1].node_id, self.genes['nodes'].items())) + 1))
+
+    def forward(self, x, memory=None):
+        x_flat = self.flatten(x)
+        if memory is None:
+            memory = torch.Tensor(self.memory)
+            outs = []
+            for batch_element in x_flat:
+                memory[0:self.num_input_nodes] = batch_element
+                for layer_index in self.layer_order:
+                    memory[layer_index] = self.layers[layer_index](memory[self.indices[layer_index]])
+                outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
+            outs = torch.stack(outs)
+            self.memory = torch.Tensor(memory)
+            return torch.reshape(outs, (x.shape[0], 4, 2))
+        else:
+            memory[:, 0:self.num_input_nodes] = x
+            for layer_index in self.layer_order:
+                memory[:, layer_index] = self.layers[layer_index](memory[:, self.indices[layer_index]])
+            return torch.reshape(
+                memory[:, self.num_input_nodes: self.num_input_nodes + self.action_num * 2],
+                (x.shape[0], 4, 2))
+
+
+if __name__ == '__main__':
+    sample = np.random.random((1, 486))
+
+    model = EvolutionModel(5, 4, 4).to(device)
+    print(model)
+    print(model.has_recurrent)
+
+    test = model(torch.tensor(sample, dtype=torch.float32))
+    # test = test.cpu().detach().numpy()
+    print(test)
+
+    state = np.random.random((1, 486))
+    target = np.random.random((4, 2))
+    states = [
+        [state],
+        [state],
+        [state],
+        [state],
+    ]
+    targets = [
+        [target],
+        [target],
+        [target],
+        [target],
+    ]
+
+    optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
+    from labirinth_ai.Models.BaseModel import train
+    train(states, targets, model, optimizer)
diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py
index 5afa2a7..f9426eb 100644
--- a/labirinth_ai/Subject.py
+++ b/labirinth_ai/Subject.py
@@ -382,6 +382,8 @@ class NetLearner(Subject):
 
         self.lastRewards = []
 
+        self.accumulated_rewards = 0
+
     def visualize(self):
         print(self.name)
         layers = self.model.get_weights()
@@ -542,6 +544,8 @@ class NetLearner(Subject):
                 self.train()
                 self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
 
+        self.accumulated_rewards += self.lastReward
+
         self.lastAction = self.action
         self.lastState = self.state
         self.lastReward = 0
@@ -728,10 +732,12 @@ class Herbivore(NetLearner):
         if len(action) == 2:
             if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
                 for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
-                    if sub.alive:
-                        self.kills += 1
-                    sub.alive = False
-                    self.alive = True
+                    if isinstance(sub, Hunter):
+                        if sub.alive:
+                            sub.kills += 1
+                            sub.alive = True
+                            sub.lastReward += 10
+                            self.alive = False
 
             self.lastRewards = []
             if right in directions:
@@ -795,6 +801,10 @@ class Herbivore(NetLearner):
 
         return action
 
+    def respawnUpdate(self, x, y, world: LabyrinthWorld):
+        super(Herbivore, self).respawnUpdate(x, y, world)
+        self.lastReward -= 1
+
 
 class Hunter(NetLearner):
     name = 'Hunter'

From cf4d773c1081f4e0087139d34c38f41bb9c0e8b3 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Fri, 12 Aug 2022 15:48:30 +0200
Subject: [PATCH 11/14] neat implementation up to mutate

---
 Client/Client.py                      |   3 +-
 labirinth_ai/LabyrinthClient.py       |  24 ++-
 labirinth_ai/LabyrinthWorld.py        |  85 ++++++----
 labirinth_ai/Models/BaseModel.py      |   6 +-
 labirinth_ai/Models/EvolutionModel.py | 229 +++++++++++++++++---------
 labirinth_ai/Models/Genotype.py       | 139 ++++++++++++++++
 labirinth_ai/Population.py            |  97 +++++++++++
 labirinth_ai/Subject.py               |  29 +---
 8 files changed, 468 insertions(+), 144 deletions(-)
 create mode 100644 labirinth_ai/Models/Genotype.py
 create mode 100644 labirinth_ai/Population.py

diff --git a/Client/Client.py b/Client/Client.py
index e64554a..80f146e 100644
--- a/Client/Client.py
+++ b/Client/Client.py
@@ -52,6 +52,7 @@ class Client:
 
         self.pos = pos
         self.time = time.time()
+        self.projMatrix = perspectiveMatrix(45.0, 400 / 400, 0.01, MAX_DISTANCE)
 
         glutReshapeFunc(self.resize)
         glutDisplayFunc(self.display)
@@ -195,7 +196,7 @@ class Client:
 
         glutSwapBuffers()
 
-        print('fps', 1.0 / (time.time() - self.time))
+        # print('fps', 1.0 / (time.time() - self.time))
         self.time = time.time()
         glutPostRedisplay()
 
diff --git a/labirinth_ai/LabyrinthClient.py b/labirinth_ai/LabyrinthClient.py
index daa8e93..227c2a0 100644
--- a/labirinth_ai/LabyrinthClient.py
+++ b/labirinth_ai/LabyrinthClient.py
@@ -1,13 +1,16 @@
 import time
 
-from Client.Client import Client, MAX_DISTANCE
+from Client.Client import Client, MAX_DISTANCE, glutPostRedisplay
 from MatrixStuff.Transformations import perspectiveMatrix
 from labirinth_ai.LabyrinthProvider import LabyrinthProvider
 
 import numpy as np
 
+
 class LabyrinthClient(Client):
     def __init__(self, test=False, pos=[0, 0, 0], world_class=LabyrinthProvider):
+        self.render = True
+        self.round_timer = time.time()
         super(LabyrinthClient, self).__init__(test, pos, world_class)
 
     def draw_world(self):
@@ -32,12 +35,25 @@ class LabyrinthClient(Client):
                 self.world_provider.world.set_color(sub.x, sub.y, 0, 212 / 255.0, 150 / 255.0, 222 / 255.0)
 
         self.projMatrix = perspectiveMatrix(45.0, 400 / 400, 0.01, MAX_DISTANCE)
-        print('redraw', time.time() - start_time)
+        # print('redraw', time.time() - start_time)
 
     def display(self):
-        super(LabyrinthClient, self).display()
-        self.draw_world()
+        if self.render:
+            super(LabyrinthClient, self).display()
+            self.draw_world()
+        else:
+            glutPostRedisplay()
         self.world_provider.world.update()
+        # round_end = time.time()
+        # print('round time', round_end - self.round_timer)
+        # self.round_timer = round_end
+
+    def keyboardHandler(self, key: int, x: int, y: int):
+        super().keyboardHandler(key, x, y)
+
+        if key == b' ':
+            self.render = not self.render
+
 
 if __name__ == '__main__':
     client = LabyrinthClient(pos=[-50, -50, -200])
diff --git a/labirinth_ai/LabyrinthWorld.py b/labirinth_ai/LabyrinthWorld.py
index b22a0ea..f2adaf9 100644
--- a/labirinth_ai/LabyrinthWorld.py
+++ b/labirinth_ai/LabyrinthWorld.py
@@ -1,11 +1,11 @@
 import time
+from typing import Tuple
 
 from Objects.Cube.Cube import Cube
 from Objects.World import World
 import numpy as np
 import random
 
-
 class LabyrinthWorld(World):
     randomBuffer = 0
     batchsize = 1000
@@ -26,21 +26,37 @@ class LabyrinthWorld(World):
 
         self.max_crates = self.max_room_num
 
-        self.subjects = []
-        self.ins = []
-        self.actions = []
-        self.targets = []
-
         self.model = None
         self.lastUpdate = time.time()
         self.nextTrain = self.randomBuffer
-        self.round = 0
+        self.round = 1
+        self.evolve_timer = 10
+        # self.evolve_timer = 1500
 
         self.trailMix = np.zeros(self.board_shape)
         self.grass = np.zeros(self.board_shape)
         self.hunter_grass = np.zeros(self.board_shape)
         self.subjectDict = {}
 
+        self._hunters = None
+        self._herbivores = None
+
+    @property
+    def hunters(self):
+        if self._hunters is None:
+            return []
+        return self._hunters.subjects
+
+    @property
+    def herbivores(self):
+        if self._herbivores is None:
+            return []
+        return self._herbivores.subjects
+
+    @property
+    def subjects(self):
+        return self.hunters + self.herbivores
+
     def generate(self, seed: int = None, sea_plate_height: int = 50, continental_plate_height: int = 200):
         board = np.zeros(self.board_shape)
         random.seed(seed)
@@ -146,36 +162,40 @@ class LabyrinthWorld(World):
 
         # adding subjects
         from labirinth_ai.Subject import Hunter, Herbivore
-        for _ in range(10):
-            while True:
-                px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
-                py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
-                if self.board[px, py] == 1:
-                    self.subjects.append(Hunter(px, py))
-                    self.ins += self.subjects[-1].x_in
-                    self.actions += self.subjects[-1].actions
-                    self.targets += self.subjects[-1].target
-                    break
+        from labirinth_ai.Population import Population
+        self._hunters = Population(Hunter, self, 10)
 
-        for _ in range(40):
-            while True:
-                px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
-                py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
-                if self.board[px, py] == 1:
-                    self.subjects.append(Herbivore(px, py))
-                    self.ins += self.subjects[-1].x_in
-                    self.actions += self.subjects[-1].actions
-                    self.targets += self.subjects[-1].target
-                    break
+        self._herbivores = Population(Herbivore, self, 40)
 
+        self.subjectDict = self.build_subject_dict()
+
+    def generate_free_coordinates(self) -> Tuple[int, int]:
+        while True:
+            px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
+            py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
+            if self.board[px, py] == 1:
+                return px, py
+
+    def build_subject_dict(self):
+        subject_dict = {}
         for x in range(self.board_shape[0]):
             for y in range(self.board_shape[1]):
-                self.subjectDict[(x, y)] = []
+                subject_dict[(x, y)] = []
 
         for sub in self.subjects:
-            self.subjectDict[(sub.x, sub.y)].append(sub)
+            subject_dict[(sub.x, sub.y)].append(sub)
+        return subject_dict
 
     def update(self):
+
+        if self.round % self.evolve_timer == 0:
+            print('Evolve population')
+            self.round = 0
+            self._hunters.evolve()
+            self._herbivores.evolve()
+            self.subjectDict = self.build_subject_dict()
+        self.round += 1
+
         # start = time.time()
         for sub in self.subjects:
             sub.calculateAction(self)
@@ -185,7 +205,6 @@ class LabyrinthWorld(World):
                 sub.update(self)
             sub.tick += 1
 
-        new_subjects = []
         kill_table = {}
         live_table = {}
         for sub in self.subjects:
@@ -194,18 +213,14 @@ class LabyrinthWorld(World):
                 live_table[sub.name] = 0
             kill_table[sub.name] += sub.kills
             live_table[sub.name] += sub.lives
-            if sub.alive:
-                new_subjects.append(sub)
-            else:
+            if not sub.alive:
                 px = random.randint(self.max_room_dim, (self.board_shape[0] - 1) - self.max_room_dim)
                 py = random.randint(self.max_room_dim, (self.board_shape[1] - 1) - self.max_room_dim)
                 while self.board[px, py] == 0:
                     px = random.randint(self.max_room_dim, (self.board_shape[0] - 1) - self.max_room_dim)
                     py = random.randint(self.max_room_dim, (self.board_shape[1] - 1) - self.max_room_dim)
                 sub.respawnUpdate(px, py, self)
-                new_subjects.append(sub)
 
-        self.subjects = new_subjects
         self.trailMix *= 0.99
 
         self.grass = np.minimum(self.grass + 0.01 * (self.board != 0), 3)
diff --git a/labirinth_ai/Models/BaseModel.py b/labirinth_ai/Models/BaseModel.py
index 9678f50..2434b61 100644
--- a/labirinth_ai/Models/BaseModel.py
+++ b/labirinth_ai/Models/BaseModel.py
@@ -44,8 +44,8 @@ class BaseModel(nn.Module):
 class BaseDataSet(Dataset):
     def __init__(self, states, targets):
         assert len(states) == len(targets), "Needs to have as many states as targets!"
-        self.states = torch.tensor(states, dtype=torch.float32)
-        self.targets = torch.tensor(targets, dtype=torch.float32)
+        self.states = torch.tensor(np.array(states), dtype=torch.float32)
+        self.targets = torch.tensor(np.array(targets), dtype=torch.float32)
 
     def __len__(self):
         return len(self.states)
@@ -69,7 +69,7 @@ def create_loss_function(action):
 
 
 def from_numpy(x):
-    return torch.tensor(x, dtype=torch.float32)
+    return torch.tensor(np.array(x), dtype=torch.float32)
 
 
 def train(states, targets, model, optimizer):
diff --git a/labirinth_ai/Models/EvolutionModel.py b/labirinth_ai/Models/EvolutionModel.py
index 38276f6..8a180d5 100644
--- a/labirinth_ai/Models/EvolutionModel.py
+++ b/labirinth_ai/Models/EvolutionModel.py
@@ -3,40 +3,16 @@ from torch import nn
 import numpy as np
 import tqdm
 from torch.utils.data import Dataset, DataLoader
-from labirinth_ai.Models.BaseModel import device
-
-
-class NodeGene:
-    valid_types = ['sensor', 'hidden', 'output']
-
-    def __init__(self, node_id, node_type, bias=None):
-        assert node_type in self.valid_types, 'Unknown node type!'
-        self.node_id = node_id
-        self.node_type = node_type
-        if node_type == 'hidden':
-            assert bias is not None, 'Expected a bias for hidden node types!'
-            self.bias = bias
-        else:
-            self.bias = None
-
-
-class ConnectionGene:
-    def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
-        self.start = start
-        self.end = end
-        self.enabled = enabled
-        self.innvovation_num = innovation_num
-        self.recurrent = recurrent
-        if weight is None:
-            self.weight = np.random.random(1)[0] * 2 - 1.0
-        else:
-            self.weight = weight
+from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
+from labirinth_ai.Models.Genotype import Genotype
 
 
 class EvolutionModel(nn.Module):
     evolutionary = True
 
-    def __init__(self, view_dimension, action_num, channels, genes=None):
+    def __init__(self, view_dimension, action_num, channels, genes: Genotype = None, genotype_class=None):
+        if genotype_class is None:
+            genotype_class = Genotype
         super(EvolutionModel, self).__init__()
         self.flatten = nn.Flatten()
 
@@ -46,25 +22,29 @@ class EvolutionModel(nn.Module):
 
         if genes is None:
             self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2
+            self.genes = genotype_class(action_num, self.num_input_nodes)
+        else:
+            self.num_input_nodes = len(list(filter(lambda element: element[1].node_type == 'sensor', genes.nodes.items())))
+            assert self.num_input_nodes > 0, 'Network needs to have sensor nodes!'
+            is_input_over = False
+            is_output_over = False
+            for key, node in genes.nodes.items():
+                if node.node_type == 'sensor':
+                    if is_input_over:
+                        raise ValueError('Node genes need to follow the order sensor, output, hidden!')
 
-            self.genes = {'nodes': {}, 'connections': []}
-            node_id = 0
-            for _ in range(self.num_input_nodes):
-                self.genes['nodes'][node_id] = NodeGene(node_id, 'sensor')
-                node_id += 1
-            first_action = node_id
-            for _ in range(action_num * 2):
-                self.genes['nodes'][node_id] = NodeGene(node_id, 'output')
-                node_id += 1
+                if node.node_type == 'output':
+                    is_input_over = True
+                    if is_output_over:
+                        raise ValueError('Node genes need to follow the order sensor, output, hidden!')
 
-            for index in range(self.num_input_nodes):
-                for action in range(action_num * 2):
-                    self.genes['connections'].append(
-                        ConnectionGene(index, first_action + action, True, index*(action_num * 2) + action)
-                    )
+                if node.node_type == 'hidden':
+                    is_output_over = True
+
+            self.genes = genes
 
         self.incoming_connections = {}
-        for connection in self.genes['connections']:
+        for connection in self.genes.connections:
             if connection.end not in self.incoming_connections.keys():
                 self.incoming_connections[connection.end] = []
             self.incoming_connections[connection.end].append(connection)
@@ -73,16 +53,17 @@ class EvolutionModel(nn.Module):
         self.indices = {}
 
         self.has_recurrent = False
-        non_recurrent_indices = {}
+        self.non_recurrent_indices = {}
+        self.recurrent_indices = {}
         with torch.no_grad():
             for key, value in self.incoming_connections.items():
                 value.sort(key=lambda element: element.start)
 
-                lin = nn.Linear(len(value), 1, bias=self.genes['nodes'][key].bias is not None)
+                lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
                 for index, connection in enumerate(value):
                     lin.weight[0, index] = value[index].weight
-                if self.genes['nodes'][key].bias is not None:
-                    lin.bias[0] = self.genes['nodes'][key].bias
+                if self.genes.nodes[key].bias is not None:
+                    lin.bias[0] = self.genes.nodes[key].bias
 
                 non_lin = nn.ELU()
                 sequence = nn.Sequential(
@@ -93,15 +74,17 @@ class EvolutionModel(nn.Module):
                 self.layers[key] = sequence
                 self.indices[key] = list(map(lambda element: element.start, value))
 
-                non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
-                if not self.has_recurrent and len(non_recurrent_indices[key]) != len(self.indices[key]):
+                self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
+                self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
+                if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
                     self.has_recurrent = True
-                non_recurrent_indices[key] = list(map(lambda element: element.start, non_recurrent_indices[key]))
+                self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
+                self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
         rank_of_node = {}
         for i in range(self.num_input_nodes):
             rank_of_node[i] = 0
 
-        layers_to_add = list(non_recurrent_indices.items())
+        layers_to_add = list(self.non_recurrent_indices.items())
         while len(layers_to_add) > 0:
             for index, (key, incoming_nodes) in enumerate(list(layers_to_add)):
                 max_rank = -1
@@ -120,44 +103,123 @@ class EvolutionModel(nn.Module):
         ranked_layers = list(rank_of_node.items())
         ranked_layers.sort(key=lambda element: element[1])
         ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
-        self.layer_order = list(map(lambda element: element[0], ranked_layers))
-        self.memory = torch.Tensor((max(map(lambda element: element[1].node_id, self.genes['nodes'].items())) + 1))
 
-    def forward(self, x, memory=None):
+        ranked_layers = list(map(lambda element: (element, 0),
+                                 filter(lambda recurrent_element:
+                                        recurrent_element not in list(
+                                            map(lambda ranked_layer: ranked_layer[0], ranked_layers)
+                                        ),
+                                        list(filter(lambda recurrent_keys:
+                                                    len(self.recurrent_indices[recurrent_keys]) > 0,
+                                                    self.recurrent_indices.keys()))))) + ranked_layers
+
+        self.layer_order = list(map(lambda element: element[0], ranked_layers))
+        self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
+        self.memory = torch.Tensor(self.memory_size)
+        self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
+
+    def forward(self, x, last_memory=None):
         x_flat = self.flatten(x)
-        if memory is None:
-            memory = torch.Tensor(self.memory)
-            outs = []
-            for batch_element in x_flat:
-                memory[0:self.num_input_nodes] = batch_element
-                for layer_index in self.layer_order:
-                    memory[layer_index] = self.layers[layer_index](memory[self.indices[layer_index]])
-                outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
-            outs = torch.stack(outs)
-            self.memory = torch.Tensor(memory)
-            return torch.reshape(outs, (x.shape[0], 4, 2))
-        else:
-            memory[:, 0:self.num_input_nodes] = x
+        if last_memory is not None:
+            last_memory_flat = self.flatten(last_memory)
+        elif self.has_recurrent:
+            raise ValueError('Recurrent networks need to be passed their previous memory!')
+
+        memory = torch.Tensor(self.memory_size)
+        outs = []
+        for batch_index, batch_element in enumerate(x_flat):
+            memory[0:self.num_input_nodes] = batch_element
             for layer_index in self.layer_order:
-                memory[:, layer_index] = self.layers[layer_index](memory[:, self.indices[layer_index]])
-            return torch.reshape(
-                memory[:, self.num_input_nodes: self.num_input_nodes + self.action_num * 2],
-                (x.shape[0], 4, 2))
+                non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
+                non_recurrent_in = torch.stack([non_recurrent_in])
+                if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
+                    recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
+                    recurrent_in = torch.stack([recurrent_in])
+
+                    combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
+                else:
+                    combined_in = non_recurrent_in
+
+                memory[layer_index] = self.layers[layer_index](combined_in)
+            outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
+        outs = torch.stack(outs)
+        self.memory = torch.Tensor(memory)
+        return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
+
+    def update_genes_with_weights(self):
+        for key, value in self.incoming_connections.items():
+            value.sort(key=lambda element: element.start)
+
+            sequence = self.layers[key]
+            lin = sequence[0]
+            for index, connection in enumerate(value):
+                value[index].weight = float(lin.weight[0, index])
+            if self.genes.nodes[key].bias is not None:
+                self.genes.nodes[key].bias = float(lin.bias[0])
+
+
+
+class RecurrentDataSet(BaseDataSet):
+    def __init__(self, states, targets, memory):
+        super().__init__(states, targets)
+        assert len(states) == len(memory), "Needs to have as many states as memories!"
+        self.memory = torch.tensor(np.array(memory), dtype=torch.float32)
+
+    def __getitem__(self, idx):
+        return self.states[idx], self.memory[idx], self.targets[idx]
+
+
+def train_recurrent(states, memory, targets, model, optimizer):
+    for action in range(model.action_num):
+        data_set = RecurrentDataSet(states[action], targets[action], memory[action])
+        dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
+        loss_fn = create_loss_function(action)
+
+        size = len(dataloader)
+        model.train()
+        for batch, (X, M, y) in enumerate(dataloader):
+            X, y, M = X.to(device), y.to(device), M.to(device)
+
+            # Compute prediction error
+            pred = model(X, M)
+            loss = loss_fn(pred, y)
+
+            # Backpropagation
+            optimizer.zero_grad()
+            loss.backward(retain_graph=True)
+            optimizer.step()
+
+            if batch % 100 == 0:
+                loss, current = loss.item(), batch * len(X)
+                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
+        model.eval()
+
+        del data_set
+        del dataloader
 
 
 if __name__ == '__main__':
-    sample = np.random.random((1, 486))
+    sample = np.random.random((1, 1))
+    last_memory = np.zeros((1, 3))
 
-    model = EvolutionModel(5, 4, 4).to(device)
-    print(model)
+    from labirinth_ai.Models.Genotype import NodeGene, ConnectionGene, Genotype
+    genes = Genotype(nodes={0: NodeGene(0, 'sensor'), 1: NodeGene(1, 'output'), 2: NodeGene(2, 'hidden', 1)},
+                     connections=[ConnectionGene(0, 2, True, 0, recurrent=True), ConnectionGene(2, 1, True, 1, 1)])
+
+    model = EvolutionModel(1, 1, 1, genes)
+
+    model = model.to(device)
+    # print(model)
     print(model.has_recurrent)
 
-    test = model(torch.tensor(sample, dtype=torch.float32))
+    test = model(torch.tensor(sample, dtype=torch.float32), torch.tensor(last_memory, dtype=torch.float32))
     # test = test.cpu().detach().numpy()
-    print(test)
+    # print(test)
 
-    state = np.random.random((1, 486))
-    target = np.random.random((4, 2))
+    state = np.random.random((1, 1))
+    memory = np.random.random((1, 1))
+
+    target = np.random.random((2, 1))
     states = [
         [state],
         [state],
@@ -170,7 +232,12 @@ if __name__ == '__main__':
         [target],
         [target],
     ]
+    memories = [
+        [memory],
+        [memory],
+        [memory],
+        [memory],
+    ]
 
     optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
-    from labirinth_ai.Models.BaseModel import train
-    train(states, targets, model, optimizer)
+    train_recurrent(states, memories, targets, model, optimizer)
diff --git a/labirinth_ai/Models/Genotype.py b/labirinth_ai/Models/Genotype.py
new file mode 100644
index 0000000..4bea59f
--- /dev/null
+++ b/labirinth_ai/Models/Genotype.py
@@ -0,0 +1,139 @@
+from abc import abstractmethod
+from typing import List, Dict
+
+import numpy as np
+
+
+class NodeGene:
+    valid_types = ['sensor', 'hidden', 'output']
+
+    def __init__(self, node_id, node_type, bias=None):
+        assert node_type in self.valid_types, 'Unknown node type!'
+        self.node_id = node_id
+        self.node_type = node_type
+        if node_type == 'hidden':
+            assert bias is not None, 'Expected a bias for hidden node types!'
+            self.bias = bias
+        else:
+            self.bias = None
+
+
+class ConnectionGene:
+    def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
+        self.start = start
+        self.end = end
+        self.enabled = enabled
+        self.innvovation_num = innovation_num
+        self.recurrent = recurrent
+        if weight is None:
+            self.weight = np.random.random(1)[0] * 2 - 1.0
+        else:
+            self.weight = weight
+
+
+class Genotype:
+    def __init__(self, action_num: int = None, num_input_nodes: int = None,
+                 nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
+        self.nodes = {}
+        self.connections = []
+        if action_num is not None and num_input_nodes is not None:
+            node_id = 0
+            for _ in range(num_input_nodes):
+                self.nodes[node_id] = NodeGene(node_id, 'sensor')
+                node_id += 1
+            first_action = node_id
+            for _ in range(action_num * 2):
+                self.nodes[node_id] = NodeGene(node_id, 'output')
+                node_id += 1
+
+            for index in range(num_input_nodes):
+                for action in range(action_num * 2):
+                    self.connections.append(
+                        ConnectionGene(index, first_action + action, True, index * (action_num * 2) + action)
+                    )
+        if nodes is not None and connections is not None:
+            self.nodes = nodes
+            self.connections = connections
+
+    def calculate_rank_of_nodes(self):
+        rank_of_node = {}
+        nodes_to_rank = list(self.nodes.items())
+        while len(nodes_to_rank) > 0:
+            for list_index, (id, node) in enumerate(nodes_to_rank):
+                incoming_connections = list(filter(lambda connection: connection.end == id and
+                                                                      not connection.recurrent, self.connections))
+                if len(incoming_connections) == 0:
+                    rank_of_node[id] = 0
+                    nodes_to_rank.pop(list_index)
+                    break
+
+                incoming_connections_starts = list(map(lambda connection: connection.start, incoming_connections))
+                start_ranks = list(map(lambda element: rank_of_node[element[0]],
+                                       filter(lambda start_node: start_node[0] in incoming_connections_starts and
+                                                                 start_node[0] in rank_of_node.keys(),
+                                              self.nodes.items())))
+                if len(start_ranks) == len(incoming_connections):
+                    rank_of_node[id] = max(start_ranks) + 1
+                    nodes_to_rank.pop(list_index)
+                    break
+        return rank_of_node
+
+    @abstractmethod
+    def mutate(self, innovation_num) -> int:
+        """
+        Decides whether or not to mutate this network. Then returns the new innovation number.
+        :param innovation_num: Current innovation number
+        :return: Updated innovation number
+        """
+
+        # return innovation_num
+        raise NotImplementedError()
+
+    @abstractmethod
+    def cross(self, other):
+        raise NotImplementedError()
+        # return self
+
+
+class NeatLike(Genotype):
+    connection_add_thr = 0.3
+    node_add_thr = 0.3
+
+    def mutate(self, innovation_num, allow_recurrent=False) -> int:
+        """
+        Decides whether or not to mutate this network. Then returns the new innovation number.
+        :param allow_recurrent: Optional parameter allowing or disallowing recurrent connections to form
+        :param innovation_num: Current innovation number
+        :return: Updated innovation number
+        """
+        # add connection
+        if np.random.random(1)[0] < self.connection_add_thr or True:
+            nodes = list(self.nodes.keys())
+            rank_of_node = self.calculate_rank_of_nodes()
+            end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
+
+            connection_tuple = list(map(lambda connection: (connection.start, connection.end), self.connections))
+
+            start = np.random.randint(0, len(nodes))
+            end = np.random.randint(0, len(end_nodes))
+
+            tries = 50
+            while (rank_of_node[end_nodes[end]] == 0 or
+                   ((not allow_recurrent) and rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]])
+                   or nodes[start] == end_nodes[end] or (nodes[start], end_nodes[end]) in connection_tuple) and\
+                    tries > 0:
+                end = np.random.randint(0, len(end_nodes))
+                if (not allow_recurrent) and rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]:
+                    start = np.random.randint(0, len(nodes))
+                tries -= 1
+            if tries > 0:
+                innovation_num += 1
+                self.connections.append(
+                    ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
+                                   recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
+        #todo add node
+
+        return innovation_num
+
+    def cross(self, other):
+        return self
diff --git a/labirinth_ai/Population.py b/labirinth_ai/Population.py
new file mode 100644
index 0000000..70eef4f
--- /dev/null
+++ b/labirinth_ai/Population.py
@@ -0,0 +1,97 @@
+import random
+import numpy as np
+
+from labirinth_ai.Models.Genotype import NeatLike
+
+
+def fib(n):
+    if n == 0:
+        return [1]
+    elif n < 0:
+        return [0]
+    else:
+        return [fib(n - 1)[0] + fib(n - 2)[0]] + fib(n - 1)
+
+
+class Population:
+    def __init__(self, subject_class, world, subject_number):
+        self.subjects = []
+        self.world = world
+        for _ in range(subject_number):
+            px, py = self.world.generate_free_coordinates()
+            self.subjects.append(subject_class(px, py, genotype_class=NeatLike))
+        self.subject_number = subject_number
+        self.subject_class = subject_class
+
+    def select(self):
+        ranked = list(self.subjects)
+        ranked.sort(key=lambda subject: subject.accumulated_rewards, reverse=True)
+
+        return ranked[:int(self.subject_number / 2)]
+
+    @classmethod
+    def scatter(cls, n, buckets):
+        out = np.zeros(buckets)
+        if n == 0:
+            return out
+
+        fib_number = 0
+        fibs = fib(fib_number)
+        while np.sum(fibs) <= n and len(fibs) <= buckets:
+            fib_number += 1
+            fibs = fib(fib_number)
+        fib_number -= 1
+        fibs = fib(fib_number)
+
+        for bucket in range(buckets):
+            if bucket < len(fibs):
+                out[bucket] += fibs[bucket]
+            else:
+                break
+
+        return out + cls.scatter(n - np.sum(fibs), buckets)
+
+    def evolve(self):
+        # get updated weights from the models
+        for subject in self.subjects:
+            subject.model.update_genes_with_weights()
+
+        # crossbreed the current pop
+        best_subjects = self.select()
+        distribution = list(self.scatter(self.subject_number - int(self.subject_number / 2), int(self.subject_number / 2)))
+
+        new_subjects = list(best_subjects)
+        for index, offspring_num in enumerate(distribution):
+            for _ in range(int(offspring_num)):
+                parent_1 = best_subjects[index]
+                parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
+
+                new_genes = parent_1.model.genes.cross(parent_2.model.genes)
+
+                # position doesn't matter, since mutation will set it
+                new_subject = self.subject_class(0, 0, new_genes)
+                new_subject.history = parent_1.history
+                new_subject.samples = parent_1.samples + parent_2.samples
+                new_subjects.append(new_subject)
+
+        assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
+
+        # mutate the pop
+        mutated_subjects = []
+        innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,
+                                                         subject.model.genes.connections
+                                                         )
+                                                     )
+                             , new_subjects))
+        for subject in new_subjects:
+            subject.accumulated_rewards = 0
+
+            innovation_num = subject.model.genes.mutate(innovation_num)
+
+            px, py = self.world.generate_free_coordinates()
+            new_subject = self.subject_class(px, py, subject.model.genes)
+            new_subject.history = subject.history
+            new_subject.samples = subject.samples
+            mutated_subjects.append(new_subject)
+
+        self.subjects = mutated_subjects
diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py
index f9426eb..dc8e886 100644
--- a/labirinth_ai/Subject.py
+++ b/labirinth_ai/Subject.py
@@ -4,6 +4,7 @@ import tensorflow as tf
 from tensorflow import keras
 
 from labirinth_ai.LabyrinthWorld import LabyrinthWorld
+from labirinth_ai.Models.EvolutionModel import EvolutionModel
 from labirinth_ai.loss import loss2, loss3
 from labirinth_ai.Models.BaseModel import BaseModel, train, create_optimizer, device, from_numpy
 
@@ -350,7 +351,7 @@ class NetLearner(Subject):
 
         self.strikes = 0
 
-    def __init__(self, x, y):
+    def __init__(self, x, y, genes=None, genotype_class=None):
         super(NetLearner, self).__init__(x, y)
 
         self.action = None
@@ -370,7 +371,10 @@ class NetLearner(Subject):
         self.x_in = []
         self.actions = []
         self.target = []
-        self.model = BaseModel(self.viewD, 4, 4).to(device)
+
+        # self.model = BaseModel(self.viewD, 4, 4).to(device)
+        self.model = EvolutionModel(self.viewD, 4, 4, genes=genes, genotype_class=genotype_class).to(device)
+
         self.optimizer = create_optimizer(self.model)
 
         if len(self.samples) < self.randomBuffer:
@@ -540,9 +544,11 @@ class NetLearner(Subject):
 
             # if len(self.samples) % self.batchsize == 0 and len(self.samples) >= self.randomBuffer:
             if len(self.samples) > self.nextTrain and doTrain:
-                print('train')
+                print('train', len(self.samples))
                 self.train()
+                self.nextTrain = len(self.samples)
                 self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
+                print(len(self.samples), self.nextTrain)
 
         self.accumulated_rewards += self.lastReward
 
@@ -657,23 +663,6 @@ class Herbivore(NetLearner):
 
     samples = []
 
-    # x_in = keras.Input(shape=(4 * (2 * viewD + 1) * (2 * viewD + 1) + 2))
-    # target = keras.Input(shape=(10, 1))
-    # inVec = keras.layers.Flatten()(x_in)
-    # # kernel_regularizer=keras.regularizers.l2(0.01)
-    # actions = keras.layers.Dense((4 * (2 * viewD + 1) * (2 * viewD + 1)), activation='elu')(inVec)
-    # actions = keras.layers.Dense(((2 * viewD + 1) * (2 * viewD + 1)), activation='elu')(actions)
-    # actions = keras.layers.Dense(8, activation='linear', use_bias=False)(actions)
-    # # actions = keras.layers.Dense(4, activation='linear', use_bias=False)(inVec)
-    #
-    # model = keras.Model(inputs=x_in, outputs=actions)
-    #
-    # # model.compile(optimizer='adam', loss=loss2, target_tensors=[target])
-    # model.compile(optimizer=tf.keras.optimizers.RMSprop(learningRate), loss=loss2, target_tensors=[target])
-
-    # def __init__(self, x, y):
-    #     super(Herbivore, self).__init__(x, y)
-
     def createState(self, world: LabyrinthWorld):
         state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
         state2 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1

From 26e7ffb12b5534f8697e6b49f62aa86b9e8985ba Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Mon, 14 Nov 2022 11:17:00 +0100
Subject: [PATCH 12/14] reworks loss function and training data, cleans up
 directions in code

---
 labirinth_ai/Models/BaseModel.py |  49 +++---
 labirinth_ai/Subject.py          | 255 ++++++++++++-------------------
 2 files changed, 129 insertions(+), 175 deletions(-)

diff --git a/labirinth_ai/Models/BaseModel.py b/labirinth_ai/Models/BaseModel.py
index 2434b61..250e3b5 100644
--- a/labirinth_ai/Models/BaseModel.py
+++ b/labirinth_ai/Models/BaseModel.py
@@ -1,7 +1,7 @@
 import torch
 from torch import nn
 import numpy as np
-import tqdm
+from tqdm import tqdm
 from torch.utils.data import Dataset, DataLoader
 
 import os
@@ -14,6 +14,7 @@ print(f"Using {device} device")
 # Define model
 class BaseModel(nn.Module):
     evolutionary = False
+
     def __init__(self, view_dimension, action_num, channels):
         super(BaseModel, self).__init__()
         self.flatten = nn.Flatten()
@@ -59,11 +60,18 @@ def create_optimizer(model):
 
 
 def create_loss_function(action):
+    lambda_factor = 0.0
+    split_factor = 1.0
     def custom_loss(prediction, target):
-        return torch.mean(0.5 * torch.square(
-            0.1 * target[:, 0, 0] + target[:, 1, 0] - (
-                        prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
-            target[:, 1, 0] - prediction[:, action, 0]), dim=0)
+        return torch.mean(split_factor * torch.square(
+            # discounted best estimate the old weights made for t+1
+            lambda_factor * target[:, 0, 0] +
+            # actual reward for t
+            target[:, 1, 0] -
+            # estimate for current weights
+            (prediction[:, action, 0] + prediction[:, action, 1])) +
+                          # trying to learn present reward separate from future reward
+                          (1.0 - split_factor) * torch.square(target[:, 1, 0] - prediction[:, action, 0]), dim=0)
 
     return custom_loss
 
@@ -75,26 +83,31 @@ def from_numpy(x):
 def train(states, targets, model, optimizer):
     for action in range(model.action_num):
         data_set = BaseDataSet(states[action], targets[action])
-        dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
+        dataloader = DataLoader(data_set, batch_size=256, shuffle=True)
         loss_fn = create_loss_function(action)
 
         size = len(dataloader)
         model.train()
-        for batch, (X, y) in enumerate(dataloader):
-            X, y = X.to(device), y.to(device)
 
-            # Compute prediction error
-            pred = model(X)
-            loss = loss_fn(pred, y)
+        epochs = 1
+        with tqdm(range(epochs)) as progress_bar:
+            for _ in enumerate(progress_bar):
+                losses = []
+                for batch, (X, y) in enumerate(dataloader):
+                    X, y = X.to(device), y.to(device)
 
-            # Backpropagation
-            optimizer.zero_grad()
-            loss.backward(retain_graph=True)
-            optimizer.step()
+                    # Compute prediction error
+                    pred = model(X)
+                    loss = loss_fn(pred, y)
 
-            if batch % 100 == 0:
-                loss, current = loss.item(), batch * len(X)
-                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
+                    # Backpropagation
+                    optimizer.zero_grad()
+                    loss.backward(retain_graph=True)
+                    optimizer.step()
+
+                    losses.append(loss.item())
+                progress_bar.set_postfix(loss=np.average(losses))
+                progress_bar.update()
         model.eval()
 
         del data_set
diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py
index dc8e886..762d9df 100644
--- a/labirinth_ai/Subject.py
+++ b/labirinth_ai/Subject.py
@@ -108,34 +108,6 @@ class QLearner(Subject):
     def createState(self, world: LabyrinthWorld):
         state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.int)  # - 1
 
-        # # floodfill state
-        # queued = [(0, 0)]
-        # todo = [(0, 0, 0)]
-        # while todo != []:
-        #     doing = todo.pop(0)
-        #
-        #     if self.x + doing[0] >= 0 and self.x + doing[0] < 64 and self.y + doing[1] >= 0 and self.y + doing[1] < 64:
-        #         value = world.board[self.x + doing[0], self.y + doing[1]]
-        #         state[self.viewD + doing[0], self.viewD + doing[1]] = value
-        #
-        #         # if value == 3:
-        #         #     state[self.viewD + doing[0], self.viewD + doing[1]] = value
-        #
-        #         if value != 0 and doing[2] < self.viewD:
-        #             for i in range(-1, 2, 1):
-        #                 for j in range(-1, 2, 1):
-        #                     # 4-neighbour. without it it is 8-neighbour
-        #                     if abs(i) + abs(j) == 1:
-        #                         if (doing[0] + i, doing[1] + j) not in queued:
-        #                             queued.append((doing[0] + i, doing[1] + j))
-        #                             todo.append((doing[0] + i, doing[1] + j, doing[2] + 1))
-        #
-        # for sub in world.subjects:
-        #     if (sub.x - self.x, sub.y - self.y) in queued and state[
-        #         self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] != 3:
-        #         state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] = state[
-        #                                                                               self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] * 100 + sub.col
-
         maxdirleft = self.x - max(self.x - (self.viewD), 0)
         maxdirright = min(self.x + (self.viewD), (world.board_shape[0] - 1)) - self.x
         maxdirup = self.y - max(self.y - (self.viewD), 0)
@@ -300,6 +272,7 @@ class DoubleQLearner(QLearner):
             pass
 
 
+RECALCULATE = False
 class NetLearner(Subject):
     right = (1, 0)
     left = (-1, 0)
@@ -440,7 +413,6 @@ class NetLearner(Subject):
             axs[1, 1].set_title('grass')
             plt.show(block=True)
 
-
     def createState(self, world: LabyrinthWorld):
         state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
         state2 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float)  # - 1
@@ -474,10 +446,8 @@ class NetLearner(Subject):
             action = self.lastAction
         return np.reshape(np.concatenate((area, action)), (1, 4 * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
 
-    def calculateAction(self, world: LabyrinthWorld, vals=None, state=None):
-        # 0, 0 is top left
+    def generate_valid_directions(self, world: LabyrinthWorld):
         directions = []
-
         if self.x - 1 >= 0:
             if world.board[self.x - 1, self.y] != 0:
                 directions.append(self.left)
@@ -493,9 +463,15 @@ class NetLearner(Subject):
         if self.y + 1 < world.board_shape[1]:
             if world.board[self.x, self.y + 1] != 0:
                 directions.append(self.down)
+        return directions
+
+    def calculateAction(self, world: LabyrinthWorld, vals=None, state=None):
+        # 0, 0 is top left
+        directions = self.generate_valid_directions(world)
 
         if directions == []:
             print('Wut?')
+            return
 
         if directions != [] and self.alive:
             if state is None:
@@ -550,7 +526,8 @@ class NetLearner(Subject):
                 self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
                 print(len(self.samples), self.nextTrain)
 
-        self.accumulated_rewards += self.lastReward
+        if not self.random:
+            self.accumulated_rewards += self.lastReward
 
         self.lastAction = self.action
         self.lastState = self.state
@@ -562,27 +539,10 @@ class NetLearner(Subject):
         self.executeAction(world, self.action)
 
     def randomAct(self, world: LabyrinthWorld):
-        right = (1, 0)
-        left = (-1, 0)
-        up = (0, -1)
-        down = (0, 1)
-        directions = []
+        directions = self.generate_valid_directions(world)
 
-        if self.x - 1 >= 0:
-            if world.board[self.x - 1, self.y] != 0:
-                directions.append(left)
-
-        if self.x + 1 < world.board_shape[0]:
-            if world.board[self.x + 1, self.y] != 0:
-                directions.append(right)
-
-        if self.y - 1 >= 0:
-            if world.board[self.x, self.y - 1] != 0:
-                directions.append(up)
-
-        if self.y + 1 < world.board_shape[1]:
-            if world.board[self.x, self.y + 1] != 0:
-                directions.append(down)
+        if len(directions) == 0:
+            return 0, 0
 
         d = random.randint(0, len(directions) - 1)
         action = directions[d]
@@ -616,16 +576,16 @@ class NetLearner(Subject):
                                      replace=True)
             samples = samples[index]
             # self.samples = []
+            target[:, 1, 0] = samples[:, 0, 3]  # reward t-2 got
             if partTwo:
-                target[:, 1, 0] = samples[:, 1, 3] #reward t-2 got
+                if RECALCULATE:
+                    nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
+                    nextVals = self.model(from_numpy(nextState)).detach().numpy()
 
-                nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
-                nextVals = self.model(from_numpy(nextState)).detach().numpy()
-
-                nextVals2 = nextVals[:, i,  0] + nextVals[:, i, 1]
-                target[:, 0, 0] = nextVals2 #best q t-1
-            else:
-                target[:, 1, 0] = np.array(list(map(lambda elem: list(elem), list(np.array(samples[:, 1, 4])))))[:, i]  # reward t-2 got
+                    nextVals2 = np.max(nextVals[:, :,  0] + nextVals[:, :, 1], axis=1)
+                    target[:, 0, 0] = nextVals2 #best q t-1
+                else:
+                    target[:, 0, 0] = samples[:, 1, 2] #best q t-1
 
             targets.append(target)
 
@@ -697,27 +657,7 @@ class Herbivore(NetLearner):
         return np.reshape(np.concatenate((area, action)), (1, 4 * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
 
     def executeAction(self, world: LabyrinthWorld, action):
-        right = (1, 0)
-        left = (-1, 0)
-        up = (0, -1)
-        down = (0, 1)
-        directions = []
-
-        if self.x - 1 >= 0:
-            if world.board[self.x - 1, self.y] != 0:
-                directions.append(left)
-
-        if self.x + 1 < world.board_shape[0]:
-            if world.board[self.x + 1, self.y] != 0:
-                directions.append(right)
-
-        if self.y - 1 >= 0:
-            if world.board[self.x, self.y - 1] != 0:
-                directions.append(up)
-
-        if self.y + 1 < world.board_shape[1]:
-            if world.board[self.x, self.y + 1] != 0:
-                directions.append(down)
+        directions = self.generate_valid_directions(world)
         if len(action) == 2:
             if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
                 for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
@@ -729,26 +669,26 @@ class Herbivore(NetLearner):
                             self.alive = False
 
             self.lastRewards = []
-            if right in directions:
+            if self.right in directions:
                 self.lastRewards.append(world.grass[self.x + 1, self.y])
             else:
                 self.lastRewards.append(0)
-            if left in directions:
+            if self.left in directions:
                 self.lastRewards.append(world.grass[self.x - 1, self.y])
             else:
                 self.lastRewards.append(0)
-            if up in directions:
+            if self.up in directions:
                 self.lastRewards.append(world.grass[self.x, self.y - 1])
             else:
                 self.lastRewards.append(0)
-            if down in directions:
+            if self.down in directions:
                self.lastRewards.append(world.grass[self.x, self.y + 1])
             else:
                 self.lastRewards.append(0)
             assert len(self.lastRewards) == 4, 'Last Rewards not filled correctly!'
 
             world.subjectDict[(self.x, self.y)].remove(self)
-            self.lastReward += world.trailMix[self.x, self.y]
+            # self.lastReward += world.trailMix[self.x, self.y]
             self.x += action[0]
             self.y += action[1]
             world.subjectDict[(self.x, self.y)].append(self)
@@ -757,33 +697,58 @@ class Herbivore(NetLearner):
             world.grass[self.x, self.y] = 0
             world.hunter_grass[self.x, self.y] = 0
 
+    def generate_valid_directions(self, world: LabyrinthWorld):
+        directions = []
+        if self.x - 1 >= 0:
+            if world.board[self.x - 1, self.y] != 0:
+                if not world.subjectDict[(self.x - 1, self.y)]:
+                    directions.append(self.left)
+
+        if self.x + 1 < world.board_shape[0]:
+            if world.board[self.x + 1, self.y] != 0:
+                if not world.subjectDict[(self.x + 1, self.y)]:
+                    directions.append(self.right)
+
+        if self.y - 1 >= 0:
+            if world.board[self.x, self.y - 1] != 0:
+                if not world.subjectDict[(self.x, self.y - 1)]:
+                    directions.append(self.up)
+
+        if self.y + 1 < world.board_shape[1]:
+            if world.board[self.x, self.y + 1] != 0:
+                if not world.subjectDict[(self.x, self.y + 1)]:
+                    directions.append(self.down)
+        return directions
+
     def randomAct(self, world: LabyrinthWorld):
-        right = (1, 0)
-        left = (-1, 0)
-        up = (0, -1)
-        down = (0, 1)
         directions = []
         actDict = {}
 
         if self.x - 1 >= 0:
             if world.board[self.x - 1, self.y] != 0:
-                directions.append(left)
-                actDict[left] = world.grass[self.x - 1, self.y]
+                if not world.subjectDict[(self.x - 1, self.y)]:
+                    directions.append(self.left)
+                    actDict[self.left] = world.grass[self.x - 1, self.y]
 
         if self.x + 1 < world.board_shape[0]:
             if world.board[self.x + 1, self.y] != 0:
-                directions.append(right)
-                actDict[right] = world.grass[self.x + 1, self.y]
+                if not world.subjectDict[(self.x + 1, self.y)]:
+                    directions.append(self.right)
+                    actDict[self.right] = world.grass[self.x + 1, self.y]
 
         if self.y - 1 >= 0:
             if world.board[self.x, self.y - 1] != 0:
-                directions.append(up)
-                actDict[up] = world.grass[self.x, self.y - 1]
+                if not world.subjectDict[(self.x, self.y - 1)]:
+                    directions.append(self.up)
+                    actDict[self.up] = world.grass[self.x, self.y - 1]
 
         if self.y + 1 < world.board_shape[1]:
             if world.board[self.x, self.y + 1] != 0:
-                directions.append(down)
-                actDict[down] = world.grass[self.x, self.y + 1]
+                if not world.subjectDict[(self.x, self.y + 1)]:
+                    directions.append(self.down)
+                    actDict[self.down] = world.grass[self.x, self.y + 1]
+        if len(directions) == 0:
+            return 0, 0
 
         allowedActions = dict(filter(lambda elem: elem[0] in directions, actDict.items()))
         action = max(allowedActions, key=allowedActions.get)
@@ -792,7 +757,7 @@ class Herbivore(NetLearner):
 
     def respawnUpdate(self, x, y, world: LabyrinthWorld):
         super(Herbivore, self).respawnUpdate(x, y, world)
-        self.lastReward -= 1
+        # self.lastReward -= 1
 
 
 class Hunter(NetLearner):
@@ -802,16 +767,12 @@ class Hunter(NetLearner):
     g = 255
     b = 255
     def randomAct(self, world: LabyrinthWorld):
-        right = (1, 0)
-        left = (-1, 0)
-        up = (0, -1)
-        down = (0, 1)
         directions = []
         actDict = {}
 
         if self.x - 1 >= 0:
             if world.board[self.x - 1, self.y] > 0.01:
-                directions.append(left)
+                directions.append(self.left)
 
                 sub = self.getClosestSubject(world, self.x - 1, self.y)
                 dist = self.viewD
@@ -819,15 +780,15 @@ class Hunter(NetLearner):
                     dist = np.sqrt(np.square(self.x - 1 - sub.x) + np.square(self.y - sub.y))
                 distReward = self.viewD - dist
 
-                actDict[left] = world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * self.hunterGrassScale + distReward
-                if len(world.subjectDict[(self.x + left[0], self.y + left[1])]) > 0:
-                    for sub in world.subjectDict[(self.x + left[0], self.y + left[1])]:
+                actDict[self.left] = world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + self.left[0], self.y + self.left[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + self.left[0], self.y + self.left[1])]:
                         if sub.col != self.col:
-                            actDict[left] += 10
+                            actDict[self.left] += 10
 
         if self.x + 1 < world.board_shape[0]:
             if world.board[self.x + 1, self.y] > 0.01:
-                directions.append(right)
+                directions.append(self.right)
 
                 sub = self.getClosestSubject(world, self.x + 1, self.y)
                 dist = self.viewD
@@ -835,15 +796,15 @@ class Hunter(NetLearner):
                     dist = np.sqrt(np.square(self.x + 1 - sub.x) + np.square(self.y - sub.y))
                 distReward = self.viewD - dist
 
-                actDict[right] = world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * self.hunterGrassScale + distReward
-                if len(world.subjectDict[(self.x + right[0], self.y + right[1])]) > 0:
-                    for sub in world.subjectDict[(self.x + right[0], self.y + right[1])]:
+                actDict[self.right] = world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + self.right[0], self.y + self.right[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + self.right[0], self.y + self.right[1])]:
                         if sub.col != self.col:
-                            actDict[right] += 10
+                            actDict[self.right] += 10
 
         if self.y - 1 >= 0:
             if world.board[self.x, self.y - 1] > 0.01:
-                directions.append(up)
+                directions.append(self.up)
 
                 sub = self.getClosestSubject(world, self.x, self.y - 1)
                 dist = self.viewD
@@ -851,15 +812,15 @@ class Hunter(NetLearner):
                     dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y - 1 - sub.y))
                 distReward = self.viewD - dist
 
-                actDict[up] = world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * self.hunterGrassScale + distReward
-                if len(world.subjectDict[(self.x + up[0], self.y + up[1])]) > 0:
-                    for sub in world.subjectDict[(self.x + up[0], self.y + up[1])]:
+                actDict[self.up] = world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + self.up[0], self.y + self.up[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + self.up[0], self.y + self.up[1])]:
                         if sub.col != self.col:
-                            actDict[up] += 10
+                            actDict[self.up] += 10
 
         if self.y + 1 < world.board_shape[1]:
             if world.board[self.x, self.y + 1] > 0.01:
-                directions.append(down)
+                directions.append(self.down)
 
                 sub = self.getClosestSubject(world, self.x, self.y + 1)
                 dist = self.viewD
@@ -867,11 +828,11 @@ class Hunter(NetLearner):
                     dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y + 1 - sub.y))
                 distReward = self.viewD - dist
 
-                actDict[down] = world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * self.hunterGrassScale + distReward
-                if len(world.subjectDict[(self.x + down[0], self.y + down[1])]) > 0:
-                    for sub in world.subjectDict[(self.x + down[0], self.y + down[1])]:
+                actDict[self.down] = world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * self.hunterGrassScale + distReward
+                if len(world.subjectDict[(self.x + self.down[0], self.y + self.down[1])]) > 0:
+                    for sub in world.subjectDict[(self.x + self.down[0], self.y + self.down[1])]:
                         if sub.col != self.col:
-                            actDict[down] += 10
+                            actDict[self.down] += 10
 
         if len(actDict) > 0:
             allowedActions = dict(filter(lambda elem: elem[0] in directions, actDict.items()))
@@ -919,47 +880,27 @@ class Hunter(NetLearner):
     def executeAction(self, world: LabyrinthWorld, action):
         grass_factor = 0.5
 
-        right = (1, 0)
-        left = (-1, 0)
-        up = (0, -1)
-        down = (0, 1)
-        directions = []
-
-        if self.x - 1 >= 0:
-            if world.board[self.x - 1, self.y] != 0:
-                directions.append(left)
-
-        if self.x + 1 < world.board_shape[0]:
-            if world.board[self.x + 1, self.y] != 0:
-                directions.append(right)
-
-        if self.y - 1 >= 0:
-            if world.board[self.x, self.y - 1] != 0:
-                directions.append(up)
-
-        if self.y + 1 < world.board_shape[1]:
-            if world.board[self.x, self.y + 1] != 0:
-                directions.append(down)
+        directions = self.generate_valid_directions(world)
 
         if len(action) == 2:
             right_kill = left_kill = up_kill = down_kill = False
-            if right in directions:
-                for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
+            if self.right in directions:
+                for sub in world.subjectDict[(self.x + self.right[0], self.y + self.right[1])]:
                     if sub.alive:
                         if sub.col != self.col:
                             right_kill = True
-            if left in directions:
-                for sub in world.subjectDict[(self.x + left[0], self.y + left[1])]:
+            if self.left in directions:
+                for sub in world.subjectDict[(self.x + self.left[0], self.y + self.left[1])]:
                     if sub.alive:
                         if sub.col != self.col:
                             left_kill = True
-            if up in directions:
-                for sub in world.subjectDict[(self.x + up[0], self.y + up[1])]:
+            if self.up in directions:
+                for sub in world.subjectDict[(self.x + self.up[0], self.y + self.up[1])]:
                     if sub.alive:
                         if sub.col != self.col:
                             up_kill = True
-            if down in directions:
-                for sub in world.subjectDict[(self.x + down[0], self.y + down[1])]:
+            if self.down in directions:
+                for sub in world.subjectDict[(self.x + self.down[0], self.y + self.down[1])]:
                     if sub.alive:
                         if sub.col != self.col:
                             down_kill = True
@@ -974,7 +915,7 @@ class Hunter(NetLearner):
                     self.alive = True
 
             self.lastRewards = []
-            if right in directions:
+            if self.right in directions:
                 sub = self.getClosestSubject(world, self.x + 1, self.y)
                 dist = self.viewD
                 if sub is not None:
@@ -986,7 +927,7 @@ class Hunter(NetLearner):
                     self.lastRewards.append(world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * grass_factor + distReward)
             else:
                 self.lastRewards.append(0)
-            if left in directions:
+            if self.left in directions:
                 sub = self.getClosestSubject(world, self.x - 1, self.y)
                 dist = self.viewD
                 if sub is not None:
@@ -998,7 +939,7 @@ class Hunter(NetLearner):
                     self.lastRewards.append(world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * grass_factor + distReward)
             else:
                 self.lastRewards.append(0)
-            if up in directions:
+            if self.up in directions:
                 sub = self.getClosestSubject(world, self.x, self.y - 1)
                 dist = self.viewD
                 if sub is not None:
@@ -1010,7 +951,7 @@ class Hunter(NetLearner):
                     self.lastRewards.append(world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * grass_factor + distReward)
             else:
                 self.lastRewards.append(0)
-            if down in directions:
+            if self.down in directions:
                 sub = self.getClosestSubject(world, self.x, self.y + 1)
                 dist = self.viewD
                 if sub is not None:

From bd561733795d8ed3d04c52a8b79b5bdc5dd7050c Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Mon, 14 Nov 2022 11:18:57 +0100
Subject: [PATCH 13/14] make evolution optional

---
 labirinth_ai/LabyrinthWorld.py        |   8 +-
 labirinth_ai/Models/EvolutionModel.py |   6 +-
 labirinth_ai/Models/Genotype.py       | 105 +++++++++++++++++++++++---
 labirinth_ai/Population.py            |  76 ++++++++++---------
 4 files changed, 144 insertions(+), 51 deletions(-)

diff --git a/labirinth_ai/LabyrinthWorld.py b/labirinth_ai/LabyrinthWorld.py
index f2adaf9..bea0ee5 100644
--- a/labirinth_ai/LabyrinthWorld.py
+++ b/labirinth_ai/LabyrinthWorld.py
@@ -30,8 +30,8 @@ class LabyrinthWorld(World):
         self.lastUpdate = time.time()
         self.nextTrain = self.randomBuffer
         self.round = 1
-        self.evolve_timer = 10
-        # self.evolve_timer = 1500
+        # self.evolve_timer = 10
+        self.evolve_timer = 1500
 
         self.trailMix = np.zeros(self.board_shape)
         self.grass = np.zeros(self.board_shape)
@@ -163,9 +163,9 @@ class LabyrinthWorld(World):
         # adding subjects
         from labirinth_ai.Subject import Hunter, Herbivore
         from labirinth_ai.Population import Population
-        self._hunters = Population(Hunter, self, 10)
+        self._hunters = Population(Hunter, self, 10, do_evolve=False)
 
-        self._herbivores = Population(Herbivore, self, 40)
+        self._herbivores = Population(Herbivore, self, 40, do_evolve=False)
 
         self.subjectDict = self.build_subject_dict()
 
diff --git a/labirinth_ai/Models/EvolutionModel.py b/labirinth_ai/Models/EvolutionModel.py
index 8a180d5..1b7d79f 100644
--- a/labirinth_ai/Models/EvolutionModel.py
+++ b/labirinth_ai/Models/EvolutionModel.py
@@ -1,7 +1,6 @@
 import torch
 from torch import nn
 import numpy as np
-import tqdm
 from torch.utils.data import Dataset, DataLoader
 from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
 from labirinth_ai.Models.Genotype import Genotype
@@ -45,6 +44,8 @@ class EvolutionModel(nn.Module):
 
         self.incoming_connections = {}
         for connection in self.genes.connections:
+            if not connection.enabled:
+                continue
             if connection.end not in self.incoming_connections.keys():
                 self.incoming_connections[connection.end] = []
             self.incoming_connections[connection.end].append(connection)
@@ -158,7 +159,6 @@ class EvolutionModel(nn.Module):
                 self.genes.nodes[key].bias = float(lin.bias[0])
 
 
-
 class RecurrentDataSet(BaseDataSet):
     def __init__(self, states, targets, memory):
         super().__init__(states, targets)
@@ -172,7 +172,7 @@ class RecurrentDataSet(BaseDataSet):
 def train_recurrent(states, memory, targets, model, optimizer):
     for action in range(model.action_num):
         data_set = RecurrentDataSet(states[action], targets[action], memory[action])
-        dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
+        dataloader = DataLoader(data_set, batch_size=512, shuffle=True)
         loss_fn = create_loss_function(action)
 
         size = len(dataloader)
diff --git a/labirinth_ai/Models/Genotype.py b/labirinth_ai/Models/Genotype.py
index 4bea59f..782525b 100644
--- a/labirinth_ai/Models/Genotype.py
+++ b/labirinth_ai/Models/Genotype.py
@@ -1,5 +1,6 @@
 from abc import abstractmethod
 from typing import List, Dict
+from copy import copy
 
 import numpy as np
 
@@ -12,11 +13,15 @@ class NodeGene:
         self.node_id = node_id
         self.node_type = node_type
         if node_type == 'hidden':
-            assert bias is not None, 'Expected a bias for hidden node types!'
+            if bias is None:
+                bias = np.random.random(1)[0] * 2 - 1.0
             self.bias = bias
         else:
             self.bias = None
 
+    def __copy__(self):
+        return NodeGene(self.node_id, self.node_type, bias=self.bias)
+
 
 class ConnectionGene:
     def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
@@ -30,12 +35,15 @@ class ConnectionGene:
         else:
             self.weight = weight
 
+    def __copy__(self):
+        return ConnectionGene(self.start, self.end, self.enabled, self.innvovation_num, self.weight, self.recurrent)
+
 
 class Genotype:
     def __init__(self, action_num: int = None, num_input_nodes: int = None,
                  nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
-        self.nodes = {}
-        self.connections = []
+        self.nodes: Dict[int, NodeGene] = {}
+        self.connections: List[ConnectionGene] = []
         if action_num is not None and num_input_nodes is not None:
             node_id = 0
             for _ in range(num_input_nodes):
@@ -61,7 +69,8 @@ class Genotype:
         while len(nodes_to_rank) > 0:
             for list_index, (id, node) in enumerate(nodes_to_rank):
                 incoming_connections = list(filter(lambda connection: connection.end == id and
-                                                                      not connection.recurrent, self.connections))
+                                                                      not connection.recurrent and connection.enabled,
+                                                   self.connections))
                 if len(incoming_connections) == 0:
                     rank_of_node[id] = 0
                     nodes_to_rank.pop(list_index)
@@ -90,7 +99,7 @@ class Genotype:
         raise NotImplementedError()
 
     @abstractmethod
-    def cross(self, other):
+    def cross(self, other, fitnes_self, fitness_other):
         raise NotImplementedError()
         # return self
 
@@ -98,6 +107,11 @@ class Genotype:
 class NeatLike(Genotype):
     connection_add_thr = 0.3
     node_add_thr = 0.3
+    disable_conn_thr = 0.1
+
+    # connection_add_thr = 0.0
+    # node_add_thr = 0.0
+    # disable_conn_thr = 0.0
 
     def mutate(self, innovation_num, allow_recurrent=False) -> int:
         """
@@ -107,7 +121,7 @@ class NeatLike(Genotype):
         :return: Updated innovation number
         """
         # add connection
-        if np.random.random(1)[0] < self.connection_add_thr or True:
+        if np.random.random(1)[0] < self.connection_add_thr:
             nodes = list(self.nodes.keys())
             rank_of_node = self.calculate_rank_of_nodes()
             end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
@@ -131,9 +145,82 @@ class NeatLike(Genotype):
                 self.connections.append(
                     ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
                                    recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
-        #todo add node
+
+        if np.random.random(1)[0] < self.node_add_thr:
+            active_connections = list(filter(lambda connection: connection.enabled, self.connections))
+
+            n = np.random.randint(0, len(active_connections))
+            old_connection = active_connections[n]
+
+            new_node = NodeGene(innovation_num, 'hidden')
+            node_id = innovation_num
+            connection_1 = ConnectionGene(old_connection.start, node_id, True, innovation_num,
+                                          recurrent=old_connection.recurrent)
+            innovation_num += 1
+            connection_2 = ConnectionGene(node_id, old_connection.end, True, innovation_num)
+            innovation_num += 1
+
+            old_connection.enabled = False
+            self.nodes[node_id] = new_node
+            self.connections.append(connection_1)
+            self.connections.append(connection_2)
+
+        if np.random.random(1)[0] < self.disable_conn_thr:
+            active_connections = list(filter(lambda connection: connection.enabled, self.connections))
+            n = np.random.randint(0, len(active_connections))
+            old_connection = active_connections[n]
+            old_connection.enabled = not old_connection.enabled
 
         return innovation_num
 
-    def cross(self, other):
-        return self
+    def cross(self, other, fitnes_self, fitness_other):
+        new_genes = NeatLike()
+        node_nums = set(map(lambda node: node[0], self.nodes.items())).union(
+            set(map(lambda node: node[0], other.nodes.items())))
+
+        connections = {}
+        for connection in self.connections:
+            connections[connection.innvovation_num] = connection
+
+        other_connections = {}
+        for connection in other.connections:
+            other_connections[connection.innvovation_num] = connection
+
+        connection_nums = set(map(lambda connection: connection[0], connections.items())).union(
+            set(map(lambda connection: connection[0], other_connections.items())))
+
+        for node_num in node_nums:
+            if node_num in self.nodes.keys() and node_num in other.nodes.keys():
+                if int(fitness_other) == int(fitnes_self):
+                    if np.random.randint(0, 2) == 0:
+                        new_genes.nodes[node_num] = copy(self.nodes[node_num])
+                    else:
+                        new_genes.nodes[node_num] = copy(other.nodes[node_num])
+                elif fitnes_self > fitness_other:
+                    new_genes.nodes[node_num] = copy(self.nodes[node_num])
+                else:
+                    new_genes.nodes[node_num] = copy(other.nodes[node_num])
+            elif node_num in self.nodes.keys() and int(fitnes_self) >= int(fitness_other):
+                new_genes.nodes[node_num] = copy(self.nodes[node_num])
+            elif node_num in other.nodes.keys() and int(fitnes_self) <= int(fitness_other):
+                new_genes.nodes[node_num] = copy(other.nodes[node_num])
+
+        for connection_num in connection_nums:
+            if connection_num in connections.keys() and connection_num in other_connections.keys():
+                if int(fitness_other) == int(fitnes_self):
+                    if np.random.randint(0, 2) == 0:
+                        connection = copy(connections[connection_num])
+                    else:
+                        connection = copy(other_connections[connection_num])
+                elif fitnes_self > fitness_other:
+                    connection = copy(connections[connection_num])
+                else:
+                    connection = copy(other_connections[connection_num])
+
+                new_genes.connections.append(connection)
+            elif connection_num in connections.keys() and int(fitnes_self) >= int(fitness_other):
+                new_genes.connections.append(copy(connections[connection_num]))
+            elif connection_num in other_connections.keys() and int(fitnes_self) <= int(fitness_other):
+                new_genes.connections.append(copy(other_connections[connection_num]))
+
+        return new_genes
diff --git a/labirinth_ai/Population.py b/labirinth_ai/Population.py
index 70eef4f..af3b0c9 100644
--- a/labirinth_ai/Population.py
+++ b/labirinth_ai/Population.py
@@ -1,6 +1,7 @@
 import random
 import numpy as np
 
+from labirinth_ai.Models import EvolutionModel
 from labirinth_ai.Models.Genotype import NeatLike
 
 
@@ -14,7 +15,7 @@ def fib(n):
 
 
 class Population:
-    def __init__(self, subject_class, world, subject_number):
+    def __init__(self, subject_class, world, subject_number, do_evolve=True):
         self.subjects = []
         self.world = world
         for _ in range(subject_number):
@@ -22,6 +23,7 @@ class Population:
             self.subjects.append(subject_class(px, py, genotype_class=NeatLike))
         self.subject_number = subject_number
         self.subject_class = subject_class
+        self.do_evolve = do_evolve
 
     def select(self):
         ranked = list(self.subjects)
@@ -52,46 +54,50 @@ class Population:
         return out + cls.scatter(n - np.sum(fibs), buckets)
 
     def evolve(self):
-        # get updated weights from the models
-        for subject in self.subjects:
-            subject.model.update_genes_with_weights()
+        if self.do_evolve:
+            if len(self.subjects) > 1:
+                # get updated weights from the models
+                for subject in self.subjects:
+                    subject.model.update_genes_with_weights()
 
-        # crossbreed the current pop
-        best_subjects = self.select()
-        distribution = list(self.scatter(self.subject_number - int(self.subject_number / 2), int(self.subject_number / 2)))
+                # crossbreed the current pop
+                best_subjects = self.select()
+                distribution = list(self.scatter(self.subject_number - int(self.subject_number / 2), int(self.subject_number / 2)))
 
-        new_subjects = list(best_subjects)
-        for index, offspring_num in enumerate(distribution):
-            for _ in range(int(offspring_num)):
-                parent_1 = best_subjects[index]
-                parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
+                new_subjects = list(best_subjects)
+                for index, offspring_num in enumerate(distribution):
+                    for _ in range(int(offspring_num)):
+                        parent_1 = best_subjects[index]
+                        parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
 
-                new_genes = parent_1.model.genes.cross(parent_2.model.genes)
+                        new_genes = parent_1.model.genes.cross(parent_2.model.genes,
+                                                               parent_1.accumulated_rewards, parent_2.accumulated_rewards)
 
-                # position doesn't matter, since mutation will set it
-                new_subject = self.subject_class(0, 0, new_genes)
-                new_subject.history = parent_1.history
-                new_subject.samples = parent_1.samples + parent_2.samples
-                new_subjects.append(new_subject)
+                        # position doesn't matter, since mutation will set it
+                        new_subject = self.subject_class(0, 0, new_genes)
+                        new_subject.history = parent_1.history
+                        new_subject.samples = parent_1.samples + parent_2.samples
+                        new_subjects.append(new_subject)
 
-        assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
-
-        # mutate the pop
-        mutated_subjects = []
-        innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,
-                                                         subject.model.genes.connections
+                assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
+            else:
+                new_subjects = self.subjects
+            # mutate the pop
+            mutated_subjects = []
+            innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,
+                                                             subject.model.genes.connections
+                                                             )
                                                          )
-                                                     )
-                             , new_subjects))
-        for subject in new_subjects:
-            subject.accumulated_rewards = 0
+                                 , new_subjects))
+            for subject in new_subjects:
+                subject.accumulated_rewards = 0
 
-            innovation_num = subject.model.genes.mutate(innovation_num)
+                innovation_num = subject.model.genes.mutate(innovation_num)
 
-            px, py = self.world.generate_free_coordinates()
-            new_subject = self.subject_class(px, py, subject.model.genes)
-            new_subject.history = subject.history
-            new_subject.samples = subject.samples
-            mutated_subjects.append(new_subject)
+                px, py = self.world.generate_free_coordinates()
+                new_subject = self.subject_class(px, py, subject.model.genes)
+                new_subject.history = subject.history
+                new_subject.samples = subject.samples
+                mutated_subjects.append(new_subject)
 
-        self.subjects = mutated_subjects
+            self.subjects = mutated_subjects

From b0d22f6bf1ec2c5c43b248ba9c096ca7ca5d8475 Mon Sep 17 00:00:00 2001
From: zomseffen <steffen@tom.bi>
Date: Wed, 21 Dec 2022 16:08:22 +0100
Subject: [PATCH 14/14] beundling weights

---
 labirinth_ai/Models/EvolutionModel.py | 104 ++++++++++++++++----------
 1 file changed, 64 insertions(+), 40 deletions(-)

diff --git a/labirinth_ai/Models/EvolutionModel.py b/labirinth_ai/Models/EvolutionModel.py
index 1b7d79f..8b9fd99 100644
--- a/labirinth_ai/Models/EvolutionModel.py
+++ b/labirinth_ai/Models/EvolutionModel.py
@@ -51,36 +51,40 @@ class EvolutionModel(nn.Module):
             self.incoming_connections[connection.end].append(connection)
 
         self.layers = {}
+        self.layer_non_recurrent_inputs = {}
+        self.layer_recurrent_inputs = {}
+        self.layer_results = {}
+        self.layer_num = 1
         self.indices = {}
 
         self.has_recurrent = False
         self.non_recurrent_indices = {}
         self.recurrent_indices = {}
-        with torch.no_grad():
-            for key, value in self.incoming_connections.items():
-                value.sort(key=lambda element: element.start)
 
-                lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
-                for index, connection in enumerate(value):
-                    lin.weight[0, index] = value[index].weight
-                if self.genes.nodes[key].bias is not None:
-                    lin.bias[0] = self.genes.nodes[key].bias
+        for key, value in self.incoming_connections.items():
+            value.sort(key=lambda element: element.start)
 
-                non_lin = nn.ELU()
-                sequence = nn.Sequential(
-                    lin,
-                    non_lin
-                )
-                self.add_module('layer_' + str(key), sequence)
-                self.layers[key] = sequence
-                self.indices[key] = list(map(lambda element: element.start, value))
+            # lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
+            # for index, connection in enumerate(value):
+            #     lin.weight[0, index] = value[index].weight
+            # if self.genes.nodes[key].bias is not None:
+            #     lin.bias[0] = self.genes.nodes[key].bias
+            #
+            # non_lin = nn.ELU()
+            # sequence = nn.Sequential(
+            #     lin,
+            #     non_lin
+            # )
+            # self.add_module('layer_' + str(key), sequence)
+            # self.layers[key] = sequence
+            self.indices[key] = list(map(lambda element: element.start, value))
 
-                self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
-                self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
-                if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
-                    self.has_recurrent = True
-                self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
-                self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
+            self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
+            self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
+            if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
+                self.has_recurrent = True
+            self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
+            self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
         rank_of_node = {}
         for i in range(self.num_input_nodes):
             rank_of_node[i] = 0
@@ -101,20 +105,39 @@ class EvolutionModel(nn.Module):
                     rank_of_node[key] = max_rank + 1
 
             layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
-        ranked_layers = list(rank_of_node.items())
-        ranked_layers.sort(key=lambda element: element[1])
-        ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
 
-        ranked_layers = list(map(lambda element: (element, 0),
-                                 filter(lambda recurrent_element:
-                                        recurrent_element not in list(
-                                            map(lambda ranked_layer: ranked_layer[0], ranked_layers)
-                                        ),
-                                        list(filter(lambda recurrent_keys:
-                                                    len(self.recurrent_indices[recurrent_keys]) > 0,
-                                                    self.recurrent_indices.keys()))))) + ranked_layers
+        with torch.no_grad():
+            self.layer_num = max_rank = max(map(lambda element: element[1], rank_of_node.items()))
+            #todo: handle solely recurrent nodes
+            for rank in range(1, max_rank + 1):
+                # get nodes
+                nodes = list(map(lambda element: element[0], filter(lambda item: item[1] == rank, rank_of_node.items())))
+                non_recurrent_inputs = list(set.union(*map(lambda node: set(self.non_recurrent_indices[node]), nodes)))
+                non_recurrent_inputs.sort()
+
+                recurrent_inputs = list(set.union(*map(lambda node: set(self.recurrent_indices[node]), nodes)))
+                recurrent_inputs.sort()
+
+                lin = nn.Linear(len(non_recurrent_inputs) + len(recurrent_inputs), len(nodes), bias=True)
+
+                # todo: load weights
+
+                # for index, connection in enumerate(value):
+                #     lin.weight[0, index] = value[index].weight
+                # if self.genes.nodes[key].bias is not None:
+                #     lin.bias[0] = self.genes.nodes[key].bias
+                #
+                non_lin = nn.ELU()
+                sequence = nn.Sequential(
+                    lin,
+                    non_lin
+                )
+                self.add_module('layer_' + str(rank), sequence)
+                self.layers[rank] = sequence
+                self.layer_results[rank] = nodes
+                self.layer_non_recurrent_inputs[rank] = non_recurrent_inputs
+                self.layer_recurrent_inputs[rank] = recurrent_inputs
 
-        self.layer_order = list(map(lambda element: element[0], ranked_layers))
         self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
         self.memory = torch.Tensor(self.memory_size)
         self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
@@ -130,24 +153,25 @@ class EvolutionModel(nn.Module):
         outs = []
         for batch_index, batch_element in enumerate(x_flat):
             memory[0:self.num_input_nodes] = batch_element
-            for layer_index in self.layer_order:
-                non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
+            for layer_index in range(1, self.layer_num + 1):
+                non_recurrent_in = memory[self.layer_non_recurrent_inputs[layer_index]]
                 non_recurrent_in = torch.stack([non_recurrent_in])
-                if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
-                    recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
+                if self.has_recurrent and len(self.layer_recurrent_inputs[layer_index]) > 0:
+                    recurrent_in = last_memory_flat[batch_index, self.layer_recurrent_inputs[layer_index]]
                     recurrent_in = torch.stack([recurrent_in])
 
                     combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
                 else:
                     combined_in = non_recurrent_in
 
-                memory[layer_index] = self.layers[layer_index](combined_in)
-            outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
+                memory[self.layer_results[layer_index]] = self.layers[layer_index](combined_in)
+            outs.append(memory[self.output_range])
         outs = torch.stack(outs)
         self.memory = torch.Tensor(memory)
         return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
 
     def update_genes_with_weights(self):
+        # todo rework
         for key, value in self.incoming_connections.items():
             value.sort(key=lambda element: element.start)