import torch from torch import nn import numpy as np from torch.utils.data import Dataset, DataLoader from labirinth_ai.Models.BaseModel import device, BaseDataSet, create_loss_function, create_optimizer from labirinth_ai.Models.Genotype import Genotype class EvolutionModel(nn.Module): evolutionary = True def __init__(self, view_dimension, action_num, channels, genes: Genotype = None, genotype_class=None): if genotype_class is None: genotype_class = Genotype super(EvolutionModel, self).__init__() self.flatten = nn.Flatten() self.action_num = action_num self.viewD = view_dimension self.channels = channels if genes is None: self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2 self.genes = genotype_class(action_num, self.num_input_nodes) else: self.num_input_nodes = len(list(filter(lambda element: element[1].node_type == 'sensor', genes.nodes.items()))) assert self.num_input_nodes > 0, 'Network needs to have sensor nodes!' is_input_over = False is_output_over = False for key, node in genes.nodes.items(): if node.node_type == 'sensor': if is_input_over: raise ValueError('Node genes need to follow the order sensor, output, hidden!') if node.node_type == 'output': is_input_over = True if is_output_over: raise ValueError('Node genes need to follow the order sensor, output, hidden!') if node.node_type == 'hidden': is_output_over = True self.genes = genes self.incoming_connections = {} for connection in self.genes.connections: if not connection.enabled: continue if connection.end not in self.incoming_connections.keys(): self.incoming_connections[connection.end] = [] self.incoming_connections[connection.end].append(connection) self.layers = {} self.indices = {} self.has_recurrent = False self.non_recurrent_indices = {} self.recurrent_indices = {} with torch.no_grad(): for key, value in self.incoming_connections.items(): value.sort(key=lambda element: element.start) lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None) for index, connection in enumerate(value): lin.weight[0, index] = value[index].weight if self.genes.nodes[key].bias is not None: lin.bias[0] = self.genes.nodes[key].bias non_lin = nn.ELU() sequence = nn.Sequential( lin, non_lin ) self.add_module('layer_' + str(key), sequence) self.layers[key] = sequence self.indices[key] = list(map(lambda element: element.start, value)) self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value)) self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value)) if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]): self.has_recurrent = True self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key])) self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key])) rank_of_node = {} for i in range(self.num_input_nodes): rank_of_node[i] = 0 layers_to_add = list(self.non_recurrent_indices.items()) while len(layers_to_add) > 0: for index, (key, incoming_nodes) in enumerate(list(layers_to_add)): max_rank = -1 all_ranks_found = True for incoming_node in incoming_nodes: if incoming_node in rank_of_node.keys(): max_rank = max(max_rank, rank_of_node[incoming_node]) else: all_ranks_found = False if all_ranks_found: rank_of_node[key] = max_rank + 1 layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add)) ranked_layers = list(rank_of_node.items()) ranked_layers.sort(key=lambda element: element[1]) ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers)) ranked_layers = list(map(lambda element: (element, 0), filter(lambda recurrent_element: recurrent_element not in list( map(lambda ranked_layer: ranked_layer[0], ranked_layers) ), list(filter(lambda recurrent_keys: len(self.recurrent_indices[recurrent_keys]) > 0, self.recurrent_indices.keys()))))) + ranked_layers self.layer_order = list(map(lambda element: element[0], ranked_layers)) self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1) self.memory = torch.Tensor(self.memory_size) self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2) def forward(self, x, last_memory=None): x_flat = self.flatten(x) if last_memory is not None: last_memory_flat = self.flatten(last_memory) elif self.has_recurrent: raise ValueError('Recurrent networks need to be passed their previous memory!') memory = torch.Tensor(self.memory_size) outs = [] for batch_index, batch_element in enumerate(x_flat): memory[0:self.num_input_nodes] = batch_element for layer_index in self.layer_order: non_recurrent_in = memory[self.non_recurrent_indices[layer_index]] non_recurrent_in = torch.stack([non_recurrent_in]) if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0: recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]] recurrent_in = torch.stack([recurrent_in]) combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1) else: combined_in = non_recurrent_in memory[layer_index] = self.layers[layer_index](combined_in) outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2]) outs = torch.stack(outs) self.memory = torch.Tensor(memory) return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2)) def update_genes_with_weights(self): for key, value in self.incoming_connections.items(): value.sort(key=lambda element: element.start) sequence = self.layers[key] lin = sequence[0] for index, connection in enumerate(value): value[index].weight = float(lin.weight[0, index]) if self.genes.nodes[key].bias is not None: self.genes.nodes[key].bias = float(lin.bias[0]) class RecurrentDataSet(BaseDataSet): def __init__(self, states, targets, memory): super().__init__(states, targets) assert len(states) == len(memory), "Needs to have as many states as memories!" self.memory = torch.tensor(np.array(memory), dtype=torch.float32) def __getitem__(self, idx): return self.states[idx], self.memory[idx], self.targets[idx] def train_recurrent(states, memory, targets, model, optimizer): for action in range(model.action_num): data_set = RecurrentDataSet(states[action], targets[action], memory[action]) dataloader = DataLoader(data_set, batch_size=512, shuffle=True) loss_fn = create_loss_function(action) size = len(dataloader) model.train() for batch, (X, M, y) in enumerate(dataloader): X, y, M = X.to(device), y.to(device), M.to(device) # Compute prediction error pred = model(X, M) loss = loss_fn(pred, y) # Backpropagation optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") model.eval() del data_set del dataloader if __name__ == '__main__': sample = np.random.random((1, 1)) last_memory = np.zeros((1, 3)) from labirinth_ai.Models.Genotype import NodeGene, ConnectionGene, Genotype genes = Genotype(nodes={0: NodeGene(0, 'sensor'), 1: NodeGene(1, 'output'), 2: NodeGene(2, 'hidden', 1)}, connections=[ConnectionGene(0, 2, True, 0, recurrent=True), ConnectionGene(2, 1, True, 1, 1)]) model = EvolutionModel(1, 1, 1, genes) model = model.to(device) # print(model) print(model.has_recurrent) test = model(torch.tensor(sample, dtype=torch.float32), torch.tensor(last_memory, dtype=torch.float32)) # test = test.cpu().detach().numpy() # print(test) state = np.random.random((1, 1)) memory = np.random.random((1, 1)) target = np.random.random((2, 1)) states = [ [state], [state], [state], [state], ] targets = [ [target], [target], [target], [target], ] memories = [ [memory], [memory], [memory], [memory], ] optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) train_recurrent(states, memories, targets, model, optimizer)