beundling weights
This commit is contained in:
parent
bd56173379
commit
b0d22f6bf1
1 changed files with 64 additions and 40 deletions
|
@ -51,36 +51,40 @@ class EvolutionModel(nn.Module):
|
|||
self.incoming_connections[connection.end].append(connection)
|
||||
|
||||
self.layers = {}
|
||||
self.layer_non_recurrent_inputs = {}
|
||||
self.layer_recurrent_inputs = {}
|
||||
self.layer_results = {}
|
||||
self.layer_num = 1
|
||||
self.indices = {}
|
||||
|
||||
self.has_recurrent = False
|
||||
self.non_recurrent_indices = {}
|
||||
self.recurrent_indices = {}
|
||||
with torch.no_grad():
|
||||
for key, value in self.incoming_connections.items():
|
||||
value.sort(key=lambda element: element.start)
|
||||
|
||||
lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
|
||||
for index, connection in enumerate(value):
|
||||
lin.weight[0, index] = value[index].weight
|
||||
if self.genes.nodes[key].bias is not None:
|
||||
lin.bias[0] = self.genes.nodes[key].bias
|
||||
for key, value in self.incoming_connections.items():
|
||||
value.sort(key=lambda element: element.start)
|
||||
|
||||
non_lin = nn.ELU()
|
||||
sequence = nn.Sequential(
|
||||
lin,
|
||||
non_lin
|
||||
)
|
||||
self.add_module('layer_' + str(key), sequence)
|
||||
self.layers[key] = sequence
|
||||
self.indices[key] = list(map(lambda element: element.start, value))
|
||||
# lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
|
||||
# for index, connection in enumerate(value):
|
||||
# lin.weight[0, index] = value[index].weight
|
||||
# if self.genes.nodes[key].bias is not None:
|
||||
# lin.bias[0] = self.genes.nodes[key].bias
|
||||
#
|
||||
# non_lin = nn.ELU()
|
||||
# sequence = nn.Sequential(
|
||||
# lin,
|
||||
# non_lin
|
||||
# )
|
||||
# self.add_module('layer_' + str(key), sequence)
|
||||
# self.layers[key] = sequence
|
||||
self.indices[key] = list(map(lambda element: element.start, value))
|
||||
|
||||
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
|
||||
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
|
||||
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
|
||||
self.has_recurrent = True
|
||||
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
|
||||
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
|
||||
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
|
||||
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
|
||||
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
|
||||
self.has_recurrent = True
|
||||
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
|
||||
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
|
||||
rank_of_node = {}
|
||||
for i in range(self.num_input_nodes):
|
||||
rank_of_node[i] = 0
|
||||
|
@ -101,20 +105,39 @@ class EvolutionModel(nn.Module):
|
|||
rank_of_node[key] = max_rank + 1
|
||||
|
||||
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
|
||||
ranked_layers = list(rank_of_node.items())
|
||||
ranked_layers.sort(key=lambda element: element[1])
|
||||
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
|
||||
|
||||
ranked_layers = list(map(lambda element: (element, 0),
|
||||
filter(lambda recurrent_element:
|
||||
recurrent_element not in list(
|
||||
map(lambda ranked_layer: ranked_layer[0], ranked_layers)
|
||||
),
|
||||
list(filter(lambda recurrent_keys:
|
||||
len(self.recurrent_indices[recurrent_keys]) > 0,
|
||||
self.recurrent_indices.keys()))))) + ranked_layers
|
||||
with torch.no_grad():
|
||||
self.layer_num = max_rank = max(map(lambda element: element[1], rank_of_node.items()))
|
||||
#todo: handle solely recurrent nodes
|
||||
for rank in range(1, max_rank + 1):
|
||||
# get nodes
|
||||
nodes = list(map(lambda element: element[0], filter(lambda item: item[1] == rank, rank_of_node.items())))
|
||||
non_recurrent_inputs = list(set.union(*map(lambda node: set(self.non_recurrent_indices[node]), nodes)))
|
||||
non_recurrent_inputs.sort()
|
||||
|
||||
recurrent_inputs = list(set.union(*map(lambda node: set(self.recurrent_indices[node]), nodes)))
|
||||
recurrent_inputs.sort()
|
||||
|
||||
lin = nn.Linear(len(non_recurrent_inputs) + len(recurrent_inputs), len(nodes), bias=True)
|
||||
|
||||
# todo: load weights
|
||||
|
||||
# for index, connection in enumerate(value):
|
||||
# lin.weight[0, index] = value[index].weight
|
||||
# if self.genes.nodes[key].bias is not None:
|
||||
# lin.bias[0] = self.genes.nodes[key].bias
|
||||
#
|
||||
non_lin = nn.ELU()
|
||||
sequence = nn.Sequential(
|
||||
lin,
|
||||
non_lin
|
||||
)
|
||||
self.add_module('layer_' + str(rank), sequence)
|
||||
self.layers[rank] = sequence
|
||||
self.layer_results[rank] = nodes
|
||||
self.layer_non_recurrent_inputs[rank] = non_recurrent_inputs
|
||||
self.layer_recurrent_inputs[rank] = recurrent_inputs
|
||||
|
||||
self.layer_order = list(map(lambda element: element[0], ranked_layers))
|
||||
self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
|
||||
self.memory = torch.Tensor(self.memory_size)
|
||||
self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
|
||||
|
@ -130,24 +153,25 @@ class EvolutionModel(nn.Module):
|
|||
outs = []
|
||||
for batch_index, batch_element in enumerate(x_flat):
|
||||
memory[0:self.num_input_nodes] = batch_element
|
||||
for layer_index in self.layer_order:
|
||||
non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
|
||||
for layer_index in range(1, self.layer_num + 1):
|
||||
non_recurrent_in = memory[self.layer_non_recurrent_inputs[layer_index]]
|
||||
non_recurrent_in = torch.stack([non_recurrent_in])
|
||||
if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
|
||||
recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
|
||||
if self.has_recurrent and len(self.layer_recurrent_inputs[layer_index]) > 0:
|
||||
recurrent_in = last_memory_flat[batch_index, self.layer_recurrent_inputs[layer_index]]
|
||||
recurrent_in = torch.stack([recurrent_in])
|
||||
|
||||
combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
|
||||
else:
|
||||
combined_in = non_recurrent_in
|
||||
|
||||
memory[layer_index] = self.layers[layer_index](combined_in)
|
||||
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
|
||||
memory[self.layer_results[layer_index]] = self.layers[layer_index](combined_in)
|
||||
outs.append(memory[self.output_range])
|
||||
outs = torch.stack(outs)
|
||||
self.memory = torch.Tensor(memory)
|
||||
return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
|
||||
|
||||
def update_genes_with_weights(self):
|
||||
# todo rework
|
||||
for key, value in self.incoming_connections.items():
|
||||
value.sort(key=lambda element: element.start)
|
||||
|
||||
|
|
Loading…
Reference in a new issue