Skip to content

Commit

Permalink
add missing method and fix implementation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
GiviMAD committed Nov 13, 2023
1 parent 7f18ba6 commit d90820f
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<artifactId>whisper-jni</artifactId>
<name>whisper-jni</name>
<url>https://github.com/GiviMAD/whisper-jni</url>
<version>1.4.3-2</version>
<version>1.4.3-3</version>
<description>A JNI wrapper for [whisper.cpp](https://github.com/ggerganov/whisper.cpp), allows to transcribe speech to text in Java</description>

<licenses>
Expand Down
18 changes: 15 additions & 3 deletions src/main/java/io/github/givimad/whisperjni/WhisperJNI.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ public WhisperContext init(Path model, WhisperContextParams params) throws IOExc
if(params == null) {
params = new WhisperContextParams();
}
return new WhisperContext(this, init(model.toAbsolutePath().toString(), params));
int ref = init(model.toAbsolutePath().toString(), params);
if(ref == -1) {
return null;
}
return new WhisperContext(this, ref);
}

/**
Expand All @@ -107,7 +111,11 @@ public WhisperContext initNoState(Path model, WhisperContextParams params) throw
if(params == null) {
params = new WhisperContextParams();
}
return new WhisperContext(this, initNoState(model.toAbsolutePath().toString(), params));
int ref = initNoState(model.toAbsolutePath().toString(), params);
if(ref == -1) {
return null;
}
return new WhisperContext(this, ref);
}

/**
Expand All @@ -118,7 +126,11 @@ public WhisperContext initNoState(Path model, WhisperContextParams params) throw
*/
public WhisperState initState(WhisperContext context) {
WhisperJNIPointer.assertAvailable(context);
return new WhisperState(this, initState(context.ref), context);
int ref = initState(context.ref);
if(ref == -1) {
return null;
}
return new WhisperState(this, ref, context);
}

