reworks loss function and training data, cleans up directions in code

This commit is contained in:
zomseffen 2022-11-14 11:17:00 +01:00
parent cf4d773c10
commit 26e7ffb12b
2 changed files with 129 additions and 175 deletions
labirinth_ai/Models

View file

@ -1,7 +1,7 @@
import torch
from torch import nn
import numpy as np
import tqdm
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
@ -14,6 +14,7 @@ print(f"Using {device} device")
# Define model
class BaseModel(nn.Module):
evolutionary = False
def __init__(self, view_dimension, action_num, channels):
super(BaseModel, self).__init__()
self.flatten = nn.Flatten()
@ -59,11 +60,18 @@ def create_optimizer(model):
def create_loss_function(action):
lambda_factor = 0.0
split_factor = 1.0
def custom_loss(prediction, target):
return torch.mean(0.5 * torch.square(
0.1 * target[:, 0, 0] + target[:, 1, 0] - (
prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
target[:, 1, 0] - prediction[:, action, 0]), dim=0)
return torch.mean(split_factor * torch.square(
# discounted best estimate the old weights made for t+1
lambda_factor * target[:, 0, 0] +
# actual reward for t
target[:, 1, 0] -
# estimate for current weights
(prediction[:, action, 0] + prediction[:, action, 1])) +
# trying to learn present reward separate from future reward
(1.0 - split_factor) * torch.square(target[:, 1, 0] - prediction[:, action, 0]), dim=0)
return custom_loss
@ -75,26 +83,31 @@ def from_numpy(x):
def train(states, targets, model, optimizer):
for action in range(model.action_num):
data_set = BaseDataSet(states[action], targets[action])
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
dataloader = DataLoader(data_set, batch_size=256, shuffle=True)
loss_fn = create_loss_function(action)
size = len(dataloader)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
epochs = 1
with tqdm(range(epochs)) as progress_bar:
for _ in enumerate(progress_bar):
losses = []
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Backpropagation
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# Backpropagation
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
losses.append(loss.item())
progress_bar.set_postfix(loss=np.average(losses))
progress_bar.update()
model.eval()
del data_set