Skip to content

[RFC] Add TrainingModule and SGD JNI #12188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

georgehong
Copy link
Contributor

As title, adds wrappers together with unit test based on XOR train.cpp example.

Summary

Adds JNI for SGD and TrainingModule, including a unit test that mirrors train.cpp for a simple XOR example. Also makes the following change:

  • Refactor jni_layer.cpp JTensor <--> Tensor conversion to be a general TensorHybrid utility. This is useful for TrainingModule classes that move maps of Tensors around.

Training dependencies are already enabled for Java JNI library, so we skip adding additional guard flags.

Test plan

Followed steps in README for extension/android, namely:

sh scripts/build_android_library.sh
sh executorch_android/android_test_setup.sh // Now creates xor.ptd and xor.pte test dependencies

./gradlew :executorch_android:connectedAndroidTest // Added unit test to check toy model convergence loss < 0.01

@georgehong georgehong requested a review from JacobSzwejbka July 3, 2025 00:42
Copy link

pytorch-bot bot commented Jul 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12188

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated Failure

As of commit 19da644 with merge base 9905026 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 3, 2025
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77702768.

Copy link

github-actions bot commented Jul 3, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@georgehong georgehong force-pushed the gh/georgehong/training_jni branch 2 times, most recently from fc8efa0 to 19da994 Compare July 3, 2025 00:58
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77702768.

As title, adds wrappers together with unit test based on XOR train.cpp example.
@georgehong georgehong force-pushed the gh/georgehong/training_jni branch from 19da994 to 19da644 Compare July 3, 2025 06:41
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77702768.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants