import torch from torch import nn import numpy as np import tqdm from torch.utils.data import Dataset, DataLoader import os os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device} device") # Define model class BaseModel(nn.Module): def __init__(self, view_dimension, action_num, channels): super(BaseModel, self).__init__() self.flatten = nn.Flatten() self.actions = [] self.action_num = action_num self.viewD = view_dimension self.channels = channels for action in range(action_num): action_sequence = nn.Sequential( nn.Linear(channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2, (2 * self.viewD + 1) * (2 * self.viewD + 1)), nn.ELU(), nn.Linear((2 * self.viewD + 1) * (2 * self.viewD + 1), (self.viewD + 1) * (self.viewD + 1)), nn.ELU(), nn.Linear((self.viewD + 1) * (self.viewD + 1), 2) ) self.add_module('action_' + str(action), action_sequence) self.actions.append(action_sequence) def forward(self, x): x_flat = self.flatten(x) actions = [] for action in range(self.action_num): actions.append(self.actions[action](x_flat)) return torch.stack(actions, dim=1) class BaseDataSet(Dataset): def __init__(self, states, targets): assert len(states) == len(targets), "Needs to have as many states as targets!" self.states = torch.tensor(states, dtype=torch.float32) self.targets = torch.tensor(targets, dtype=torch.float32) def __len__(self): return len(self.states) def __getitem__(self, idx): return self.states[idx], self.targets[idx] def create_optimizer(model): return torch.optim.RMSprop(model.parameters(), lr=1e-3) def create_loss_function(action): def custom_loss(prediction, target): return torch.mean(0.5 * torch.square( 0.1 * target[:, 0, 0] + target[:, 1, 0] - ( prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square( target[:, 1, 0] - prediction[:, action, 0]), dim=0) return custom_loss def from_numpy(x): return torch.tensor(x, dtype=torch.float32) def train(states, targets, model, optimizer): for action in range(model.action_num): data_set = BaseDataSet(states[action], targets[action]) dataloader = DataLoader(data_set, batch_size=64, shuffle=True) loss_fn = create_loss_function(action) size = len(dataloader) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") model.eval() del data_set del dataloader if __name__ == '__main__': sample = np.random.random((1, 4, 11, 11)) model = BaseModel(5, 4, 4).to(device) print(model) test = model(torch.tensor(sample, dtype=torch.float32)) # test = test.cpu().detach().numpy() print(test) state = np.random.random((4, 11, 11)) target = np.random.random((4, 2)) states = [ [state], [state], [state], [state], ] targets = [ [target], [target], [target], [target], ] optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) train(states, targets, model, optimizer)