make evolution optional

This commit is contained in:
zomseffen 2022-11-14 11:18:57 +01:00
parent 26e7ffb12b
commit bd56173379
4 changed files with 144 additions and 51 deletions

View file

@ -30,8 +30,8 @@ class LabyrinthWorld(World):
self.lastUpdate = time.time() self.lastUpdate = time.time()
self.nextTrain = self.randomBuffer self.nextTrain = self.randomBuffer
self.round = 1 self.round = 1
self.evolve_timer = 10 # self.evolve_timer = 10
# self.evolve_timer = 1500 self.evolve_timer = 1500
self.trailMix = np.zeros(self.board_shape) self.trailMix = np.zeros(self.board_shape)
self.grass = np.zeros(self.board_shape) self.grass = np.zeros(self.board_shape)
@ -163,9 +163,9 @@ class LabyrinthWorld(World):
# adding subjects # adding subjects
from labirinth_ai.Subject import Hunter, Herbivore from labirinth_ai.Subject import Hunter, Herbivore
from labirinth_ai.Population import Population from labirinth_ai.Population import Population
self._hunters = Population(Hunter, self, 10) self._hunters = Population(Hunter, self, 10, do_evolve=False)
self._herbivores = Population(Herbivore, self, 40) self._herbivores = Population(Herbivore, self, 40, do_evolve=False)
self.subjectDict = self.build_subject_dict() self.subjectDict = self.build_subject_dict()

View file

@ -1,7 +1,6 @@
import torch import torch
from torch import nn from torch import nn
import numpy as np import numpy as np
import tqdm
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
from labirinth_ai.Models.Genotype import Genotype from labirinth_ai.Models.Genotype import Genotype
@ -45,6 +44,8 @@ class EvolutionModel(nn.Module):
self.incoming_connections = {} self.incoming_connections = {}
for connection in self.genes.connections: for connection in self.genes.connections:
if not connection.enabled:
continue
if connection.end not in self.incoming_connections.keys(): if connection.end not in self.incoming_connections.keys():
self.incoming_connections[connection.end] = [] self.incoming_connections[connection.end] = []
self.incoming_connections[connection.end].append(connection) self.incoming_connections[connection.end].append(connection)
@ -158,7 +159,6 @@ class EvolutionModel(nn.Module):
self.genes.nodes[key].bias = float(lin.bias[0]) self.genes.nodes[key].bias = float(lin.bias[0])
class RecurrentDataSet(BaseDataSet): class RecurrentDataSet(BaseDataSet):
def __init__(self, states, targets, memory): def __init__(self, states, targets, memory):
super().__init__(states, targets) super().__init__(states, targets)
@ -172,7 +172,7 @@ class RecurrentDataSet(BaseDataSet):
def train_recurrent(states, memory, targets, model, optimizer): def train_recurrent(states, memory, targets, model, optimizer):
for action in range(model.action_num): for action in range(model.action_num):
data_set = RecurrentDataSet(states[action], targets[action], memory[action]) data_set = RecurrentDataSet(states[action], targets[action], memory[action])
dataloader = DataLoader(data_set, batch_size=64, shuffle=True) dataloader = DataLoader(data_set, batch_size=512, shuffle=True)
loss_fn = create_loss_function(action) loss_fn = create_loss_function(action)
size = len(dataloader) size = len(dataloader)

View file

