make evolution optional
This commit is contained in:
parent
26e7ffb12b
commit
bd56173379
4 changed files with 144 additions and 51 deletions
|
@ -30,8 +30,8 @@ class LabyrinthWorld(World):
|
|||
self.lastUpdate = time.time()
|
||||
self.nextTrain = self.randomBuffer
|
||||
self.round = 1
|
||||
self.evolve_timer = 10
|
||||
# self.evolve_timer = 1500
|
||||
# self.evolve_timer = 10
|
||||
self.evolve_timer = 1500
|
||||
|
||||
self.trailMix = np.zeros(self.board_shape)
|
||||
self.grass = np.zeros(self.board_shape)
|
||||
|
@ -163,9 +163,9 @@ class LabyrinthWorld(World):
|
|||
# adding subjects
|
||||
from labirinth_ai.Subject import Hunter, Herbivore
|
||||
from labirinth_ai.Population import Population
|
||||
self._hunters = Population(Hunter, self, 10)
|
||||
self._hunters = Population(Hunter, self, 10, do_evolve=False)
|
||||
|
||||
self._herbivores = Population(Herbivore, self, 40)
|
||||
self._herbivores = Population(Herbivore, self, 40, do_evolve=False)
|
||||
|
||||
self.subjectDict = self.build_subject_dict()
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
|
||||
from labirinth_ai.Models.Genotype import Genotype
|
||||
|
@ -45,6 +44,8 @@ class EvolutionModel(nn.Module):
|
|||
|
||||
self.incoming_connections = {}
|
||||
for connection in self.genes.connections:
|
||||
if not connection.enabled:
|
||||
continue
|
||||
if connection.end not in self.incoming_connections.keys():
|
||||
self.incoming_connections[connection.end] = []
|
||||
self.incoming_connections[connection.end].append(connection)
|
||||
|
@ -158,7 +159,6 @@ class EvolutionModel(nn.Module):
|
|||
self.genes.nodes[key].bias = float(lin.bias[0])
|
||||
|
||||
|
||||
|
||||
class RecurrentDataSet(BaseDataSet):
|
||||
def __init__(self, states, targets, memory):
|
||||
super().__init__(states, targets)
|
||||
|
@ -172,7 +172,7 @@ class RecurrentDataSet(BaseDataSet):
|
|||
def train_recurrent(states, memory, targets, model, optimizer):
|
||||
for action in range(model.action_num):
|
||||
data_set = RecurrentDataSet(states[action], targets[action], memory[action])
|
||||
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
|
||||
dataloader = DataLoader(data_set, batch_size=512, shuffle=True)
|
||||
loss_fn = create_loss_function(action)
|
||||
|
||||
size = len(dataloader)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Dict
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -12,11 +13,15 @@ class NodeGene:
|
|||
self.node_id = node_id
|
||||
self.node_type = node_type
|
||||
if node_type == 'hidden':
|
||||
assert bias is not None, 'Expected a bias for hidden node types!'
|
||||
if bias is None:
|
||||
bias = np.random.random(1)[0] * 2 - 1.0
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def __copy__(self):
|
||||
return NodeGene(self.node_id, self.node_type, bias=self.bias)
|
||||
|
||||
|
||||
class ConnectionGene:
|
||||
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
|
||||
|
@ -30,12 +35,15 @@ class ConnectionGene:
|
|||
else:
|
||||
self.weight = weight
|
||||
|
||||
def __copy__(self):
|
||||
return ConnectionGene(self.start, self.end, self.enabled, self.innvovation_num, self.weight, self.recurrent)
|
||||
|
||||
|
||||
class Genotype:
|
||||
def __init__(self, action_num: int = None, num_input_nodes: int = None,
|
||||
nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
|
||||
self.nodes = {}
|
||||
self.connections = []
|
||||
self.nodes: Dict[int, NodeGene] = {}
|
||||
self.connections: List[ConnectionGene] = []
|
||||
if action_num is not None and num_input_nodes is not None:
|
||||
node_id = 0
|
||||
for _ in range(num_input_nodes):
|
||||
|
@ -61,7 +69,8 @@ class Genotype:
|
|||
while len(nodes_to_rank) > 0:
|
||||
for list_index, (id, node) in enumerate(nodes_to_rank):
|
||||
incoming_connections = list(filter(lambda connection: connection.end == id and
|
||||
not connection.recurrent, self.connections))
|
||||
not connection.recurrent and connection.enabled,
|
||||
self.connections))
|
||||
if len(incoming_connections) == 0:
|
||||
rank_of_node[id] = 0
|
||||
nodes_to_rank.pop(list_index)
|
||||
|
@ -90,7 +99,7 @@ class Genotype:
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def cross(self, other):
|
||||
def cross(self, other, fitnes_self, fitness_other):
|
||||
raise NotImplementedError()
|
||||
# return self
|
||||
|
||||
|
@ -98,6 +107,11 @@ class Genotype:
|
|||
class NeatLike(Genotype):
|
||||
connection_add_thr = 0.3
|
||||
node_add_thr = 0.3
|
||||
disable_conn_thr = 0.1
|
||||
|
||||
# connection_add_thr = 0.0
|
||||
# node_add_thr = 0.0
|
||||
# disable_conn_thr = 0.0
|
||||
|
||||
def mutate(self, innovation_num, allow_recurrent=False) -> int:
|
||||
"""
|
||||
|
@ -107,7 +121,7 @@ class NeatLike(Genotype):
|
|||
:return: Updated innovation number
|
||||
"""
|
||||
# add connection
|
||||
if np.random.random(1)[0] < self.connection_add_thr or True:
|
||||
if np.random.random(1)[0] < self.connection_add_thr:
|
||||
nodes = list(self.nodes.keys())
|
||||
rank_of_node = self.calculate_rank_of_nodes()
|
||||
end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
|
||||
|
@ -131,9 +145,82 @@ class NeatLike(Genotype):
|
|||
self.connections.append(
|
||||
ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
|
||||
recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
|
||||
#todo add node
|
||||
|
||||
if np.random.random(1)[0] < self.node_add_thr:
|
||||
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
|
||||
|
||||
n = np.random.randint(0, len(active_connections))
|
||||
old_connection = active_connections[n]
|
||||
|
||||
new_node = NodeGene(innovation_num, 'hidden')
|
||||
node_id = innovation_num
|
||||
connection_1 = ConnectionGene(old_connection.start, node_id, True, innovation_num,
|
||||
recurrent=old_connection.recurrent)
|
||||
innovation_num += 1
|
||||
connection_2 = ConnectionGene(node_id, old_connection.end, True, innovation_num)
|
||||
innovation_num += 1
|
||||
|
||||
old_connection.enabled = False
|
||||
self.nodes[node_id] = new_node
|
||||
self.connections.append(connection_1)
|
||||
self.connections.append(connection_2)
|
||||
|
||||
if np.random.random(1)[0] < self.disable_conn_thr:
|
||||
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
|
||||
n = np.random.randint(0, len(active_connections))
|
||||
old_connection = active_connections[n]
|
||||
old_connection.enabled = not old_connection.enabled
|
||||
|
||||
return innovation_num
|
||||
|
||||
def cross(self, other):
|
||||
return self
|
||||
def cross(self, other, fitnes_self, fitness_other):
|
||||
new_genes = NeatLike()
|
||||
node_nums = set(map(lambda node: node[0], self.nodes.items())).union(
|
||||
set(map(lambda node: node[0], other.nodes.items())))
|
||||
|
||||
connections = {}
|
||||
for connection in self.connections:
|
||||
connections[connection.innvovation_num] = connection
|
||||
|
||||
other_connections = {}
|
||||
for connection in other.connections:
|
||||
other_connections[connection.innvovation_num] = connection
|
||||
|
||||
connection_nums = set(map(lambda connection: connection[0], connections.items())).union(
|
||||
set(map(lambda connection: connection[0], other_connections.items())))
|
||||
|
||||
for node_num in node_nums:
|
||||
if node_num in self.nodes.keys() and node_num in other.nodes.keys():
|
||||
if int(fitness_other) == int(fitnes_self):
|
||||
if np.random.randint(0, 2) == 0:
|
||||
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
||||
else:
|
||||
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
||||
elif fitnes_self > fitness_other:
|
||||
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
||||
else:
|
||||
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
||||
elif node_num in self.nodes.keys() and int(fitnes_self) >= int(fitness_other):
|
||||
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
||||
elif node_num in other.nodes.keys() and int(fitnes_self) <= int(fitness_other):
|
||||
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
||||
|
||||
for connection_num in connection_nums:
|
||||
if connection_num in connections.keys() and connection_num in other_connections.keys():
|
||||
if int(fitness_other) == int(fitnes_self):
|
||||
if np.random.randint(0, 2) == 0:
|
||||
connection = copy(connections[connection_num])
|
||||
else:
|
||||
connection = copy(other_connections[connection_num])
|
||||
elif fitnes_self > fitness_other:
|
||||
connection = copy(connections[connection_num])
|
||||
else:
|
||||
connection = copy(other_connections[connection_num])
|
||||
|
||||
new_genes.connections.append(connection)
|
||||
elif connection_num in connections.keys() and int(fitnes_self) >= int(fitness_other):
|
||||
new_genes.connections.append(copy(connections[connection_num]))
|
||||
elif connection_num in other_connections.keys() and int(fitnes_self) <= int(fitness_other):
|
||||
new_genes.connections.append(copy(other_connections[connection_num]))
|
||||
|
||||
return new_genes
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import random
|
||||
import numpy as np
|
||||
|
||||
from labirinth_ai.Models import EvolutionModel
|
||||
from labirinth_ai.Models.Genotype import NeatLike
|
||||
|
||||
|
||||
|
@ -14,7 +15,7 @@ def fib(n):
|
|||
|
||||
|
||||
class Population:
|
||||
def __init__(self, subject_class, world, subject_number):
|
||||
def __init__(self, subject_class, world, subject_number, do_evolve=True):
|
||||
self.subjects = []
|
||||
self.world = world
|
||||
for _ in range(subject_number):
|
||||
|
@ -22,6 +23,7 @@ class Population:
|
|||
self.subjects.append(subject_class(px, py, genotype_class=NeatLike))
|
||||
self.subject_number = subject_number
|
||||
self.subject_class = subject_class
|
||||
self.do_evolve = do_evolve
|
||||
|
||||
def select(self):
|
||||
ranked = list(self.subjects)
|
||||
|
@ -52,6 +54,8 @@ class Population:
|
|||
return out + cls.scatter(n - np.sum(fibs), buckets)
|
||||
|
||||
def evolve(self):
|
||||
if self.do_evolve:
|
||||
if len(self.subjects) > 1:
|
||||
# get updated weights from the models
|
||||
for subject in self.subjects:
|
||||
subject.model.update_genes_with_weights()
|
||||
|
@ -66,7 +70,8 @@ class Population:
|
|||
parent_1 = best_subjects[index]
|
||||
parent_2 = best_subjects[random.randint(index + 1, len(best_subjects) - 1)]
|
||||
|
||||
new_genes = parent_1.model.genes.cross(parent_2.model.genes)
|
||||
new_genes = parent_1.model.genes.cross(parent_2.model.genes,
|
||||
parent_1.accumulated_rewards, parent_2.accumulated_rewards)
|
||||
|
||||
# position doesn't matter, since mutation will set it
|
||||
new_subject = self.subject_class(0, 0, new_genes)
|
||||
|
@ -75,7 +80,8 @@ class Population:
|
|||
new_subjects.append(new_subject)
|
||||
|
||||
assert len(new_subjects) == self.subject_number, 'All generations should have constant size!'
|
||||
|
||||
else:
|
||||
new_subjects = self.subjects
|
||||
# mutate the pop
|
||||
mutated_subjects = []
|
||||
innovation_num = max(map(lambda subject: max(map(lambda connection: connection.innvovation_num,
|
||||
|
|
Loading…
Reference in a new issue