reworks loss function and training data, cleans up directions in code
This commit is contained in:
parent
cf4d773c10
commit
26e7ffb12b
2 changed files with 129 additions and 175 deletions
labirinth_ai/Models
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import os
|
||||
|
@ -14,6 +14,7 @@ print(f"Using {device} device")
|
|||
# Define model
|
||||
class BaseModel(nn.Module):
|
||||
evolutionary = False
|
||||
|
||||
def __init__(self, view_dimension, action_num, channels):
|
||||
super(BaseModel, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
|
@ -59,11 +60,18 @@ def create_optimizer(model):
|
|||
|
||||
|
||||
def create_loss_function(action):
|
||||
lambda_factor = 0.0
|
||||
split_factor = 1.0
|
||||
def custom_loss(prediction, target):
|
||||
return torch.mean(0.5 * torch.square(
|
||||
0.1 * target[:, 0, 0] + target[:, 1, 0] - (
|
||||
prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
|
||||
target[:, 1, 0] - prediction[:, action, 0]), dim=0)
|
||||
return torch.mean(split_factor * torch.square(
|
||||
# discounted best estimate the old weights made for t+1
|
||||
lambda_factor * target[:, 0, 0] +
|
||||
# actual reward for t
|
||||
target[:, 1, 0] -
|
||||
# estimate for current weights
|
||||
(prediction[:, action, 0] + prediction[:, action, 1])) +
|
||||
# trying to learn present reward separate from future reward
|
||||
(1.0 - split_factor) * torch.square(target[:, 1, 0] - prediction[:, action, 0]), dim=0)
|
||||
|
||||
return custom_loss
|
||||
|
||||
|
@ -75,26 +83,31 @@ def from_numpy(x):
|
|||
def train(states, targets, model, optimizer):
|
||||
for action in range(model.action_num):
|
||||
data_set = BaseDataSet(states[action], targets[action])
|
||||
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
|
||||
dataloader = DataLoader(data_set, batch_size=256, shuffle=True)
|
||||
loss_fn = create_loss_function(action)
|
||||
|
||||
size = len(dataloader)
|
||||
model.train()
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
epochs = 1
|
||||
with tqdm(range(epochs)) as progress_bar:
|
||||
for _ in enumerate(progress_bar):
|
||||
losses = []
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.item(), batch * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
progress_bar.set_postfix(loss=np.average(losses))
|
||||
progress_bar.update()
|
||||
model.eval()
|
||||
|
||||
del data_set
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue