make evolution optional
This commit is contained in:
parent
26e7ffb12b
commit
bd56173379
4 changed files with 144 additions and 51 deletions
|
@ -30,8 +30,8 @@ class LabyrinthWorld(World):
|
||||||
self.lastUpdate = time.time()
|
self.lastUpdate = time.time()
|
||||||
self.nextTrain = self.randomBuffer
|
self.nextTrain = self.randomBuffer
|
||||||
self.round = 1
|
self.round = 1
|
||||||
self.evolve_timer = 10
|
# self.evolve_timer = 10
|
||||||
# self.evolve_timer = 1500
|
self.evolve_timer = 1500
|
||||||
|
|
||||||
self.trailMix = np.zeros(self.board_shape)
|
self.trailMix = np.zeros(self.board_shape)
|
||||||
self.grass = np.zeros(self.board_shape)
|
self.grass = np.zeros(self.board_shape)
|
||||||
|
@ -163,9 +163,9 @@ class LabyrinthWorld(World):
|
||||||
# adding subjects
|
# adding subjects
|
||||||
from labirinth_ai.Subject import Hunter, Herbivore
|
from labirinth_ai.Subject import Hunter, Herbivore
|
||||||
from labirinth_ai.Population import Population
|
from labirinth_ai.Population import Population
|
||||||
self._hunters = Population(Hunter, self, 10)
|
self._hunters = Population(Hunter, self, 10, do_evolve=False)
|
||||||
|
|
||||||
self._herbivores = Population(Herbivore, self, 40)
|
self._herbivores = Population(Herbivore, self, 40, do_evolve=False)
|
||||||
|
|
||||||
self.subjectDict = self.build_subject_dict()
|
self.subjectDict = self.build_subject_dict()
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
|
from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
|
||||||
from labirinth_ai.Models.Genotype import Genotype
|
from labirinth_ai.Models.Genotype import Genotype
|
||||||
|
@ -45,6 +44,8 @@ class EvolutionModel(nn.Module):
|
||||||
|
|
||||||
self.incoming_connections = {}
|
self.incoming_connections = {}
|
||||||
for connection in self.genes.connections:
|
for connection in self.genes.connections:
|
||||||
|
if not connection.enabled:
|
||||||
|
continue
|
||||||
if connection.end not in self.incoming_connections.keys():
|
if connection.end not in self.incoming_connections.keys():
|
||||||
self.incoming_connections[connection.end] = []
|
self.incoming_connections[connection.end] = []
|
||||||
self.incoming_connections[connection.end].append(connection)
|
self.incoming_connections[connection.end].append(connection)
|
||||||
|
@ -158,7 +159,6 @@ class EvolutionModel(nn.Module):
|
||||||
self.genes.nodes[key].bias = float(lin.bias[0])
|
self.genes.nodes[key].bias = float(lin.bias[0])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RecurrentDataSet(BaseDataSet):
|
class RecurrentDataSet(BaseDataSet):
|
||||||
def __init__(self, states, targets, memory):
|
def __init__(self, states, targets, memory):
|
||||||
super().__init__(states, targets)
|
super().__init__(states, targets)
|
||||||
|
@ -172,7 +172,7 @@ class RecurrentDataSet(BaseDataSet):
|
||||||
def train_recurrent(states, memory, targets, model, optimizer):
|
def train_recurrent(states, memory, targets, model, optimizer):
|
||||||
for action in range(model.action_num):
|
for action in range(model.action_num):
|
||||||
data_set = RecurrentDataSet(states[action], targets[action], memory[action])
|
data_set = RecurrentDataSet(states[action], targets[action], memory[action])
|
||||||
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
|
dataloader = DataLoader(data_set, batch_size=512, shuffle=True)
|
||||||
loss_fn = create_loss_function(action)
|
loss_fn = create_loss_function(action)
|
||||||
|
|
||||||
size = len(dataloader)
|
size = len(dataloader)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -12,11 +13,15 @@ class NodeGene:
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
self.node_type = node_type
|
self.node_type = node_type
|
||||||
if node_type == 'hidden':
|
if node_type == 'hidden':
|
||||||
assert bias is not None, 'Expected a bias for hidden node types!'
|
if bias is None:
|
||||||
|
bias = np.random.random(1)[0] * 2 - 1.0
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
return NodeGene(self.node_id, self.node_type, bias=self.bias)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionGene:
|
class ConnectionGene:
|
||||||
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
|
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
|
||||||
|
@ -30,12 +35,15 @@ class ConnectionGene:
|
||||||
else:
|
else:
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
return ConnectionGene(self.start, self.end, self.enabled, self.innvovation_num, self.weight, self.recurrent)
|
||||||
|
|
||||||
|
|
||||||
class Genotype:
|
class Genotype:
|
||||||
def __init__(self, action_num: int = None, num_input_nodes: int = None,
|
def __init__(self, action_num: int = None, num_input_nodes: int = None,
|
||||||
nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
|
nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
|
||||||
self.nodes = {}
|
self.nodes: Dict[int, NodeGene] = {}
|
||||||
self.connections = []
|
self.connections: List[ConnectionGene] = []
|
||||||
if action_num is not None and num_input_nodes is not None:
|
if action_num is not None and num_input_nodes is not None:
|
||||||
node_id = 0
|
node_id = 0
|
||||||
for _ in range(num_input_nodes):
|
for _ in range(num_input_nodes):
|
||||||
|
@ -61,7 +69,8 @@ class Genotype:
|
||||||
while len(nodes_to_rank) > 0:
|
while len(nodes_to_rank) > 0:
|
||||||
for list_index, (id, node) in enumerate(nodes_to_rank):
|
for list_index, (id, node) in enumerate(nodes_to_rank):
|
||||||
incoming_connections = list(filter(lambda connection: connection.end == id and
|
incoming_connections = list(filter(lambda connection: connection.end == id and
|
||||||
not connection.recurrent, self.connections))
|
not connection.recurrent and connection.enabled,
|
||||||
|
self.connections))
|
||||||
if len(incoming_connections) == 0:
|
if len(incoming_connections) == 0:
|
||||||
rank_of_node[id] = 0
|
rank_of_node[id] = 0
|
||||||
nodes_to_rank.pop(list_index)
|
nodes_to_rank.pop(list_index)
|
||||||
|
@ -90,7 +99,7 @@ class Genotype:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cross(self, other):
|
def cross(self, other, fitnes_self, fitness_other):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# return self
|
# return self
|
||||||
|
|
||||||
|
@ -98,6 +107,11 @@ class Genotype:
|
||||||
class NeatLike(Genotype):
|
class NeatLike(Genotype):
|
||||||
connection_add_thr = 0.3
|
connection_add_thr = 0.3
|
||||||
node_add_thr = 0.3
|
node_add_thr = 0.3
|
||||||
|
disable_conn_thr = 0.1
|
||||||
|
|
||||||
|
# connection_add_thr = 0.0
|
||||||
|
# node_add_thr = 0.0
|
||||||
|
# disable_conn_thr = 0.0
|
||||||
|
|
||||||
def mutate(self, innovation_num, allow_recurrent=False) -> int:
|
def mutate(self, innovation_num, allow_recurrent=False) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -107,7 +121,7 @@ class NeatLike(Genotype):
|
||||||
:return: Updated innovation number
|
:return: Updated innovation number
|
||||||
"""
|
"""
|
||||||
# add connection
|
# add connection
|
||||||
if np.random.random(1)[0] < self.connection_add_thr or True:
|
if np.random.random(1)[0] < self.connection_add_thr:
|
||||||
nodes = list(self.nodes.keys())
|
nodes = list(self.nodes.keys())
|
||||||
rank_of_node = self.calculate_rank_of_nodes()
|
rank_of_node = self.calculate_rank_of_nodes()
|
||||||
end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
|
end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
|
||||||
|
@ -131,9 +145,82 @@ class NeatLike(Genotype):
|
||||||
self.connections.append(
|
self.connections.append(
|
||||||
ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
|
ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
|
||||||
recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
|
recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
|
||||||
#todo add node
|
|
||||||
|
if np.random.random(1)[0] < self.node_add_thr:
|
||||||
|
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
|
||||||
|
|
||||||
|
n = np.random.randint(0, len(active_connections))
|
||||||
|
old_connection = active_connections[n]
|
||||||
|
|
||||||
|
new_node = NodeGene(innovation_num, 'hidden')
|
||||||
|
node_id = innovation_num
|
||||||
|
connection_1 = ConnectionGene(old_connection.start, node_id, True, innovation_num,
|
||||||
|
recurrent=old_connection.recurrent)
|
||||||
|
innovation_num += 1
|
||||||
|
connection_2 = ConnectionGene(node_id, old_connection.end, True, innovation_num)
|
||||||
|
innovation_num += 1
|
||||||
|
|
||||||
|
old_connection.enabled = False
|
||||||
|
self.nodes[node_id] = new_node
|
||||||
|
self.connections.append(connection_1)
|
||||||
|
self.connections.append(connection_2)
|
||||||
|
|
||||||
|
if np.random.random(1)[0] < self.disable_conn_thr:
|
||||||
|
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
|
||||||
|
n = np.random.randint(0, len(active_connections))
|
||||||
|
old_connection = active_connections[n]
|
||||||
|
old_connection.enabled = not old_connection.enabled
|
||||||
|
|
||||||
return innovation_num
|
return innovation_num
|
||||||
|
|
||||||
def cross(self, other):
|
def cross(self, other, fitnes_self, fitness_other):
|
||||||
return self
|
new_genes = NeatLike()
|
||||||
|
node_nums = set(map(lambda node: node[0], self.nodes.items())).union(
|
||||||
|
set(map(lambda node: node[0], other.nodes.items())))
|
||||||
|
|
||||||
|
connections = {}
|
||||||
|
for connection in self.connections:
|
||||||
|
connections[connection.innvovation_num] = connection
|
||||||
|
|
||||||
|
other_connections = {}
|
||||||
|
for connection in other.connections:
|
||||||
|
other_connections[connection.innvovation_num] = connection
|
||||||
|
|
||||||
|
connection_nums = set(map(lambda connection: connection[0], connections.items())).union(
|
||||||
|
set(map(lambda connection: connection[0], other_connections.items())))
|
||||||
|
|
||||||
|
for node_num in node_nums:
|
||||||
|
if node_num in self.nodes.keys() and node_num in other.nodes.keys():
|
||||||
|
if int(fitness_other) == int(fitnes_self):
|
||||||
|
if np.random.randint(0, 2) == 0:
|
||||||
|
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
||||||
|
else:
|
||||||
|
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
||||||
|
elif fitnes_self > fitness_other:
|
||||||
|
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
||||||
|
else:
|
||||||
|
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
||||||
|
elif node_num in self.nodes.keys() and int(fitnes_self) >= int(fitness_other):
|
||||||
|
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
||||||
|
elif node_num in other.nodes.keys() and int(fitnes_self) <= int(fitness_other):
|
||||||
|
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
||||||
|
|
||||||
|
for connection_num in connection_nums:
|
||||||
|
if connection_num in connections.keys() and connection_num in other_connections.keys():
|
||||||
|
if int(fitness_other) == int(fitnes_self):
|
||||||
|
if np.random.randint(0, 2) == 0:
|
||||||
|
connection = copy(connections[connection_num])
|
||||||
|
else:
|
||||||
|
connection = copy(other_connections[connection_num])
|
||||||
|
elif fitnes_self > fitness_other:
|
||||||
|
connection = copy(connections[connection_num])
|
||||||
|
else:
|
||||||
|
connection = copy(other_connections[connection_num])
|
||||||
|
|
||||||
|
new_genes.connections.append(connection)
|
||||||
|
elif connection_num in connections.keys() and int(fitnes_self) >= int(fitness_other):
|
||||||
|
new_genes.connections.append(copy(connections[connection_num]))
|
||||||
|
elif connection_num in other_connections.keys() and int(fitnes_self) <= int(fitness_other):
|
||||||
|
new_genes.connections.append(copy(other_connections[connection_num]))
|
||||||
|
|
||||||
|
return new_genes
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from labirinth_ai.Models import EvolutionModel
|
||||||
from labirinth_ai.Models.Genotype import NeatLike
|
from labirinth_ai.Models.Genotype import NeatLike
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ def fib(n):
|
||||||
|
|
||||||
|
|
||||||
class Population:
|
class Population:
|
||||||
def __init__(self, subject_class, world, subject_number):
|
def __init__(self, subject_class, world, subject_number, do_evolve=True):
|
||||||
self.subjects = []
|
self.subjects = []
|
||||||
self.world = world
|
self.world = world
|
||||||
for _ in range(subject_number):
|
for _ in range(subject_number):
|
||||||
|
@ -22,6 +23,7 @@ class Population:
|
||||||
self.subjects.append(subject_class(px, py, genotype_class=NeatLike))
|
self.subjects.append(subject_class(px, py, genotype_class=NeatLike))
|
||||||
self.subject_number = subject_number
|
self.subject_number = subject_number
|
||||||
self.subject_class = subject_class
|
self.subject_class = subject_class
|
||||||
|
self.do_evolve = do_evolve
|
||||||
|
|
||||||
def select(self):
|
def select(self):
|
||||||
ranked = list(self.subjects)
|
ranked = list(self.subjects)
|
||||||
|
@ -52,6 +54,8 @@ class Population:
|
||||||
return out + cls.scatter(n - np.sum(fibs), buckets)
|
return out + cls.scatter(n - np.sum(fibs), buckets)
|
||||||
|
|
||||||
def evolve(self):
|
def evolve(self):
|
||||||
|
if self.do_evolve:
|
||||||
|
if len(self.subjects) > 1:
|
||||||
# get updated weights from the models
|
# get updated weights from the models
|
||||||
for subject in self.subjects:
|
for subject in self.subjects:
|
||||||
subject.model.update_genes_with_weights()
|
subject.model.update_genes_with_weights()
|
||||||
|
@ -66,7 +70,8 @@ class Population:
|
||||||
parent_1 = best_subjects[index]
|
parent_1 = best_subjects[index]
|
||||||
parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
|
parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
|
||||||
|
|
||||||
new_genes = parent_1.model.genes.cross(parent_2.model.genes)
|
new_genes = parent_1.model.genes.cross(parent_2.model.genes,
|
||||||
|
parent_1.accumulated_rewards, parent_2.accumulated_rewards)
|
||||||
|
|
||||||
# position doesn't matter, since mutation will set it
|
# position doesn't matter, since mutation will set it
|
||||||
new_subject = self.subject_class(0, 0, new_genes)
|
new_subject = self.subject_class(0, 0, new_genes)
|
||||||
|
@ -75,7 +80,8 @@ class Population:
|
||||||
new_subjects.append(new_subject)
|
new_subjects.append(new_subject)
|
||||||
|
|
||||||
assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
|
assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
|
||||||
|
else:
|
||||||
|
new_subjects = self.subjects
|
||||||
# mutate the pop
|
# mutate the pop
|
||||||
mutated_subjects = []
|
mutated_subjects = []
|
||||||
innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,
|
innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,
|
||||||
|
|
Loading…
Reference in a new issue