2022-02-12 17:35:15 +01:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
import numpy as np
|
|
|
|
import tqdm
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
2022-02-12 19:30:03 +01:00
|
|
|
import os
|
|
|
|
os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0"
|
|
|
|
|
2022-02-12 17:35:15 +01:00
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
print(f"Using {device} device")
|
|
|
|
|
|
|
|
|
|
|
|
# Define model
|
|
|
|
class BaseModel(nn.Module):
|
2022-03-11 14:19:55 +01:00
|
|
|
evolutionary = False
|
2022-02-12 17:35:15 +01:00
|
|
|
def __init__(self, view_dimension, action_num, channels):
|
|
|
|
super(BaseModel, self).__init__()
|
|
|
|
self.flatten = nn.Flatten()
|
|
|
|
self.actions = []
|
|
|
|
self.action_num = action_num
|
|
|
|
self.viewD = view_dimension
|
|
|
|
self.channels = channels
|
|
|
|
for action in range(action_num):
|
|
|
|
action_sequence = nn.Sequential(
|
|
|
|
nn.Linear(channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2,
|
|
|
|
(2 * self.viewD + 1) * (2 * self.viewD + 1)),
|
|
|
|
nn.ELU(),
|
|
|
|
nn.Linear((2 * self.viewD + 1) * (2 * self.viewD + 1), (self.viewD + 1) * (self.viewD + 1)),
|
|
|
|
nn.ELU(),
|
|
|
|
nn.Linear((self.viewD + 1) * (self.viewD + 1), 2)
|
|
|
|
)
|
|
|
|
self.add_module('action_' + str(action), action_sequence)
|
|
|
|
self.actions.append(action_sequence)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x_flat = self.flatten(x)
|
|
|
|
actions = []
|
|
|
|
for action in range(self.action_num):
|
|
|
|
actions.append(self.actions[action](x_flat))
|
|
|
|
return torch.stack(actions, dim=1)
|
|
|
|
|
2022-03-11 14:19:55 +01:00
|
|
|
|
2022-02-12 17:35:15 +01:00
|
|
|
class BaseDataSet(Dataset):
|
|
|
|
def __init__(self, states, targets):
|
|
|
|
assert len(states) == len(targets), "Needs to have as many states as targets!"
|
2022-08-12 15:48:30 +02:00
|
|
|
self.states = torch.tensor(np.array(states), dtype=torch.float32)
|
|
|
|
self.targets = torch.tensor(np.array(targets), dtype=torch.float32)
|
2022-02-12 17:35:15 +01:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.states)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
return self.states[idx], self.targets[idx]
|
|
|
|
|
|
|
|
|
|
|
|
def create_optimizer(model):
|
|
|
|
return torch.optim.RMSprop(model.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
|
|
|
|
def create_loss_function(action):
|
|
|
|
def custom_loss(prediction, target):
|
|
|
|
return torch.mean(0.5 * torch.square(
|
|
|
|
0.1 * target[:, 0, 0] + target[:, 1, 0] - (
|
|
|
|
prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
|
|
|
|
target[:, 1, 0] - prediction[:, action, 0]), dim=0)
|
|
|
|
|
|
|
|
return custom_loss
|
|
|
|
|
|
|
|
|
|
|
|
def from_numpy(x):
|
2022-08-12 15:48:30 +02:00
|
|
|
return torch.tensor(np.array(x), dtype=torch.float32)
|
2022-02-12 17:35:15 +01:00
|
|
|
|
|
|
|
|
|
|
|
def train(states, targets, model, optimizer):
|
|
|
|
for action in range(model.action_num):
|
|
|
|
data_set = BaseDataSet(states[action], targets[action])
|
|
|
|
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
|
|
|
|
loss_fn = create_loss_function(action)
|
|
|
|
|
|
|
|
size = len(dataloader)
|
|
|
|
model.train()
|
|
|
|
for batch, (X, y) in enumerate(dataloader):
|
|
|
|
X, y = X.to(device), y.to(device)
|
|
|
|
|
|
|
|
# Compute prediction error
|
|
|
|
pred = model(X)
|
|
|
|
loss = loss_fn(pred, y)
|
|
|
|
|
|
|
|
# Backpropagation
|
|
|
|
optimizer.zero_grad()
|
2022-03-11 14:19:55 +01:00
|
|
|
loss.backward(retain_graph=True)
|
2022-02-12 17:35:15 +01:00
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
if batch % 100 == 0:
|
|
|
|
loss, current = loss.item(), batch * len(X)
|
|
|
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
|
|
|
model.eval()
|
|
|
|
|
2022-02-12 19:30:03 +01:00
|
|
|
del data_set
|
|
|
|
del dataloader
|
|
|
|
|
2022-02-12 17:35:15 +01:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2022-03-11 14:19:55 +01:00
|
|
|
sample = np.random.random((1, 486))
|
2022-02-12 17:35:15 +01:00
|
|
|
|
|
|
|
model = BaseModel(5, 4, 4).to(device)
|
|
|
|
print(model)
|
|
|
|
|
|
|
|
test = model(torch.tensor(sample, dtype=torch.float32))
|
|
|
|
# test = test.cpu().detach().numpy()
|
|
|
|
print(test)
|
|
|
|
|
2022-03-11 14:19:55 +01:00
|
|
|
state = np.random.random((486,))
|
2022-02-12 17:35:15 +01:00
|
|
|
target = np.random.random((4, 2))
|
|
|
|
states = [
|
|
|
|
[state],
|
|
|
|
[state],
|
|
|
|
[state],
|
|
|
|
[state],
|
|
|
|
]
|
|
|
|
targets = [
|
|
|
|
[target],
|
|
|
|
[target],
|
|
|
|
[target],
|
|
|
|
[target],
|
|
|
|
]
|
|
|
|
|
|
|
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
train(states, targets, model, optimizer)
|
|
|
|
|