VoxelEngine/labirinth_ai/Models/EvolutionModel.py

177 lines
6.6 KiB
Python
Raw Normal View History

import torch
from torch import nn
import numpy as np
import tqdm
from torch.utils.data import Dataset, DataLoader
from labirinth_ai.Models.BaseModel import device
class NodeGene:
valid_types = ['sensor', 'hidden', 'output']
def __init__(self, node_id, node_type, bias=None):
assert node_type in self.valid_types, 'Unknown node type!'
self.node_id = node_id
self.node_type = node_type
if node_type == 'hidden':
assert bias is not None, 'Expected a bias for hidden node types!'
self.bias = bias
else:
self.bias = None
class ConnectionGene:
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
self.start = start
self.end = end
self.enabled = enabled
self.innvovation_num = innovation_num
self.recurrent = recurrent
if weight is None:
self.weight = np.random.random(1)[0] * 2 - 1.0
else:
self.weight = weight
class EvolutionModel(nn.Module):
evolutionary = True
def __init__(self, view_dimension, action_num, channels, genes=None):
super(EvolutionModel, self).__init__()
self.flatten = nn.Flatten()
self.action_num = action_num
self.viewD = view_dimension
self.channels = channels
if genes is None:
self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2
self.genes = {'nodes': {}, 'connections': []}
node_id = 0
for _ in range(self.num_input_nodes):
self.genes['nodes'][node_id] = NodeGene(node_id, 'sensor')
node_id += 1
first_action = node_id
for _ in range(action_num * 2):
self.genes['nodes'][node_id] = NodeGene(node_id, 'output')
node_id += 1
for index in range(self.num_input_nodes):
for action in range(action_num * 2):
self.genes['connections'].append(
ConnectionGene(index, first_action + action, True, index*(action_num * 2) + action)
)
self.incoming_connections = {}
for connection in self.genes['connections']:
if connection.end not in self.incoming_connections.keys():
self.incoming_connections[connection.end] = []
self.incoming_connections[connection.end].append(connection)
self.layers = {}
self.indices = {}
self.has_recurrent = False
non_recurrent_indices = {}
with torch.no_grad():
for key, value in self.incoming_connections.items():
value.sort(key=lambda element: element.start)
lin = nn.Linear(len(value), 1, bias=self.genes['nodes'][key].bias is not None)
for index, connection in enumerate(value):
lin.weight[0, index] = value[index].weight
if self.genes['nodes'][key].bias is not None:
lin.bias[0] = self.genes['nodes'][key].bias
non_lin = nn.ELU()
sequence = nn.Sequential(
lin,
non_lin
)
self.add_module('layer_' + str(key), sequence)
self.layers[key] = sequence
self.indices[key] = list(map(lambda element: element.start, value))
non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
if not self.has_recurrent and len(non_recurrent_indices[key]) != len(self.indices[key]):
self.has_recurrent = True
non_recurrent_indices[key] = list(map(lambda element: element.start, non_recurrent_indices[key]))
rank_of_node = {}
for i in range(self.num_input_nodes):
rank_of_node[i] = 0
layers_to_add = list(non_recurrent_indices.items())
while len(layers_to_add) > 0:
for index, (key, incoming_nodes) in enumerate(list(layers_to_add)):
max_rank = -1
all_ranks_found = True
for incoming_node in incoming_nodes:
if incoming_node in rank_of_node.keys():
max_rank = max(max_rank, rank_of_node[incoming_node])
else:
all_ranks_found = False
if all_ranks_found:
rank_of_node[key] = max_rank + 1
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
ranked_layers = list(rank_of_node.items())
ranked_layers.sort(key=lambda element: element[1])
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
self.layer_order = list(map(lambda element: element[0], ranked_layers))
self.memory = torch.Tensor((max(map(lambda element: element[1].node_id, self.genes['nodes'].items())) + 1))
def forward(self, x, memory=None):
x_flat = self.flatten(x)
if memory is None:
memory = torch.Tensor(self.memory)
outs = []
for batch_element in x_flat:
memory[0:self.num_input_nodes] = batch_element
for layer_index in self.layer_order:
memory[layer_index] = self.layers[layer_index](memory[self.indices[layer_index]])
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
outs = torch.stack(outs)
self.memory = torch.Tensor(memory)
return torch.reshape(outs, (x.shape[0], 4, 2))
else:
memory[:, 0:self.num_input_nodes] = x
for layer_index in self.layer_order:
memory[:, layer_index] = self.layers[layer_index](memory[:, self.indices[layer_index]])
return torch.reshape(
memory[:, self.num_input_nodes: self.num_input_nodes + self.action_num * 2],
(x.shape[0], 4, 2))
if __name__ == '__main__':
sample = np.random.random((1, 486))
model = EvolutionModel(5, 4, 4).to(device)
print(model)
print(model.has_recurrent)
test = model(torch.tensor(sample, dtype=torch.float32))
# test = test.cpu().detach().numpy()
print(test)
state = np.random.random((1, 486))
target = np.random.random((4, 2))
states = [
[state],
[state],
[state],
[state],
]
targets = [
[target],
[target],
[target],
[target],
]
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
from labirinth_ai.Models.BaseModel import train
train(states, targets, model, optimizer)