-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_1d.py
107 lines (82 loc) · 3.33 KB
/
main_1d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import yaml
import pickle
import argparse
import numpy as np
import torch as T
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from collections import namedtuple
from common.shared_optim import SharedAdam, SharedRMSprop
from Harlow_1D.train import train, train_stacked
from models.a3c_lstm_simple import A3C_LSTM, A3C_StackedLSTM
if __name__ == "__main__":
mp.set_start_method("spawn")
os.environ['OMP_NUM_THREADS'] = '1'
parser = argparse.ArgumentParser(description='Paramaters')
parser.add_argument('-c', '--config', type=str, default="Harlow_1D/config.yaml", help='path of config file')
args = parser.parse_args()
with open(args.config, 'r', encoding="utf-8") as fin:
config = yaml.load(fin, Loader=yaml.FullLoader)
n_seeds = 8
base_seed = config["seed"]
base_run_title = config["run-title"]
for seed_idx in range(1, n_seeds + 1):
config["run-title"] = base_run_title + f"_{seed_idx}"
config["seed"] = base_seed * seed_idx
exp_path = os.path.join(config["save-path"], config["run-title"])
if not os.path.isdir(exp_path):
os.mkdir(exp_path)
out_path = os.path.join(exp_path, os.path.basename(args.config))
with open(out_path, 'w') as fout:
yaml.dump(config, fout)
############## Start Here ##############
print(f"> Running {config['run-title']} {config['mode']} using {config['optimizer']}")
if config["mode"] == "vanilla":
shared_model = A3C_LSTM(
config["task"]["input-dim"],
config["agent"]["mem-units"],
config["task"]["num-actions"],
config["agent"]["cell-type"]
)
elif config["mode"] == "stacked":
shared_model = A3C_StackedLSTM(
config["task"]["input-dim"],
config["agent"]["mem-units"],
config["task"]["num-actions"],
device=config["device"]
)
else:
raise ValueError(config["mode"])
shared_model.share_memory()
shared_model.to(config['device'])
print(shared_model)
optim_class = SharedAdam if config["optimizer"] == "adam" else SharedRMSprop
optimizer = optim_class(shared_model.parameters(), lr=config["agent"]["lr"])
optimizer.share_memory()
processes = []
T.manual_seed(config["seed"])
np.random.seed(config["seed"])
T.random.manual_seed(config["seed"])
if config["resume"]:
filepath = os.path.join(
config["save-path"],
config["load-title"],
f"{config['load-title']}_{config['start-episode']}.pt"
)
print(f"> Loading Checkpoint {filepath}")
shared_model.load_state_dict(T.load(filepath)["state_dict"])
train_target = train_stacked if config["mode"] == "stacked" else train
for rank in range(config["agent"]["n-workers"]):
p = mp.Process(target=train_target, args=(
config,
shared_model,
optimizer,
rank,
))
p.start()
processes += [p]
for p in processes:
p.join()