beundling weights
This commit is contained in:
parent
bd56173379
commit
b0d22f6bf1
1 changed files with 64 additions and 40 deletions
|
@ -51,36 +51,40 @@ class EvolutionModel(nn.Module):
|
||||||
self.incoming_connections[connection.end].append(connection)
|
self.incoming_connections[connection.end].append(connection)
|
||||||
|
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
self.layer_non_recurrent_inputs = {}
|
||||||
|
self.layer_recurrent_inputs = {}
|
||||||
|
self.layer_results = {}
|
||||||
|
self.layer_num = 1
|
||||||
self.indices = {}
|
self.indices = {}
|
||||||
|
|
||||||
self.has_recurrent = False
|
self.has_recurrent = False
|
||||||
self.non_recurrent_indices = {}
|
self.non_recurrent_indices = {}
|
||||||
self.recurrent_indices = {}
|
self.recurrent_indices = {}
|
||||||
with torch.no_grad():
|
|
||||||
for key, value in self.incoming_connections.items():
|
|
||||||
value.sort(key=lambda element: element.start)
|
|
||||||
|
|
||||||
lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
|
for key, value in self.incoming_connections.items():
|
||||||
for index, connection in enumerate(value):
|
value.sort(key=lambda element: element.start)
|
||||||
lin.weight[0, index] = value[index].weight
|
|
||||||
if self.genes.nodes[key].bias is not None:
|
|
||||||
lin.bias[0] = self.genes.nodes[key].bias
|
|
||||||
|
|
||||||
non_lin = nn.ELU()
|
# lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
|
||||||
sequence = nn.Sequential(
|
# for index, connection in enumerate(value):
|
||||||
lin,
|
# lin.weight[0, index] = value[index].weight
|
||||||
non_lin
|
# if self.genes.nodes[key].bias is not None:
|
||||||
)
|
# lin.bias[0] = self.genes.nodes[key].bias
|
||||||
self.add_module('layer_' + str(key), sequence)
|
#
|
||||||
self.layers[key] = sequence
|
# non_lin = nn.ELU()
|
||||||
self.indices[key] = list(map(lambda element: element.start, value))
|
# sequence = nn.Sequential(
|
||||||
|
# lin,
|
||||||
|
# non_lin
|
||||||
|
# )
|
||||||
|
# self.add_module('layer_' + str(key), sequence)
|
||||||
|
# self.layers[key] = sequence
|
||||||
|
self.indices[key] = list(map(lambda element: element.start, value))
|
||||||
|
|
||||||
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
|
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
|
||||||
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
|
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
|
||||||
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
|
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
|
||||||
self.has_recurrent = True
|
self.has_recurrent = True
|
||||||
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
|
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
|
||||||
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
|
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
|
||||||
rank_of_node = {}
|
rank_of_node = {}
|
||||||
for i in range(self.num_input_nodes):
|
for i in range(self.num_input_nodes):
|
||||||
rank_of_node[i] = 0
|
rank_of_node[i] = 0
|
||||||
|
@ -101,20 +105,39 @@ class EvolutionModel(nn.Module):
|
||||||
rank_of_node[key] = max_rank + 1
|
rank_of_node[key] = max_rank + 1
|
||||||
|
|
||||||
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
|
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
|
||||||
ranked_layers = list(rank_of_node.items())
|
|
||||||
ranked_layers.sort(key=lambda element: element[1])
|
|
||||||
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
|
|
||||||
|
|
||||||
ranked_layers = list(map(lambda element: (element, 0),
|
with torch.no_grad():
|
||||||
filter(lambda recurrent_element:
|
self.layer_num = max_rank = max(map(lambda element: element[1], rank_of_node.items()))
|
||||||
recurrent_element not in list(
|
#todo: handle solely recurrent nodes
|
||||||
map(lambda ranked_layer: ranked_layer[0], ranked_layers)
|
for rank in range(1, max_rank + 1):
|
||||||
),
|
# get nodes
|
||||||
list(filter(lambda recurrent_keys:
|
nodes = list(map(lambda element: element[0], filter(lambda item: item[1] == rank, rank_of_node.items())))
|
||||||
len(self.recurrent_indices[recurrent_keys]) > 0,
|
non_recurrent_inputs = list(set.union(*map(lambda node: set(self.non_recurrent_indices[node]), nodes)))
|
||||||
self.recurrent_indices.keys()))))) + ranked_layers
|
non_recurrent_inputs.sort()
|
||||||
|
|
||||||
|
recurrent_inputs = list(set.union(*map(lambda node: set(self.recurrent_indices[node]), nodes)))
|
||||||
|
recurrent_inputs.sort()
|
||||||
|
|
||||||
|
lin = nn.Linear(len(non_recurrent_inputs) + len(recurrent_inputs), len(nodes), bias=True)
|
||||||
|
|
||||||
|
# todo: load weights
|
||||||
|
|
||||||
|
# for index, connection in enumerate(value):
|
||||||
|
# lin.weight[0, index] = value[index].weight
|
||||||
|
# if self.genes.nodes[key].bias is not None:
|
||||||
|
# lin.bias[0] = self.genes.nodes[key].bias
|
||||||
|
#
|
||||||
|
non_lin = nn.ELU()
|
||||||
|
sequence = nn.Sequential(
|
||||||
|
lin,
|
||||||
|
non_lin
|
||||||
|
)
|
||||||
|
self.add_module('layer_' + str(rank), sequence)
|
||||||
|
self.layers[rank] = sequence
|
||||||
|
self.layer_results[rank] = nodes
|
||||||
|
self.layer_non_recurrent_inputs[rank] = non_recurrent_inputs
|
||||||
|
self.layer_recurrent_inputs[rank] = recurrent_inputs
|
||||||
|
|
||||||
self.layer_order = list(map(lambda element: element[0], ranked_layers))
|
|
||||||
self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
|
self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
|
||||||
self.memory = torch.Tensor(self.memory_size)
|
self.memory = torch.Tensor(self.memory_size)
|
||||||
self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
|
self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
|
||||||
|
@ -130,24 +153,25 @@ class EvolutionModel(nn.Module):
|
||||||
outs = []
|
outs = []
|
||||||
for batch_index, batch_element in enumerate(x_flat):
|
for batch_index, batch_element in enumerate(x_flat):
|
||||||
memory[0:self.num_input_nodes] = batch_element
|
memory[0:self.num_input_nodes] = batch_element
|
||||||
for layer_index in self.layer_order:
|
for layer_index in range(1, self.layer_num + 1):
|
||||||
non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
|
non_recurrent_in = memory[self.layer_non_recurrent_inputs[layer_index]]
|
||||||
non_recurrent_in = torch.stack([non_recurrent_in])
|
non_recurrent_in = torch.stack([non_recurrent_in])
|
||||||
if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
|
if self.has_recurrent and len(self.layer_recurrent_inputs[layer_index]) > 0:
|
||||||
recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
|
recurrent_in = last_memory_flat[batch_index, self.layer_recurrent_inputs[layer_index]]
|
||||||
recurrent_in = torch.stack([recurrent_in])
|
recurrent_in = torch.stack([recurrent_in])
|
||||||
|
|
||||||
combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
|
combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
|
||||||
else:
|
else:
|
||||||
combined_in = non_recurrent_in
|
combined_in = non_recurrent_in
|
||||||
|
|
||||||
memory[layer_index] = self.layers[layer_index](combined_in)
|
memory[self.layer_results[layer_index]] = self.layers[layer_index](combined_in)
|
||||||
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
|
outs.append(memory[self.output_range])
|
||||||
outs = torch.stack(outs)
|
outs = torch.stack(outs)
|
||||||
self.memory = torch.Tensor(memory)
|
self.memory = torch.Tensor(memory)
|
||||||
return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
|
return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
|
||||||
|
|
||||||
def update_genes_with_weights(self):
|
def update_genes_with_weights(self):
|
||||||
|
# todo rework
|
||||||
for key, value in self.incoming_connections.items():
|
for key, value in self.incoming_connections.items():
|
||||||
value.sort(key=lambda element: element.start)
|
value.sort(key=lambda element: element.start)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue