Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NCF_PyTorch models #536

Merged
merged 6 commits into from
Oct 21, 2023
Merged

Add NCF_PyTorch models #536

merged 6 commits into from
Oct 21, 2023

Conversation

hieuddo
Copy link
Member

@hieuddo hieuddo commented Oct 20, 2023

Description

Add the option to use NCF models in PyTorch (implemented models are tensorflow).

Checklist:

  • I have added tests.
  • I have updated the documentation accordingly.

@hieuddo hieuddo requested a review from tqtg October 20, 2023 08:48
@hieuddo hieuddo self-assigned this Oct 20, 2023
@hieuddo hieuddo marked this pull request as draft October 20, 2023 09:06
@hieuddo hieuddo requested review from darrylong and removed request for tqtg October 20, 2023 09:06
@darrylong darrylong requested a review from tqtg October 20, 2023 09:41
@hieuddo hieuddo marked this pull request as ready for review October 20, 2023 11:50
@hieuddo
Copy link
Member Author

hieuddo commented Oct 20, 2023

I did a quick exp with movielens-100k dataset, n_factors=8 and layers=[32,16,8]

  • MSELoss for GMF, 100 epochs
  • BCELoss for MLP and NeuMF, each 10 epochs

and get these results:

NDCG@50 Recall@50 Train (s) Test (s)
GMF-PyTorch 0.0344 0.0556 137.7308 0.5606
MLP-PyTorch 0.2127 0.3179 14.0588 0.7142
NeuMF-PyTorch 0.2123 0.3169 14.3680 0.7654

The results for the same dataset from implemented Tensorflow version, same embedding/structure size (which are acquired from cornac/tutorials notebook ) are:

NDCG@50 Recall@50 Train (s) Test (s)
GMF 0.1609 0.2778 33.4836 1.2062
MLP 0.1630 0.2807 24.3624 1.2698
NeuMF 0.2412 0.4268 24.6304 1.3844

Those results are with some random hyperparameters without carefully tuning, so I think PyTorch version is comparable with Tensorflow version. @darrylong please help review the code! Thanks.

@tqtg
Copy link
Member

tqtg commented Oct 20, 2023

@hieuddo any idea why GMF performance is quite low as compared to others?

@hieuddo
Copy link
Member Author

hieuddo commented Oct 21, 2023

Turned out to be because of seed. I carried out the exp for GMF with the same hyperparameters, except seed, and got these results:

NDCG@50 Recall@50 Train (s) Test (s)
GMF-PyTorch 0.2059 0.3083 166.3036 0.5522
GMF-PyTorch 0.1979 0.2867 163.1750 0.5426
GMF-PyTorch 0.0005 0.0009 162.1954 0.5394

This could be due to overfitting because training_loss is very low (0.04).

Anyway, my point still stands: comparable exp results are achievable with this PyTorch implementation.

One discussion though, should we stick to the original paper and hard-coded BCE as loss function or should we offer the option to change loss function by hyperparameter?

@tqtg
Copy link
Member

tqtg commented Oct 21, 2023

@hieuddo I refactor the code so we can specify backend for the models. They're all runnable though could you help check the logic if I've missed something? Refer to the updated ncf_example.py.

@tqtg
Copy link
Member

tqtg commented Oct 21, 2023

One discussion though, should we stick to the original paper and hard-coded BCE as loss function or should we offer the option to change loss function by hyperparameter?

BCE is fine. Any reason why we want to add more options?

@hieuddo
Copy link
Member Author

hieuddo commented Oct 21, 2023

@hieuddo I refactor the code so we can specify backend for the models. They're all runnable though could you help check the logic if I've missed something? Refer to the updated ncf_example.py.

For now, it seems fine to me. Will check again later because I'll be using NCF-PyTorch in the future.

BCE is fine. Any reason why we want to add more options?

No specific reason. Just my random wonder. So let's just stick with BCE.

Btw, your "refactor" commit is wonderful madness. I guess I'll be contributing more to cornac so I can learn from your bag of tricks haha

@tqtg
Copy link
Member

tqtg commented Oct 21, 2023

Here are results running the example with the two backend:

TensorFlow NDCG@50 Recall@50 Train (s) Test (s)
GMF 0.0404 0.1148 69.4215 1.9982
MLP 0.0397 0.1212 68.0473 3.3358
NeuMF 0.0397 0.1207 68.6332 3.5113
NeuMF_pretrained 0.0398 0.1212 67.5522 3.4957
PyTorch NDCG@50 Recall@50 Train (s) Test (s)
GMF 0.0407 0.1167 70.2337 1.7052
MLP 0.0391 0.1105 67.4546 2.071
NeuMF 0.0397 0.1155 68.184 2.3072
NeuMF_pretrained 0.0407 0.1167 67.5983 2.2617

They look comparable both in terms of performance and time. Still missing model load and save logic for PyTorch though we can add them later in another PR. This one has become quite complex with many changes.

@hieuddo @darrylong feel free to merge if you're comfortable with this. Thanks Hieu again for the contribution.

@hieuddo hieuddo merged commit f4d97b7 into PreferredAI:master Oct 21, 2023
12 checks passed
@hieuddo hieuddo deleted the ncf_pytorch branch October 21, 2023 17:18
darrylong pushed a commit to darrylong/cornac that referenced this pull request Oct 23, 2023
Add PyTorch backend for NCF models

---------

Co-authored-by: tqtg <tuantq.vnu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants