solves exiting
This commit is contained in:
parent
e718873caa
commit
33b5d9c83e
2 changed files with 7 additions and 3 deletions
|
@ -4,6 +4,9 @@ import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0"
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
|
@ -36,7 +39,6 @@ class BaseModel(nn.Module):
|
||||||
actions.append(self.actions[action](x_flat))
|
actions.append(self.actions[action](x_flat))
|
||||||
return torch.stack(actions, dim=1)
|
return torch.stack(actions, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class BaseDataSet(Dataset):
|
class BaseDataSet(Dataset):
|
||||||
def __init__(self, states, targets):
|
def __init__(self, states, targets):
|
||||||
assert len(states) == len(targets), "Needs to have as many states as targets!"
|
assert len(states) == len(targets), "Needs to have as many states as targets!"
|
||||||
|
@ -93,6 +95,9 @@ def train(states, targets, model, optimizer):
|
||||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
del data_set
|
||||||
|
del dataloader
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sample = np.random.random((1, 4, 11, 11))
|
sample = np.random.random((1, 4, 11, 11))
|
||||||
|
|
|
@ -370,8 +370,7 @@ class NetLearner(Subject):
|
||||||
self.x_in = []
|
self.x_in = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
self.target = []
|
self.target = []
|
||||||
self.model = BaseModel(self.viewD, 4, 4)
|
self.model = BaseModel(self.viewD, 4, 4).to(device)
|
||||||
self.model.to(device)
|
|
||||||
self.optimizer = create_optimizer(self.model)
|
self.optimizer = create_optimizer(self.model)
|
||||||
|
|
||||||
if len(self.samples) < self.randomBuffer:
|
if len(self.samples) < self.randomBuffer:
|
||||||
|
|
Loading…
Reference in a new issue