neat implementation up to mutate

This commit is contained in:
zomseffen 2022-08-12 15:48:30 +02:00
parent 4a05baa103
commit cf4d773c10
8 changed files with 468 additions and 144 deletions

View file

@ -44,8 +44,8 @@ class BaseModel(nn.Module):
class BaseDataSet(Dataset):
def __init__(self, states, targets):
assert len(states) == len(targets), "Needs to have as many states as targets!"
self.states = torch.tensor(states, dtype=torch.float32)
self.targets = torch.tensor(targets, dtype=torch.float32)
self.states = torch.tensor(np.array(states), dtype=torch.float32)
self.targets = torch.tensor(np.array(targets), dtype=torch.float32)
def __len__(self):
return len(self.states)
@ -69,7 +69,7 @@ def create_loss_function(action):
def from_numpy(x):
return torch.tensor(x, dtype=torch.float32)
return torch.tensor(np.array(x), dtype=torch.float32)
def train(states, targets, model, optimizer):

View file

@ -3,40 +3,16 @@ from torch import nn
import numpy as np
import tqdm
from torch.utils.data import Dataset, DataLoader
from labirinth_ai.Models.BaseModel import device
class NodeGene:
valid_types = ['sensor', 'hidden', 'output']
def __init__(self, node_id, node_type, bias=None):
assert node_type in self.valid_types, 'Unknown node type!'
self.node_id = node_id
self.node_type = node_type
if node_type == 'hidden':
assert bias is not None, 'Expected a bias for hidden node types!'
self.bias = bias
else:
self.bias = None
class ConnectionGene:
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
self.start = start
self.end = end
self.enabled = enabled
self.innvovation_num = innovation_num
self.recurrent = recurrent
if weight is None:
self.weight = np.random.random(1)[0] * 2 - 1.0
else:
self.weight = weight
from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
from labirinth_ai.Models.Genotype import Genotype
class EvolutionModel(nn.Module):
evolutionary = True
def __init__(self, view_dimension, action_num, channels, genes=None):
def __init__(self, view_dimension, action_num, channels, genes: Genotype = None, genotype_class=None):
if genotype_class is None:
genotype_class = Genotype
super(EvolutionModel, self).__init__()
self.flatten = nn.Flatten()
@ -46,25 +22,29 @@ class EvolutionModel(nn.Module):
if genes is None:
self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2
self.genes = genotype_class(action_num, self.num_input_nodes)
else:
self.num_input_nodes = len(list(filter(lambda element: element[1].node_type == 'sensor', genes.nodes.items())))
assert self.num_input_nodes > 0, 'Network needs to have sensor nodes!'
is_input_over = False
is_output_over = False
for key, node in genes.nodes.items():
if node.node_type == 'sensor':
if is_input_over:
raise ValueError('Node genes need to follow the order sensor, output, hidden!')
self.genes = {'nodes': {}, 'connections': []}
node_id = 0
for _ in range(self.num_input_nodes):
self.genes['nodes'][node_id] = NodeGene(node_id, 'sensor')
node_id += 1
first_action = node_id
for _ in range(action_num * 2):
self.genes['nodes'][node_id] = NodeGene(node_id, 'output')
node_id += 1
if node.node_type == 'output':
is_input_over = True
if is_output_over:
raise ValueError('Node genes need to follow the order sensor, output, hidden!')
for index in range(self.num_input_nodes):
for action in range(action_num * 2):
self.genes['connections'].append(
ConnectionGene(index, first_action + action, True, index*(action_num * 2) + action)
)
if node.node_type == 'hidden':
is_output_over = True
self.genes = genes
self.incoming_connections = {}
for connection in self.genes['connections']:
for connection in self.genes.connections:
if connection.end not in self.incoming_connections.keys():
self.incoming_connections[connection.end] = []
self.incoming_connections[connection.end].append(connection)
@ -73,16 +53,17 @@ class EvolutionModel(nn.Module):
self.indices = {}
self.has_recurrent = False
non_recurrent_indices = {}
self.non_recurrent_indices = {}
self.recurrent_indices = {}
with torch.no_grad():
for key, value in self.incoming_connections.items():
value.sort(key=lambda element: element.start)
lin = nn.Linear(len(value), 1, bias=self.genes['nodes'][key].bias is not None)
lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
for index, connection in enumerate(value):
lin.weight[0, index] = value[index].weight
if self.genes['nodes'][key].bias is not None:
lin.bias[0] = self.genes['nodes'][key].bias
if self.genes.nodes[key].bias is not None:
lin.bias[0] = self.genes.nodes[key].bias
non_lin = nn.ELU()
sequence = nn.Sequential(
@ -93,15 +74,17 @@ class EvolutionModel(nn.Module):
self.layers[key] = sequence
self.indices[key] = list(map(lambda element: element.start, value))
non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
if not self.has_recurrent and len(non_recurrent_indices[key]) != len(self.indices[key]):
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
self.has_recurrent = True
non_recurrent_indices[key] = list(map(lambda element: element.start, non_recurrent_indices[key]))
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
rank_of_node = {}
for i in range(self.num_input_nodes):
rank_of_node[i] = 0
layers_to_add = list(non_recurrent_indices.items())
layers_to_add = list(self.non_recurrent_indices.items())
while len(layers_to_add) > 0:
for index, (key, incoming_nodes) in enumerate(list(layers_to_add)):
max_rank = -1
@ -120,44 +103,123 @@ class EvolutionModel(nn.Module):
ranked_layers = list(rank_of_node.items())
ranked_layers.sort(key=lambda element: element[1])
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
self.layer_order = list(map(lambda element: element[0], ranked_layers))
self.memory = torch.Tensor((max(map(lambda element: element[1].node_id, self.genes['nodes'].items())) + 1))
def forward(self, x, memory=None):
ranked_layers = list(map(lambda element: (element, 0),
filter(lambda recurrent_element:
recurrent_element not in list(
map(lambda ranked_layer: ranked_layer[0], ranked_layers)
),
list(filter(lambda recurrent_keys:
len(self.recurrent_indices[recurrent_keys]) > 0,
self.recurrent_indices.keys()))))) + ranked_layers
self.layer_order = list(map(lambda element: element[0], ranked_layers))
self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
self.memory = torch.Tensor(self.memory_size)
self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
def forward(self, x, last_memory=None):
x_flat = self.flatten(x)
if memory is None:
memory = torch.Tensor(self.memory)
outs = []
for batch_element in x_flat:
memory[0:self.num_input_nodes] = batch_element
for layer_index in self.layer_order:
memory[layer_index] = self.layers[layer_index](memory[self.indices[layer_index]])
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
outs = torch.stack(outs)
self.memory = torch.Tensor(memory)
return torch.reshape(outs, (x.shape[0], 4, 2))
else:
memory[:, 0:self.num_input_nodes] = x
if last_memory is not None:
last_memory_flat = self.flatten(last_memory)
elif self.has_recurrent:
raise ValueError('Recurrent networks need to be passed their previous memory!')
memory = torch.Tensor(self.memory_size)
outs = []
for batch_index, batch_element in enumerate(x_flat):
memory[0:self.num_input_nodes] = batch_element
for layer_index in self.layer_order:
memory[:, layer_index] = self.layers[layer_index](memory[:, self.indices[layer_index]])
return torch.reshape(
memory[:, self.num_input_nodes: self.num_input_nodes + self.action_num * 2],
(x.shape[0], 4, 2))
non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
non_recurrent_in = torch.stack([non_recurrent_in])
if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
recurrent_in = torch.stack([recurrent_in])
combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
else:
combined_in = non_recurrent_in
memory[layer_index] = self.layers[layer_index](combined_in)
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
outs = torch.stack(outs)
self.memory = torch.Tensor(memory)
return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
def update_genes_with_weights(self):
for key, value in self.incoming_connections.items():
value.sort(key=lambda element: element.start)
sequence = self.layers[key]
lin = sequence[0]
for index, connection in enumerate(value):
value[index].weight = float(lin.weight[0, index])
if self.genes.nodes[key].bias is not None:
self.genes.nodes[key].bias = float(lin.bias[0])
class RecurrentDataSet(BaseDataSet):
def __init__(self, states, targets, memory):
super().__init__(states, targets)
assert len(states) == len(memory), "Needs to have as many states as memories!"
self.memory = torch.tensor(np.array(memory), dtype=torch.float32)
def __getitem__(self, idx):
return self.states[idx], self.memory[idx], self.targets[idx]
def train_recurrent(states, memory, targets, model, optimizer):
for action in range(model.action_num):
data_set = RecurrentDataSet(states[action], targets[action], memory[action])
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
loss_fn = create_loss_function(action)
size = len(dataloader)
model.train()
for batch, (X, M, y) in enumerate(dataloader):
X, y, M = X.to(device), y.to(device), M.to(device)
# Compute prediction error
pred = model(X, M)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
model.eval()
del data_set
del dataloader
if __name__ == '__main__':
sample = np.random.random((1, 486))
sample = np.random.random((1, 1))
last_memory = np.zeros((1, 3))
model = EvolutionModel(5, 4, 4).to(device)
print(model)
from labirinth_ai.Models.Genotype import NodeGene, ConnectionGene, Genotype
genes = Genotype(nodes={0: NodeGene(0, 'sensor'), 1: NodeGene(1, 'output'), 2: NodeGene(2, 'hidden', 1)},
connections=[ConnectionGene(0, 2, True, 0, recurrent=True), ConnectionGene(2, 1, True, 1, 1)])
model = EvolutionModel(1, 1, 1, genes)
model = model.to(device)
# print(model)
print(model.has_recurrent)
test = model(torch.tensor(sample, dtype=torch.float32))
test = model(torch.tensor(sample, dtype=torch.float32), torch.tensor(last_memory, dtype=torch.float32))
# test = test.cpu().detach().numpy()
print(test)
# print(test)
state = np.random.random((1, 486))
target = np.random.random((4, 2))
state = np.random.random((1, 1))
memory = np.random.random((1, 1))
target = np.random.random((2, 1))
states = [
[state],
[state],
@ -170,7 +232,12 @@ if __name__ == '__main__':
[target],
[target],
]
memories = [
[memory],
[memory],
[memory],
[memory],
]
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
from labirinth_ai.Models.BaseModel import train
train(states, targets, model, optimizer)
train_recurrent(states, memories, targets, model, optimizer)

View file

@ -0,0 +1,139 @@
from abc import abstractmethod
from typing import List, Dict
import numpy as np
class NodeGene:
valid_types = ['sensor', 'hidden', 'output']
def __init__(self, node_id, node_type, bias=None):
assert node_type in self.valid_types, 'Unknown node type!'
self.node_id = node_id
self.node_type = node_type
if node_type == 'hidden':
assert bias is not None, 'Expected a bias for hidden node types!'
self.bias = bias
else:
self.bias = None
class ConnectionGene:
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
self.start = start
self.end = end
self.enabled = enabled
self.innvovation_num = innovation_num
self.recurrent = recurrent
if weight is None:
self.weight = np.random.random(1)[0] * 2 - 1.0
else:
self.weight = weight
class Genotype:
def __init__(self, action_num: int = None, num_input_nodes: int = None,
nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
self.nodes = {}
self.connections = []
if action_num is not None and num_input_nodes is not None:
node_id = 0
for _ in range(num_input_nodes):
self.nodes[node_id] = NodeGene(node_id, 'sensor')
node_id += 1
first_action = node_id
for _ in range(action_num * 2):
self.nodes[node_id] = NodeGene(node_id, 'output')
node_id += 1
for index in range(num_input_nodes):
for action in range(action_num * 2):
self.connections.append(
ConnectionGene(index, first_action + action, True, index * (action_num * 2) + action)
)
if nodes is not None and connections is not None:
self.nodes = nodes
self.connections = connections
def calculate_rank_of_nodes(self):
rank_of_node = {}
nodes_to_rank = list(self.nodes.items())
while len(nodes_to_rank) > 0:
for list_index, (id, node) in enumerate(nodes_to_rank):
incoming_connections = list(filter(lambda connection: connection.end == id and
not connection.recurrent, self.connections))
if len(incoming_connections) == 0:
rank_of_node[id] = 0
nodes_to_rank.pop(list_index)
break
incoming_connections_starts = list(map(lambda connection: connection.start, incoming_connections))
start_ranks = list(map(lambda element: rank_of_node[element[0]],
filter(lambda start_node: start_node[0] in incoming_connections_starts and
start_node[0] in rank_of_node.keys(),
self.nodes.items())))
if len(start_ranks) == len(incoming_connections):
rank_of_node[id] = max(start_ranks) + 1
nodes_to_rank.pop(list_index)
break
return rank_of_node
@abstractmethod
def mutate(self, innovation_num) -> int:
"""
Decides whether or not to mutate this network. Then returns the new innovation number.
:param innovation_num: Current innovation number
:return: Updated innovation number
"""
# return innovation_num
raise NotImplementedError()
@abstractmethod
def cross(self, other):
raise NotImplementedError()
# return self
class NeatLike(Genotype):
connection_add_thr = 0.3
node_add_thr = 0.3
def mutate(self, innovation_num, allow_recurrent=False) -> int:
"""
Decides whether or not to mutate this network. Then returns the new innovation number.
:param allow_recurrent: Optional parameter allowing or disallowing recurrent connections to form
:param innovation_num: Current innovation number
:return: Updated innovation number
"""
# add connection
if np.random.random(1)[0] < self.connection_add_thr or True:
nodes = list(self.nodes.keys())
rank_of_node = self.calculate_rank_of_nodes()
end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
connection_tuple = list(map(lambda connection: (connection.start, connection.end), self.connections))
start = np.random.randint(0, len(nodes))
end = np.random.randint(0, len(end_nodes))
tries = 50
while (rank_of_node[end_nodes[end]] == 0 or
((not allow_recurrent) and rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]])
or nodes[start] == end_nodes[end] or (nodes[start], end_nodes[end]) in connection_tuple) and\
tries > 0:
end = np.random.randint(0, len(end_nodes))
if (not allow_recurrent) and rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]:
start = np.random.randint(0, len(nodes))
tries -= 1
if tries > 0:
innovation_num += 1
self.connections.append(
ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
#todo add node
return innovation_num
def cross(self, other):
return self