base evolution model. needs different memory handling
This commit is contained in:
parent
33b5d9c83e
commit
4a05baa103
4 changed files with 218 additions and 48 deletions
|
@ -146,23 +146,27 @@ class LabyrinthWorld(World):
|
||||||
|
|
||||||
# adding subjects
|
# adding subjects
|
||||||
from labirinth_ai.Subject import Hunter, Herbivore
|
from labirinth_ai.Subject import Hunter, Herbivore
|
||||||
while len(self.subjects) < 2:
|
for _ in range(10):
|
||||||
px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
|
while True:
|
||||||
py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
|
px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
|
||||||
if self.board[px, py] == 1:
|
py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
|
||||||
self.subjects.append(Hunter(px, py))
|
if self.board[px, py] == 1:
|
||||||
self.ins += self.subjects[-1].x_in
|
self.subjects.append(Hunter(px, py))
|
||||||
self.actions += self.subjects[-1].actions
|
self.ins += self.subjects[-1].x_in
|
||||||
self.targets += self.subjects[-1].target
|
self.actions += self.subjects[-1].actions
|
||||||
|
self.targets += self.subjects[-1].target
|
||||||
|
break
|
||||||
|
|
||||||
while len(self.subjects) < 10:
|
for _ in range(40):
|
||||||
px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
|
while True:
|
||||||
py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
|
px = random.randint(self.max_room_dim, self.board_shape[0] - self.max_room_dim)
|
||||||
if self.board[px, py] == 1:
|
py = random.randint(self.max_room_dim, self.board_shape[1] - self.max_room_dim)
|
||||||
self.subjects.append(Herbivore(px, py))
|
if self.board[px, py] == 1:
|
||||||
self.ins += self.subjects[-1].x_in
|
self.subjects.append(Herbivore(px, py))
|
||||||
self.actions += self.subjects[-1].actions
|
self.ins += self.subjects[-1].x_in
|
||||||
self.targets += self.subjects[-1].target
|
self.actions += self.subjects[-1].actions
|
||||||
|
self.targets += self.subjects[-1].target
|
||||||
|
break
|
||||||
|
|
||||||
for x in range(self.board_shape[0]):
|
for x in range(self.board_shape[0]):
|
||||||
for y in range(self.board_shape[1]):
|
for y in range(self.board_shape[1]):
|
||||||
|
@ -173,36 +177,14 @@ class LabyrinthWorld(World):
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
# start = time.time()
|
# start = time.time()
|
||||||
if self.model is None:
|
for sub in self.subjects:
|
||||||
for sub in self.subjects:
|
sub.calculateAction(self)
|
||||||
sub.calculateAction(self)
|
|
||||||
else:
|
|
||||||
states = list(map(lambda e: e.createState(self), self.subjects))
|
|
||||||
states = sum(list(map(lambda e: [e, e, e, e], states)), [])
|
|
||||||
vals = self.model.predict(states)
|
|
||||||
vals = np.reshape(np.transpose(np.reshape(vals, (len(self.subjects), 4, 2)), (0, 2, 1)),
|
|
||||||
(len(self.subjects), 1, 8))
|
|
||||||
list(map(lambda e: e[1].calculateAction(self, vals[e[0]], states[e[0]]), enumerate(self.subjects)))
|
|
||||||
|
|
||||||
for sub in self.subjects:
|
for sub in self.subjects:
|
||||||
if sub.alive:
|
if sub.alive:
|
||||||
sub.update(self, doTrain=self.model is None)
|
sub.update(self)
|
||||||
sub.tick += 1
|
sub.tick += 1
|
||||||
|
|
||||||
if self.model is not None:
|
|
||||||
if self.round >= self.nextTrain:
|
|
||||||
samples = list(map(lambda e: e.generateSamples(), self.subjects))
|
|
||||||
states = sum(list(map(lambda e: e[0], samples)), [])
|
|
||||||
targets = sum(list(map(lambda e: e[1], samples)), [])
|
|
||||||
self.model.fit(states, targets)
|
|
||||||
self.nextTrain = self.batchsize / 5
|
|
||||||
self.round = 0
|
|
||||||
for sub in self.subjects:
|
|
||||||
if len(sub.samples) > 20*self.batchsize:
|
|
||||||
sub.samples = sub.samples[:-20*self.batchsize]
|
|
||||||
else:
|
|
||||||
self.round += 1
|
|
||||||
|
|
||||||
new_subjects = []
|
new_subjects = []
|
||||||
kill_table = {}
|
kill_table = {}
|
||||||
live_table = {}
|
live_table = {}
|
||||||
|
|
|
@ -13,6 +13,7 @@ print(f"Using {device} device")
|
||||||
|
|
||||||
# Define model
|
# Define model
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
|
evolutionary = False
|
||||||
def __init__(self, view_dimension, action_num, channels):
|
def __init__(self, view_dimension, action_num, channels):
|
||||||
super(BaseModel, self).__init__()
|
super(BaseModel, self).__init__()
|
||||||
self.flatten = nn.Flatten()
|
self.flatten = nn.Flatten()
|
||||||
|
@ -39,6 +40,7 @@ class BaseModel(nn.Module):
|
||||||
actions.append(self.actions[action](x_flat))
|
actions.append(self.actions[action](x_flat))
|
||||||
return torch.stack(actions, dim=1)
|
return torch.stack(actions, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class BaseDataSet(Dataset):
|
class BaseDataSet(Dataset):
|
||||||
def __init__(self, states, targets):
|
def __init__(self, states, targets):
|
||||||
assert len(states) == len(targets), "Needs to have as many states as targets!"
|
assert len(states) == len(targets), "Needs to have as many states as targets!"
|
||||||
|
@ -87,7 +89,7 @@ def train(states, targets, model, optimizer):
|
||||||
|
|
||||||
# Backpropagation
|
# Backpropagation
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward(retain_graph=True)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if batch % 100 == 0:
|
if batch % 100 == 0:
|
||||||
|
@ -100,7 +102,7 @@ def train(states, targets, model, optimizer):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sample = np.random.random((1, 4, 11, 11))
|
sample = np.random.random((1, 486))
|
||||||
|
|
||||||
model = BaseModel(5, 4, 4).to(device)
|
model = BaseModel(5, 4, 4).to(device)
|
||||||
print(model)
|
print(model)
|
||||||
|
@ -109,7 +111,7 @@ if __name__ == '__main__':
|
||||||
# test = test.cpu().detach().numpy()
|
# test = test.cpu().detach().numpy()
|
||||||
print(test)
|
print(test)
|
||||||
|
|
||||||
state = np.random.random((4, 11, 11))
|
state = np.random.random((486,))
|
||||||
target = np.random.random((4, 2))
|
target = np.random.random((4, 2))
|
||||||
states = [
|
states = [
|
||||||
[state],
|
[state],
|
||||||
|
|
176
labirinth_ai/Models/EvolutionModel.py
Normal file
176
labirinth_ai/Models/EvolutionModel.py
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from labirinth_ai.Models.BaseModel import device
|
||||||
|
|
||||||
|
|
||||||
|
class NodeGene:
|
||||||
|
valid_types = ['sensor', 'hidden', 'output']
|
||||||
|
|
||||||
|
def __init__(self, node_id, node_type, bias=None):
|
||||||
|
assert node_type in self.valid_types, 'Unknown node type!'
|
||||||
|
self.node_id = node_id
|
||||||
|
self.node_type = node_type
|
||||||
|
if node_type == 'hidden':
|
||||||
|
assert bias is not None, 'Expected a bias for hidden node types!'
|
||||||
|
self.bias = bias
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionGene:
|
||||||
|
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.enabled = enabled
|
||||||
|
self.innvovation_num = innovation_num
|
||||||
|
self.recurrent = recurrent
|
||||||
|
if weight is None:
|
||||||
|
self.weight = np.random.random(1)[0] * 2 - 1.0
|
||||||
|
else:
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
|
||||||
|
class EvolutionModel(nn.Module):
|
||||||
|
evolutionary = True
|
||||||
|
|
||||||
|
def __init__(self, view_dimension, action_num, channels, genes=None):
|
||||||
|
super(EvolutionModel, self).__init__()
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
|
self.action_num = action_num
|
||||||
|
self.viewD = view_dimension
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
if genes is None:
|
||||||
|
self.num_input_nodes = channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2
|
||||||
|
|
||||||
|
self.genes = {'nodes': {}, 'connections': []}
|
||||||
|
node_id = 0
|
||||||
|
for _ in range(self.num_input_nodes):
|
||||||
|
self.genes['nodes'][node_id] = NodeGene(node_id, 'sensor')
|
||||||
|
node_id += 1
|
||||||
|
first_action = node_id
|
||||||
|
for _ in range(action_num * 2):
|
||||||
|
self.genes['nodes'][node_id] = NodeGene(node_id, 'output')
|
||||||
|
node_id += 1
|
||||||
|
|
||||||
|
for index in range(self.num_input_nodes):
|
||||||
|
for action in range(action_num * 2):
|
||||||
|
self.genes['connections'].append(
|
||||||
|
ConnectionGene(index, first_action + action, True, index*(action_num * 2) + action)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.incoming_connections = {}
|
||||||
|
for connection in self.genes['connections']:
|
||||||
|
if connection.end not in self.incoming_connections.keys():
|
||||||
|
self.incoming_connections[connection.end] = []
|
||||||
|
self.incoming_connections[connection.end].append(connection)
|
||||||
|
|
||||||
|
self.layers = {}
|
||||||
|
self.indices = {}
|
||||||
|
|
||||||
|
self.has_recurrent = False
|
||||||
|
non_recurrent_indices = {}
|
||||||
|
with torch.no_grad():
|
||||||
|
for key, value in self.incoming_connections.items():
|
||||||
|
value.sort(key=lambda element: element.start)
|
||||||
|
|
||||||
|
lin = nn.Linear(len(value), 1, bias=self.genes['nodes'][key].bias is not None)
|
||||||
|
for index, connection in enumerate(value):
|
||||||
|
lin.weight[0, index] = value[index].weight
|
||||||
|
if self.genes['nodes'][key].bias is not None:
|
||||||
|
lin.bias[0] = self.genes['nodes'][key].bias
|
||||||
|
|
||||||
|
non_lin = nn.ELU()
|
||||||
|
sequence = nn.Sequential(
|
||||||
|
lin,
|
||||||
|
non_lin
|
||||||
|
)
|
||||||
|
self.add_module('layer_' + str(key), sequence)
|
||||||
|
self.layers[key] = sequence
|
||||||
|
self.indices[key] = list(map(lambda element: element.start, value))
|
||||||
|
|
||||||
|
non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
|
||||||
|
if not self.has_recurrent and len(non_recurrent_indices[key]) != len(self.indices[key]):
|
||||||
|
self.has_recurrent = True
|
||||||
|
non_recurrent_indices[key] = list(map(lambda element: element.start, non_recurrent_indices[key]))
|
||||||
|
rank_of_node = {}
|
||||||
|
for i in range(self.num_input_nodes):
|
||||||
|
rank_of_node[i] = 0
|
||||||
|
|
||||||
|
layers_to_add = list(non_recurrent_indices.items())
|
||||||
|
while len(layers_to_add) > 0:
|
||||||
|
for index, (key, incoming_nodes) in enumerate(list(layers_to_add)):
|
||||||
|
max_rank = -1
|
||||||
|
all_ranks_found = True
|
||||||
|
|
||||||
|
for incoming_node in incoming_nodes:
|
||||||
|
if incoming_node in rank_of_node.keys():
|
||||||
|
max_rank = max(max_rank, rank_of_node[incoming_node])
|
||||||
|
else:
|
||||||
|
all_ranks_found = False
|
||||||
|
|
||||||
|
if all_ranks_found:
|
||||||
|
rank_of_node[key] = max_rank + 1
|
||||||
|
|
||||||
|
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
|
||||||
|
ranked_layers = list(rank_of_node.items())
|
||||||
|
ranked_layers.sort(key=lambda element: element[1])
|
||||||
|
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
|
||||||
|
self.layer_order = list(map(lambda element: element[0], ranked_layers))
|
||||||
|
self.memory = torch.Tensor((max(map(lambda element: element[1].node_id, self.genes['nodes'].items())) + 1))
|
||||||
|
|
||||||
|
def forward(self, x, memory=None):
|
||||||
|
x_flat = self.flatten(x)
|
||||||
|
if memory is None:
|
||||||
|
memory = torch.Tensor(self.memory)
|
||||||
|
outs = []
|
||||||
|
for batch_element in x_flat:
|
||||||
|
memory[0:self.num_input_nodes] = batch_element
|
||||||
|
for layer_index in self.layer_order:
|
||||||
|
memory[layer_index] = self.layers[layer_index](memory[self.indices[layer_index]])
|
||||||
|
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
|
||||||
|
outs = torch.stack(outs)
|
||||||
|
self.memory = torch.Tensor(memory)
|
||||||
|
return torch.reshape(outs, (x.shape[0], 4, 2))
|
||||||
|
else:
|
||||||
|
memory[:, 0:self.num_input_nodes] = x
|
||||||
|
for layer_index in self.layer_order:
|
||||||
|
memory[:, layer_index] = self.layers[layer_index](memory[:, self.indices[layer_index]])
|
||||||
|
return torch.reshape(
|
||||||
|
memory[:, self.num_input_nodes: self.num_input_nodes + self.action_num * 2],
|
||||||
|
(x.shape[0], 4, 2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sample = np.random.random((1, 486))
|
||||||
|
|
||||||
|
model = EvolutionModel(5, 4, 4).to(device)
|
||||||
|
print(model)
|
||||||
|
print(model.has_recurrent)
|
||||||
|
|
||||||
|
test = model(torch.tensor(sample, dtype=torch.float32))
|
||||||
|
# test = test.cpu().detach().numpy()
|
||||||
|
print(test)
|
||||||
|
|
||||||
|
state = np.random.random((1, 486))
|
||||||
|
target = np.random.random((4, 2))
|
||||||
|
states = [
|
||||||
|
[state],
|
||||||
|
[state],
|
||||||
|
[state],
|
||||||
|
[state],
|
||||||
|
]
|
||||||
|
targets = [
|
||||||
|
[target],
|
||||||
|
[target],
|
||||||
|
[target],
|
||||||
|
[target],
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
|
||||||
|
from labirinth_ai.Models.BaseModel import train
|
||||||
|
train(states, targets, model, optimizer)
|
|
@ -382,6 +382,8 @@ class NetLearner(Subject):
|
||||||
|
|
||||||
self.lastRewards = []
|
self.lastRewards = []
|
||||||
|
|
||||||
|
self.accumulated_rewards = 0
|
||||||
|
|
||||||
def visualize(self):
|
def visualize(self):
|
||||||
print(self.name)
|
print(self.name)
|
||||||
layers = self.model.get_weights()
|
layers = self.model.get_weights()
|
||||||
|
@ -542,6 +544,8 @@ class NetLearner(Subject):
|
||||||
self.train()
|
self.train()
|
||||||
self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
|
self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
|
||||||
|
|
||||||
|
self.accumulated_rewards += self.lastReward
|
||||||
|
|
||||||
self.lastAction = self.action
|
self.lastAction = self.action
|
||||||
self.lastState = self.state
|
self.lastState = self.state
|
||||||
self.lastReward = 0
|
self.lastReward = 0
|
||||||
|
@ -728,10 +732,12 @@ class Herbivore(NetLearner):
|
||||||
if len(action) == 2:
|
if len(action) == 2:
|
||||||
if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
|
if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
|
||||||
for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
|
for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
|
||||||
if sub.alive:
|
if isinstance(sub, Hunter):
|
||||||
self.kills += 1
|
if sub.alive:
|
||||||
sub.alive = False
|
sub.kills += 1
|
||||||
self.alive = True
|
sub.alive = True
|
||||||
|
sub.lastReward += 10
|
||||||
|
self.alive = False
|
||||||
|
|
||||||
self.lastRewards = []
|
self.lastRewards = []
|
||||||
if right in directions:
|
if right in directions:
|
||||||
|
@ -795,6 +801,10 @@ class Herbivore(NetLearner):
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
def respawnUpdate(self, x, y, world: LabyrinthWorld):
|
||||||
|
super(Herbivore, self).respawnUpdate(x, y, world)
|
||||||
|
self.lastReward -= 1
|
||||||
|
|
||||||
|
|
||||||
class Hunter(NetLearner):
|
class Hunter(NetLearner):
|
||||||
name = 'Hunter'
|
name = 'Hunter'
|
||||||
|
|
Loading…
Reference in a new issue