import torch from torch import nn import numpy as np import tqdm from torch.utils.data import Dataset, DataLoader from labirinth_ai.Models.BaseModel import device class NodeGene: valid_types = ['sensor', 'hidden', 'output'] def __init__(self, node_id, node_type, bias=None): assert node_type in self.valid_types, 'Unknown node type!' self.node_id = node_id self.node_type = node_type if node_type == 'hidden': assert bias is not None, 'Expected a bias for hidden node types!' self.bias = bias else: self.bias = None class ConnectionGene: def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False): self.start = start self.end = end self.enabled = enabled self.innvovation_num = innovation_num self.recurrent = recurrent if weight is None: self.weight = np.random.random(1)[0] * 2 - 1.0 else: self.weight = weight class EvolutionModel(nn.Module): evolutionary = True def __init__(self, view_dimension, action_num, channels, genes=None): super(EvolutionModel, self).__init__() self.flatten = nn.Flatten() self.action_num = action_num self.viewD = view_dimension self.channels = channels if genes is None: self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2 self.genes = {'nodes': {}, 'connections': []} node_id = 0 for _ in range(self.num_input_nodes): self.genes['nodes'][node_id] = NodeGene(node_id, 'sensor') node_id += 1 first_action = node_id for _ in range(action_num * 2): self.genes['nodes'][node_id] = NodeGene(node_id, 'output') node_id += 1 for index in range(self.num_input_nodes): for action in range(action_num * 2): self.genes['connections'].append( ConnectionGene(index, first_action + action, True, index*(action_num * 2) + action) ) self.incoming_connections = {} for connection in self.genes['connections']: if connection.end not in self.incoming_connections.keys(): self.incoming_connections[connection.end] = [] self.incoming_connections[connection.end].append(connection) self.layers = {} self.indices = {} self.has_recurrent = False non_recurrent_indices = {} with torch.no_grad(): for key, value in self.incoming_connections.items(): value.sort(key=lambda element: element.start) lin = nn.Linear(len(value), 1, bias=self.genes['nodes'][key].bias is not None) for index, connection in enumerate(value): lin.weight[0, index] = value[index].weight if self.genes['nodes'][key].bias is not None: lin.bias[0] = self.genes['nodes'][key].bias non_lin = nn.ELU() sequence = nn.Sequential( lin, non_lin ) self.add_module('layer_' + str(key), sequence) self.layers[key] = sequence self.indices[key] = list(map(lambda element: element.start, value)) non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value)) if not self.has_recurrent and len(non_recurrent_indices[key]) != len(self.indices[key]): self.has_recurrent = True non_recurrent_indices[key] = list(map(lambda element: element.start, non_recurrent_indices[key])) rank_of_node = {} for i in range(self.num_input_nodes): rank_of_node[i] = 0 layers_to_add = list(non_recurrent_indices.items()) while len(layers_to_add) > 0: for index, (key, incoming_nodes) in enumerate(list(layers_to_add)): max_rank = -1 all_ranks_found = True for incoming_node in incoming_nodes: if incoming_node in rank_of_node.keys(): max_rank = max(max_rank, rank_of_node[incoming_node]) else: all_ranks_found = False if all_ranks_found: rank_of_node[key] = max_rank + 1 layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add)) ranked_layers = list(rank_of_node.items()) ranked_layers.sort(key=lambda element: element[1]) ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers)) self.layer_order = list(map(lambda element: element[0], ranked_layers)) self.memory = torch.Tensor((max(map(lambda element: element[1].node_id, self.genes['nodes'].items())) + 1)) def forward(self, x, memory=None): x_flat = self.flatten(x) if memory is None: memory = torch.Tensor(self.memory) outs = [] for batch_element in x_flat: memory[0:self.num_input_nodes] = batch_element for layer_index in self.layer_order: memory[layer_index] = self.layers[layer_index](memory[self.indices[layer_index]]) outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2]) outs = torch.stack(outs) self.memory = torch.Tensor(memory) return torch.reshape(outs, (x.shape[0], 4, 2)) else: memory[:, 0:self.num_input_nodes] = x for layer_index in self.layer_order: memory[:, layer_index] = self.layers[layer_index](memory[:, self.indices[layer_index]]) return torch.reshape( memory[:, self.num_input_nodes: self.num_input_nodes + self.action_num * 2], (x.shape[0], 4, 2)) if __name__ == '__main__': sample = np.random.random((1, 486)) model = EvolutionModel(5, 4, 4).to(device) print(model) print(model.has_recurrent) test = model(torch.tensor(sample, dtype=torch.float32)) # test = test.cpu().detach().numpy() print(test) state = np.random.random((1, 486)) target = np.random.random((4, 2)) states = [ [state], [state], [state], [state], ] targets = [ [target], [target], [target], [target], ] optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) from labirinth_ai.Models.BaseModel import train train(states, targets, model, optimizer)