From 82500faf4754d7f86649d9bb03f467954ddab164 Mon Sep 17 00:00:00 2001 From: zindzigriffin Date: Fri, 1 Jul 2022 19:01:03 -0700 Subject: [PATCH] Update MLP_Policy.py --- hw1/cs285/policies/MLP_policy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hw1/cs285/policies/MLP_policy.py b/hw1/cs285/policies/MLP_policy.py index c8e1fd7d..2e353c68 100644 --- a/hw1/cs285/policies/MLP_policy.py +++ b/hw1/cs285/policies/MLP_policy.py @@ -109,7 +109,11 @@ def update( adv_n=None, acs_labels_na=None, qvals=None ): # TODO: update the policy and return the loss - loss = TODO + self.optimizer.zero_grad() + predicted_actions = self(torch.Tensor(observations).to(self.device)) + loss = self.loss_func(predicted_actions, torch.Tensor(actions).to(self.device)) + loss.backward() + self.optimizer.step() return { # You can add extra logging information here, but keep this line 'Training Loss': ptu.to_numpy(loss),