diff --git a/labirinth_ai/Models/EvolutionModel.py b/labirinth_ai/Models/EvolutionModel.py index 1b7d79f..8b9fd99 100644 --- a/labirinth_ai/Models/EvolutionModel.py +++ b/labirinth_ai/Models/EvolutionModel.py @@ -51,36 +51,40 @@ class EvolutionModel(nn.Module): self.incoming_connections[connection.end].append(connection) self.layers = {} + self.layer_non_recurrent_inputs = {} + self.layer_recurrent_inputs = {} + self.layer_results = {} + self.layer_num = 1 self.indices = {} self.has_recurrent = False self.non_recurrent_indices = {} self.recurrent_indices = {} - with torch.no_grad(): - for key, value in self.incoming_connections.items(): - value.sort(key=lambda element: element.start) - lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None) - for index, connection in enumerate(value): - lin.weight[0, index] = value[index].weight - if self.genes.nodes[key].bias is not None: - lin.bias[0] = self.genes.nodes[key].bias + for key, value in self.incoming_connections.items(): + value.sort(key=lambda element: element.start) - non_lin = nn.ELU() - sequence = nn.Sequential( - lin, - non_lin - ) - self.add_module('layer_' + str(key), sequence) - self.layers[key] = sequence - self.indices[key] = list(map(lambda element: element.start, value)) + # lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None) + # for index, connection in enumerate(value): + # lin.weight[0, index] = value[index].weight + # if self.genes.nodes[key].bias is not None: + # lin.bias[0] = self.genes.nodes[key].bias + # + # non_lin = nn.ELU() + # sequence = nn.Sequential( + # lin, + # non_lin + # ) + # self.add_module('layer_' + str(key), sequence) + # self.layers[key] = sequence + self.indices[key] = list(map(lambda element: element.start, value)) - self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value)) - self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value)) - if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]): - self.has_recurrent = True - self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key])) - self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key])) + self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value)) + self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value)) + if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]): + self.has_recurrent = True + self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key])) + self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key])) rank_of_node = {} for i in range(self.num_input_nodes): rank_of_node[i] = 0 @@ -101,20 +105,39 @@ class EvolutionModel(nn.Module): rank_of_node[key] = max_rank + 1 layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add)) - ranked_layers = list(rank_of_node.items()) - ranked_layers.sort(key=lambda element: element[1]) - ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers)) - ranked_layers = list(map(lambda element: (element, 0), - filter(lambda recurrent_element: - recurrent_element not in list( - map(lambda ranked_layer: ranked_layer[0], ranked_layers) - ), - list(filter(lambda recurrent_keys: - len(self.recurrent_indices[recurrent_keys]) > 0, - self.recurrent_indices.keys()))))) + ranked_layers + with torch.no_grad(): + self.layer_num = max_rank = max(map(lambda element: element[1], rank_of_node.items())) + #todo: handle solely recurrent nodes + for rank in range(1, max_rank + 1): + # get nodes + nodes = list(map(lambda element: element[0], filter(lambda item: item[1] == rank, rank_of_node.items()))) + non_recurrent_inputs = list(set.union(*map(lambda node: set(self.non_recurrent_indices[node]), nodes))) + non_recurrent_inputs.sort() + + recurrent_inputs = list(set.union(*map(lambda node: set(self.recurrent_indices[node]), nodes))) + recurrent_inputs.sort() + + lin = nn.Linear(len(non_recurrent_inputs) + len(recurrent_inputs), len(nodes), bias=True) + + # todo: load weights + + # for index, connection in enumerate(value): + # lin.weight[0, index] = value[index].weight + # if self.genes.nodes[key].bias is not None: + # lin.bias[0] = self.genes.nodes[key].bias + # + non_lin = nn.ELU() + sequence = nn.Sequential( + lin, + non_lin + ) + self.add_module('layer_' + str(rank), sequence) + self.layers[rank] = sequence + self.layer_results[rank] = nodes + self.layer_non_recurrent_inputs[rank] = non_recurrent_inputs + self.layer_recurrent_inputs[rank] = recurrent_inputs - self.layer_order = list(map(lambda element: element[0], ranked_layers)) self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1) self.memory = torch.Tensor(self.memory_size) self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2) @@ -130,24 +153,25 @@ class EvolutionModel(nn.Module): outs = [] for batch_index, batch_element in enumerate(x_flat): memory[0:self.num_input_nodes] = batch_element - for layer_index in self.layer_order: - non_recurrent_in = memory[self.non_recurrent_indices[layer_index]] + for layer_index in range(1, self.layer_num + 1): + non_recurrent_in = memory[self.layer_non_recurrent_inputs[layer_index]] non_recurrent_in = torch.stack([non_recurrent_in]) - if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0: - recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]] + if self.has_recurrent and len(self.layer_recurrent_inputs[layer_index]) > 0: + recurrent_in = last_memory_flat[batch_index, self.layer_recurrent_inputs[layer_index]] recurrent_in = torch.stack([recurrent_in]) combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1) else: combined_in = non_recurrent_in - memory[layer_index] = self.layers[layer_index](combined_in) - outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2]) + memory[self.layer_results[layer_index]] = self.layers[layer_index](combined_in) + outs.append(memory[self.output_range]) outs = torch.stack(outs) self.memory = torch.Tensor(memory) return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2)) def update_genes_with_weights(self): + # todo rework for key, value in self.incoming_connections.items(): value.sort(key=lambda element: element.start)