@ -1,5 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Dict from typing import List, Dict
from copy import copy
import numpy as np import numpy as np
@ -12,11 +13,15 @@ class NodeGene:
self.node_id = node_id self.node_id = node_id
self.node_type = node_type self.node_type = node_type
if node_type == 'hidden': if node_type == 'hidden':
assert bias is not None, 'Expected a bias for hidden node types!' if bias is None:
bias = np.random.random(1)[0] * 2 - 1.0
self.bias = bias self.bias = bias
else: else:
self.bias = None self.bias = None
def __copy__(self):
return NodeGene(self.node_id, self.node_type, bias=self.bias)
class ConnectionGene: class ConnectionGene:
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False): def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
@ -30,12 +35,15 @@ class ConnectionGene:
else: else:
self.weight = weight self.weight = weight
def __copy__(self):
return ConnectionGene(self.start, self.end, self.enabled, self.innvovation_num, self.weight, self.recurrent)
class Genotype: class Genotype:
def __init__(self, action_num: int = None, num_input_nodes: int = None, def __init__(self, action_num: int = None, num_input_nodes: int = None,
nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None): nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
self.nodes = {} self.nodes: Dict[int, NodeGene] = {}
self.connections = [] self.connections: List[ConnectionGene] = []
if action_num is not None and num_input_nodes is not None: if action_num is not None and num_input_nodes is not None:
node_id = 0 node_id = 0
for _ in range(num_input_nodes): for _ in range(num_input_nodes):
@ -61,7 +69,8 @@ class Genotype:
while len(nodes_to_rank) > 0: while len(nodes_to_rank) > 0:
for list_index, (id, node) in enumerate(nodes_to_rank): for list_index, (id, node) in enumerate(nodes_to_rank):
incoming_connections = list(filter(lambda connection: connection.end == id and incoming_connections = list(filter(lambda connection: connection.end == id and
not connection.recurrent, self.connections)) not connection.recurrent and connection.enabled,
self.connections))
if len(incoming_connections) == 0: if len(incoming_connections) == 0:
rank_of_node[id] = 0 rank_of_node[id] = 0
nodes_to_rank.pop(list_index) nodes_to_rank.pop(list_index)
@ -90,7 +99,7 @@ class Genotype:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def cross(self, other): def cross(self, other, fitnes_self, fitness_other):
raise NotImplementedError() raise NotImplementedError()
# return self # return self
@ -98,6 +107,11 @@ class Genotype:
class NeatLike(Genotype): class NeatLike(Genotype):
connection_add_thr = 0.3 connection_add_thr = 0.3
node_add_thr = 0.3 node_add_thr = 0.3
disable_conn_thr = 0.1
# connection_add_thr = 0.0
# node_add_thr = 0.0
# disable_conn_thr = 0.0
def mutate(self, innovation_num, allow_recurrent=False) -> int: def mutate(self, innovation_num, allow_recurrent=False) -> int:
""" """
@ -107,7 +121,7 @@ class NeatLike(Genotype):
:return: Updated innovation number :return: Updated innovation number
""" """
# add connection # add connection
if np.random.random(1)[0] < self.connection_add_thr or True: if np.random.random(1)[0] < self.connection_add_thr:
nodes = list(self.nodes.keys()) nodes = list(self.nodes.keys())
rank_of_node = self.calculate_rank_of_nodes() rank_of_node = self.calculate_rank_of_nodes()
end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes)) end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
@ -131,9 +145,82 @@ class NeatLike(Genotype):
self.connections.append( self.connections.append(
ConnectionGene(nodes[start], end_nodes[end], True, innovation_num, ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]])) recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
#todo add node
if np.random.random(1)[0] < self.node_add_thr:
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
n = np.random.randint(0, len(active_connections))
old_connection = active_connections[n]
new_node = NodeGene(innovation_num, 'hidden')
node_id = innovation_num
connection_1 = ConnectionGene(old_connection.start, node_id, True, innovation_num,
recurrent=old_connection.recurrent)
innovation_num += 1
connection_2 = ConnectionGene(node_id, old_connection.end, True, innovation_num)
innovation_num += 1
old_connection.enabled = False
self.nodes[node_id] = new_node
self.connections.append(connection_1)
self.connections.append(connection_2)
if np.random.random(1)[0] < self.disable_conn_thr:
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
n = np.random.randint(0, len(active_connections))
old_connection = active_connections[n]
old_connection.enabled = not old_connection.enabled
return innovation_num return innovation_num
def cross(self, other): def cross(self, other, fitnes_self, fitness_other):
return self new_genes = NeatLike()
node_nums = set(map(lambda node: node[0], self.nodes.items())).union(
set(map(lambda node: node[0], other.nodes.items())))
connections = {}
for connection in self.connections:
connections[connection.innvovation_num] = connection
other_connections = {}
for connection in other.connections:
other_connections[connection.innvovation_num] = connection
connection_nums = set(map(lambda connection: connection[0], connections.items())).union(
set(map(lambda connection: connection[0], other_connections.items())))
for node_num in node_nums:
if node_num in self.nodes.keys() and node_num in other.nodes.keys():
if int(fitness_other) == int(fitnes_self):
if np.random.randint(0, 2) == 0:
new_genes.nodes[node_num] = copy(self.nodes[node_num])
else:
new_genes.nodes[node_num] = copy(other.nodes[node_num])
elif fitnes_self > fitness_other:
new_genes.nodes[node_num] = copy(self.nodes[node_num])
else:
new_genes.nodes[node_num] = copy(other.nodes[node_num])
elif node_num in self.nodes.keys() and int(fitnes_self) >= int(fitness_other):
new_genes.nodes[node_num] = copy(self.nodes[node_num])
elif node_num in other.nodes.keys() and int(fitnes_self) <= int(fitness_other):
new_genes.nodes[node_num] = copy(other.nodes[node_num])
for connection_num in connection_nums:
if connection_num in connections.keys() and connection_num in other_connections.keys():
if int(fitness_other) == int(fitnes_self):
if np.random.randint(0, 2) == 0:
connection = copy(connections[connection_num])
else:
connection = copy(other_connections[connection_num])
elif fitnes_self > fitness_other:
connection = copy(connections[connection_num])
else:
connection = copy(other_connections[connection_num])
new_genes.connections.append(connection)
elif connection_num in connections.keys() and int(fitnes_self) >= int(fitness_other):
new_genes.connections.append(copy(connections[connection_num]))
elif connection_num in other_connections.keys() and int(fitnes_self) <= int(fitness_other):
new_genes.connections.append(copy(other_connections[connection_num]))
return new_genes

