243 lines
9.9 KiB
Python
243 lines
9.9 KiB
Python
import torch
|
|
from torch import nn
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer
|
|
from labirinth_ai.Models.Genotype import Genotype
|
|
|
|
|
|
class EvolutionModel(nn.Module):
|
|
evolutionary = True
|
|
|
|
def __init__(self, view_dimension, action_num, channels, genes: Genotype = None, genotype_class=None):
|
|
if genotype_class is None:
|
|
genotype_class = Genotype
|
|
super(EvolutionModel, self).__init__()
|
|
self.flatten = nn.Flatten()
|
|
|
|
self.action_num = action_num
|
|
self.viewD = view_dimension
|
|
self.channels = channels
|
|
|
|
if genes is None:
|
|
self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2
|
|
self.genes = genotype_class(action_num, self.num_input_nodes)
|
|
else:
|
|
self.num_input_nodes = len(list(filter(lambda element: element[1].node_type == 'sensor', genes.nodes.items())))
|
|
assert self.num_input_nodes > 0, 'Network needs to have sensor nodes!'
|
|
is_input_over = False
|
|
is_output_over = False
|
|
for key, node in genes.nodes.items():
|
|
if node.node_type == 'sensor':
|
|
if is_input_over:
|
|
raise ValueError('Node genes need to follow the order sensor, output, hidden!')
|
|
|
|
if node.node_type == 'output':
|
|
is_input_over = True
|
|
if is_output_over:
|
|
raise ValueError('Node genes need to follow the order sensor, output, hidden!')
|
|
|
|
if node.node_type == 'hidden':
|
|
is_output_over = True
|
|
|
|
self.genes = genes
|
|
|
|
self.incoming_connections = {}
|
|
for connection in self.genes.connections:
|
|
if not connection.enabled:
|
|
continue
|
|
if connection.end not in self.incoming_connections.keys():
|
|
self.incoming_connections[connection.end] = []
|
|
self.incoming_connections[connection.end].append(connection)
|
|
|
|
self.layers = {}
|
|
self.indices = {}
|
|
|
|
self.has_recurrent = False
|
|
self.non_recurrent_indices = {}
|
|
self.recurrent_indices = {}
|
|
with torch.no_grad():
|
|
for key, value in self.incoming_connections.items():
|
|
value.sort(key=lambda element: element.start)
|
|
|
|
lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
|
|
for index, connection in enumerate(value):
|
|
lin.weight[0, index] = value[index].weight
|
|
if self.genes.nodes[key].bias is not None:
|
|
lin.bias[0] = self.genes.nodes[key].bias
|
|
|
|
non_lin = nn.ELU()
|
|
sequence = nn.Sequential(
|
|
lin,
|
|
non_lin
|
|
)
|
|
self.add_module('layer_' + str(key), sequence)
|
|
self.layers[key] = sequence
|
|
self.indices[key] = list(map(lambda element: element.start, value))
|
|
|
|
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
|
|
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
|
|
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
|
|
self.has_recurrent = True
|
|
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
|
|
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
|
|
rank_of_node = {}
|
|
for i in range(self.num_input_nodes):
|
|
rank_of_node[i] = 0
|
|
|
|
layers_to_add = list(self.non_recurrent_indices.items())
|
|
while len(layers_to_add) > 0:
|
|
for index, (key, incoming_nodes) in enumerate(list(layers_to_add)):
|
|
max_rank = -1
|
|
all_ranks_found = True
|
|
|
|
for incoming_node in incoming_nodes:
|
|
if incoming_node in rank_of_node.keys():
|
|
max_rank = max(max_rank, rank_of_node[incoming_node])
|
|
else:
|
|
all_ranks_found = False
|
|
|
|
if all_ranks_found:
|
|
rank_of_node[key] = max_rank + 1
|
|
|
|
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
|
|
ranked_layers = list(rank_of_node.items())
|
|
ranked_layers.sort(key=lambda element: element[1])
|
|
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
|
|
|
|
ranked_layers = list(map(lambda element: (element, 0),
|
|
filter(lambda recurrent_element:
|
|
recurrent_element not in list(
|
|
map(lambda ranked_layer: ranked_layer[0], ranked_layers)
|
|
),
|
|
list(filter(lambda recurrent_keys:
|
|
len(self.recurrent_indices[recurrent_keys]) > 0,
|
|
self.recurrent_indices.keys()))))) + ranked_layers
|
|
|
|
self.layer_order = list(map(lambda element: element[0], ranked_layers))
|
|
self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
|
|
self.memory = torch.Tensor(self.memory_size)
|
|
self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
|
|
|
|
def forward(self, x, last_memory=None):
|
|
x_flat = self.flatten(x)
|
|
if last_memory is not None:
|
|
last_memory_flat = self.flatten(last_memory)
|
|
elif self.has_recurrent:
|
|
raise ValueError('Recurrent networks need to be passed their previous memory!')
|
|
|
|
memory = torch.Tensor(self.memory_size)
|
|
outs = []
|
|
for batch_index, batch_element in enumerate(x_flat):
|
|
memory[0:self.num_input_nodes] = batch_element
|
|
for layer_index in self.layer_order:
|
|
non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
|
|
non_recurrent_in = torch.stack([non_recurrent_in])
|
|
if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
|
|
recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
|
|
recurrent_in = torch.stack([recurrent_in])
|
|
|
|
combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
|
|
else:
|
|
combined_in = non_recurrent_in
|
|
|
|
memory[layer_index] = self.layers[layer_index](combined_in)
|
|
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
|
|
outs = torch.stack(outs)
|
|
self.memory = torch.Tensor(memory)
|
|
return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
|
|
|
|
def update_genes_with_weights(self):
|
|
for key, value in self.incoming_connections.items():
|
|
value.sort(key=lambda element: element.start)
|
|
|
|
sequence = self.layers[key]
|
|
lin = sequence[0]
|
|
for index, connection in enumerate(value):
|
|
value[index].weight = float(lin.weight[0, index])
|
|
if self.genes.nodes[key].bias is not None:
|
|
self.genes.nodes[key].bias = float(lin.bias[0])
|
|
|
|
|
|
class RecurrentDataSet(BaseDataSet):
|
|
def __init__(self, states, targets, memory):
|
|
super().__init__(states, targets)
|
|
assert len(states) == len(memory), "Needs to have as many states as memories!"
|
|
self.memory = torch.tensor(np.array(memory), dtype=torch.float32)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.states[idx], self.memory[idx], self.targets[idx]
|
|
|
|
|
|
def train_recurrent(states, memory, targets, model, optimizer):
|
|
for action in range(model.action_num):
|
|
data_set = RecurrentDataSet(states[action], targets[action], memory[action])
|
|
dataloader = DataLoader(data_set, batch_size=512, shuffle=True)
|
|
loss_fn = create_loss_function(action)
|
|
|
|
size = len(dataloader)
|
|
model.train()
|
|
for batch, (X, M, y) in enumerate(dataloader):
|
|
X, y, M = X.to(device), y.to(device), M.to(device)
|
|
|
|
# Compute prediction error
|
|
pred = model(X, M)
|
|
loss = loss_fn(pred, y)
|
|
|
|
# Backpropagation
|
|
optimizer.zero_grad()
|
|
loss.backward(retain_graph=True)
|
|
optimizer.step()
|
|
|
|
if batch % 100 == 0:
|
|
loss, current = loss.item(), batch * len(X)
|
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
|
model.eval()
|
|
|
|
del data_set
|
|
del dataloader
|
|
|
|
|
|
if __name__ == '__main__':
|
|
sample = np.random.random((1, 1))
|
|
last_memory = np.zeros((1, 3))
|
|
|
|
from labirinth_ai.Models.Genotype import NodeGene, ConnectionGene, Genotype
|
|
genes = Genotype(nodes={0: NodeGene(0, 'sensor'), 1: NodeGene(1, 'output'), 2: NodeGene(2, 'hidden', 1)},
|
|
connections=[ConnectionGene(0, 2, True, 0, recurrent=True), ConnectionGene(2, 1, True, 1, 1)])
|
|
|
|
model = EvolutionModel(1, 1, 1, genes)
|
|
|
|
model = model.to(device)
|
|
# print(model)
|
|
print(model.has_recurrent)
|
|
|
|
test = model(torch.tensor(sample, dtype=torch.float32), torch.tensor(last_memory, dtype=torch.float32))
|
|
# test = test.cpu().detach().numpy()
|
|
# print(test)
|
|
|
|
state = np.random.random((1, 1))
|
|
memory = np.random.random((1, 1))
|
|
|
|
target = np.random.random((2, 1))
|
|
states = [
|
|
[state],
|
|
[state],
|
|
[state],
|
|
[state],
|
|
]
|
|
targets = [
|
|
[target],
|
|
[target],
|
|
[target],
|
|
[target],
|
|
]
|
|
memories = [
|
|
[memory],
|
|
[memory],
|
|
[memory],
|
|
[memory],
|
|
]
|
|
|
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
|
|
train_recurrent(states, memories, targets, model, optimizer)
|