/**
Expand Down
83 changes: 66 additions & 17 deletions src/main/native/io_github_givimad_whisperjni_WhisperJNI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
std::map<int, whisper_context *> contextMap;
std::map<int, whisper_state *> stateMap;

static JNIEnv *envRef;
static jclass whisperJniJClass;
static jmethodID logMethodId;
static JNIEnv *envRef = nullptr;
static void whisper_log_proxy(const char * text) {
if(whisperJniJClass) {
jstring jstr = envRef->NewStringUTF(text);
envRef->CallStaticVoidMethod(whisperJniJClass, logMethodId, jstr);
if(envRef) {
jclass whisperJNIClass = envRef->FindClass("io/github/givimad/whisperjni/WhisperJNI");
if(whisperJNIClass) {
jmethodID logMethodId = envRef->GetStaticMethodID(whisperJNIClass, "log", "(Ljava/lang/String;)V");
jstring jstr = envRef->NewStringUTF(text);
envRef->CallStaticVoidMethod(whisperJNIClass, logMethodId, jstr);
}
}
}
int insertModel(whisper_context *ctx)
Expand All @@ -25,17 +27,34 @@ int insertModel(whisper_context *ctx)

struct whisper_context_params newWhisperContextParams(JNIEnv *env, jobject jParams)
{
envRef = env;
jclass paramsJClass = env->GetObjectClass(jParams);
struct whisper_context_params params = whisper_context_default_params();
params.use_gpu = (jboolean)env->GetBooleanField(jParams, env->GetFieldID(paramsJClass, "useGPU", "Z"));
return params;
}

void freeWhisperFullParams(JNIEnv *env, jobject jParams, whisper_full_params params)
{
envRef = env;
jclass paramsJClass = env->GetObjectClass(jParams);
jstring language = (jstring)env->GetObjectField(jParams, env->GetFieldID(paramsJClass, "language", "Ljava/lang/String;"));
if(language) {
env->ReleaseStringUTFChars(language, params.language);
}
jstring initialPrompt = (jstring)env->GetObjectField(jParams, env->GetFieldID(paramsJClass, "initialPrompt", "Ljava/lang/String;"));
if(initialPrompt) {
env->ReleaseStringUTFChars(initialPrompt, params.initial_prompt);
}
}

struct whisper_full_params newWhisperFullParams(JNIEnv *env, jobject jParams)
{
envRef = env;
jclass paramsJClass = env->GetObjectClass(jParams);

whisper_sampling_strategy samplingStrategy = (whisper_sampling_strategy)env->GetIntField(jParams, env->GetFieldID(paramsJClass, "strategy", "I"));
struct whisper_full_params params = whisper_full_default_params(samplingStrategy);
whisper_full_params params = whisper_full_default_params(samplingStrategy);

int nThreads = (jint)env->GetIntField(jParams, env->GetFieldID(paramsJClass, "nThreads", "I"));
if (nThreads > 0)
Expand Down Expand Up @@ -92,58 +111,85 @@ struct whisper_full_params newWhisperFullParams(JNIEnv *env, jobject jParams)

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_init(JNIEnv *env, jobject thisObject, jstring modelPath, jobject jParams)
{
envRef = env;
const char *path = env->GetStringUTFChars(modelPath, NULL);
return insertModel(whisper_init_from_file_with_params(path, newWhisperContextParams(env, jParams)));
struct whisper_context *context = whisper_init_from_file_with_params(path, newWhisperContextParams(env, jParams));
env->ReleaseStringUTFChars(modelPath, path);
if(!context) {
return -1;
}
return insertModel(context);
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initNoState(JNIEnv *env, jobject thisObject, jstring modelPath, jobject jParams)
{
envRef = env;
const char *path = env->GetStringUTFChars(modelPath, NULL);
return insertModel(whisper_init_from_file_with_params_no_state(path, newWhisperContextParams(env, jParams)));
struct whisper_context *context = whisper_init_from_file_with_params_no_state(path, newWhisperContextParams(env, jParams));
env->ReleaseStringUTFChars(modelPath, path);
if(!context) {
return -1;
}
return insertModel(context);
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initState(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
int stateRef = rand();
whisper_state *state = whisper_init_state(contextMap.at(ctxRef));
if(!state) {
return -1;
}
stateMap.insert({stateRef, state});
return stateRef;
}

JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initOpenVINOEncoder(JNIEnv *env, jobject thisObject, jint ctxRef, jstring deviceString) {
const char *device = env->GetStringUTFChars(deviceString, NULL);
envRef = env;
const char* device = env->GetStringUTFChars(deviceString, NULL);
whisper_ctx_init_openvino_encoder(contextMap.at(ctxRef), nullptr, device, nullptr);
env->ReleaseStringUTFChars(deviceString, device);
}

JNIEXPORT jboolean JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_isMultilingual(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
return whisper_is_multilingual(contextMap.at(ctxRef));
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_full(JNIEnv *env, jobject thisObject, jint ctxRef, jobject jParams, jfloatArray samples, jint numSamples)
{
envRef = env;
whisper_full_params params = newWhisperFullParams(env, jParams);
const float *samplesPointer = env->GetFloatArrayElements(samples, NULL);
return whisper_full(contextMap.at(ctxRef), newWhisperFullParams(env, jParams), samplesPointer, numSamples);
int result = whisper_full(contextMap.at(ctxRef), params, samplesPointer, numSamples);
freeWhisperFullParams(env, jParams, params);
return result;
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullWithState(JNIEnv *env, jobject thisObject, jint ctxRef, jint stateRef, jobject jParams, jfloatArray samples, jint numSamples)
{
envRef = env;
const float *samplesPointer = env->GetFloatArrayElements(samples, NULL);
return whisper_full_with_state(contextMap.at(ctxRef), stateMap.at(stateRef), newWhisperFullParams(env, jParams), samplesPointer, numSamples);
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullNSegments(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
return whisper_full_n_segments(contextMap.at(ctxRef));
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullNSegmentsFromState(JNIEnv *env, jobject thisObject, jint stateRef)
{
envRef = env;
return whisper_full_n_segments_from_state(stateMap.at(stateRef));
}

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp0(JNIEnv *env, jobject thisObject, jint ctxRef, jint index)
{
envRef = env;
whisper_context *whisper_ctx = contextMap.at(ctxRef);
int nSegments = whisper_full_n_segments(whisper_ctx);
if (nSegments < index + 1)
Expand All @@ -157,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp1(JNIEnv *env, jobject thisObject, jint ctxRef, jint index)
{
envRef = env;
whisper_context *whisper_ctx = contextMap.at(ctxRef);
int nSegments = whisper_full_n_segments(whisper_ctx);
if (nSegments < index + 1)
Expand All @@ -170,6 +217,7 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentText(JNIEnv *env, jobject thisObject, jint ctxRef, jint index)
{
envRef = env;
whisper_context *whisper_ctx = contextMap.at(ctxRef);
int nSegments = whisper_full_n_segments(whisper_ctx);
if (nSegments < index + 1)
Expand All @@ -184,6 +232,7 @@ JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSe

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp0FromState(JNIEnv *env, jobject thisObject, jint stateRef, jint index)
{
envRef = env;
whisper_state *state = stateMap.at(stateRef);
int nSegments = whisper_full_n_segments_from_state(state);
if (nSegments < index + 1)
Expand All @@ -197,6 +246,7 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp1FromState(JNIEnv *env, jobject thisObject, jint stateRef, jint index)
{
envRef = env;
whisper_state *state = stateMap.at(stateRef);
int nSegments = whisper_full_n_segments_from_state(state);
if (nSegments < index + 1)
Expand All @@ -210,6 +260,7 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTextFromState(JNIEnv *env, jobject thisObject, jint stateRef, jint index)
{
envRef = env;
whisper_state *state = stateMap.at(stateRef);
int nSegments = whisper_full_n_segments_from_state(state);
if (nSegments < index + 1)
Expand All @@ -223,30 +274,28 @@ JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSe
}
JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_printSystemInfo(JNIEnv *env, jobject thisObject)
{
envRef = env;
const char *text = whisper_print_system_info();
return env->NewStringUTF(text);
}
JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_freeContext(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
whisper_free(contextMap.at(ctxRef));
contextMap.erase(ctxRef);
}

JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_freeState(JNIEnv *env, jobject thisObject, jint stateRef)
{
envRef = env;
whisper_free_state(stateMap.at(stateRef));
stateMap.erase(stateRef);
}
JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_setLogger(JNIEnv *env, jclass thisClass, jboolean enabled) {
envRef = env;
if (enabled) {
envRef = env;
whisperJniJClass = thisClass;
logMethodId = env->GetStaticMethodID(thisClass, "log", "(Ljava/lang/String;)V");
whisper_set_log_callback(whisper_log_proxy);
} else {
whisperJniJClass = NULL;
envRef = NULL;
logMethodId = NULL;
whisper_set_log_callback(NULL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ public void printSystemInfo() throws Exception {
System.out.println("whisper.cpp library info: " + whisperCPPSystemInfo);
}

@Test
public void initOpenVINO() throws Exception {
try(var ctx = whisper.initNoState(testModelPath)) {
assertNotNull(ctx);
whisper.initOpenVINO(ctx, "CPU");
}
}

private float[] readJFKFileSamples() throws UnsupportedAudioFileException, IOException {
// sample is a 16 bit int 16000hz little endian wav file
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(samplePath.toFile());
Expand Down

0 comments on commit d90820f

Please sign in to comment.