View file

@ -1,6 +1,7 @@
import random import random
import numpy as np import numpy as np
from labirinth_ai.Models import EvolutionModel
from labirinth_ai.Models.Genotype import NeatLike from labirinth_ai.Models.Genotype import NeatLike
@ -14,7 +15,7 @@ def fib(n):
class Population: class Population:
def __init__(self, subject_class, world, subject_number): def __init__(self, subject_class, world, subject_number, do_evolve=True):
self.subjects = [] self.subjects = []
self.world = world self.world = world
for _ in range(subject_number): for _ in range(subject_number):
@ -22,6 +23,7 @@ class Population:
self.subjects.append(subject_class(px, py, genotype_class=NeatLike)) self.subjects.append(subject_class(px, py, genotype_class=NeatLike))
self.subject_number = subject_number self.subject_number = subject_number
self.subject_class = subject_class self.subject_class = subject_class
self.do_evolve = do_evolve
def select(self): def select(self):
ranked = list(self.subjects) ranked = list(self.subjects)
@ -52,6 +54,8 @@ class Population:
return out + cls.scatter(n - np.sum(fibs), buckets) return out + cls.scatter(n - np.sum(fibs), buckets)
def evolve(self): def evolve(self):
if self.do_evolve:
if len(self.subjects) > 1:
# get updated weights from the models # get updated weights from the models
for subject in self.subjects: for subject in self.subjects:
subject.model.update_genes_with_weights() subject.model.update_genes_with_weights()
@ -66,7 +70,8 @@ class Population:
parent_1 = best_subjects[index] parent_1 = best_subjects[index]
parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)] parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
new_genes = parent_1.model.genes.cross(parent_2.model.genes) new_genes = parent_1.model.genes.cross(parent_2.model.genes,
parent_1.accumulated_rewards, parent_2.accumulated_rewards)
# position doesn't matter, since mutation will set it # position doesn't matter, since mutation will set it
new_subject = self.subject_class(0, 0, new_genes) new_subject = self.subject_class(0, 0, new_genes)
@ -75,7 +80,8 @@ class Population:
new_subjects.append(new_subject) new_subjects.append(new_subject)
assert len(new_subjects) == self.subject_number, 'All generations should have constant size!' assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
else:
new_subjects = self.subjects
# mutate the pop # mutate the pop
mutated_subjects = [] mutated_subjects = []
innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num, innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,