solves exiting

This commit is contained in:
zomseffen 2022-02-12 19:30:03 +01:00
parent e718873caa
commit 33b5d9c83e
2 changed files with 7 additions and 3 deletions

View file

@ -4,6 +4,9 @@ import numpy as np
import tqdm import tqdm
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import os
os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device") print(f"Using {device} device")
@ -36,7 +39,6 @@ class BaseModel(nn.Module):
actions.append(self.actions[action](x_flat)) actions.append(self.actions[action](x_flat))
return torch.stack(actions, dim=1) return torch.stack(actions, dim=1)
class BaseDataSet(Dataset): class BaseDataSet(Dataset):
def __init__(self, states, targets): def __init__(self, states, targets):
assert len(states) == len(targets), "Needs to have as many states as targets!" assert len(states) == len(targets), "Needs to have as many states as targets!"
@ -93,6 +95,9 @@ def train(states, targets, model, optimizer):
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
model.eval() model.eval()
del data_set
del dataloader
if __name__ == '__main__': if __name__ == '__main__':
sample = np.random.random((1, 4, 11, 11)) sample = np.random.random((1, 4, 11, 11))

View file

@ -370,8 +370,7 @@ class NetLearner(Subject):
self.x_in = [] self.x_in = []
self.actions = [] self.actions = []
self.target = [] self.target = []
self.model = BaseModel(self.viewD, 4, 4) self.model = BaseModel(self.viewD, 4, 4).to(device)
self.model.to(device)
self.optimizer = create_optimizer(self.model) self.optimizer = create_optimizer(self.model)
if len(self.samples) < self.randomBuffer: if len(self.samples) < self.randomBuffer: