Skip to content

Commit 71ac2c2

Browse files
committed
fx
1 parent 08cb430 commit 71ac2c2

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

src/nuterm/pytorch_tests.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ void pytorch_test6()
258258

259259
auto my_loss =
260260
[](const torch::Tensor &curr, const torch::Tensor &next) -> torch::Tensor {
261-
return torch::relu(next - curr + 1.0);
261+
return torch::relu(next - curr + 1.0);
262262
};
263263

264264
torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.1);
@@ -318,7 +318,7 @@ void pytorch_test7()
318318

319319
auto my_loss =
320320
[](const torch::Tensor &curr, const torch::Tensor &next) -> torch::Tensor {
321-
return torch::relu(next - curr + 1.0);
321+
return torch::relu(next - curr + 1.0);
322322
};
323323

324324
torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.1);

src/nuterm/training.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ void ranking_function_training(
2424
assert(batch.size(0) == 2);
2525

2626
auto ranking_function_loss =
27-
[](const torch::Tensor &curr, const torch::Tensor &next) -> torch::Tensor
28-
{
27+
[](const torch::Tensor &curr, const torch::Tensor &next) -> torch::Tensor {
2928
assert(curr.dim() == 1 && next.dim() == 1);
3029
// The ranking needs to decrease from 'curr' to 'next'
3130
// by at least 'delta'. Anything less than that is loss.

0 commit comments

Comments
 (0)