Skip to content

Commit

Permalink
fix passing library logs to java and release audio buffer memory
Browse files Browse the repository at this point in the history
  • Loading branch information
GiviMAD committed Nov 19, 2023
1 parent d90820f commit f709a5c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<artifactId>whisper-jni</artifactId>
<name>whisper-jni</name>
<url>https://github.com/GiviMAD/whisper-jni</url>
<version>1.4.3-3</version>
<version>1.4.3-4</version>
<description>A JNI wrapper for [whisper.cpp](https://github.com/ggerganov/whisper.cpp), allows to transcribe speech to text in Java</description>

<licenses>
Expand Down
81 changes: 46 additions & 35 deletions src/main/native/io_github_givimad_whisperjni_WhisperJNI.cpp
Original file line number Diff line number Diff line change
@@ -1,33 +1,55 @@
#include <iostream>
#include <map>
#include <jni.h>
#include "io_github_givimad_whisperjni_WhisperJNI.h"
#include "whisper.h"
#include <map>

std::map<int, whisper_context *> contextMap;
std::map<int, whisper_state *> stateMap;

static JNIEnv *envRef = nullptr;
static JavaVM *jvmRef = nullptr;
static void whisper_log_proxy(const char * text) {
if(envRef) {
jclass whisperJNIClass = envRef->FindClass("io/github/givimad/whisperjni/WhisperJNI");
if(whisperJNIClass) {
jmethodID logMethodId = envRef->GetStaticMethodID(whisperJNIClass, "log", "(Ljava/lang/String;)V");
jstring jstr = envRef->NewStringUTF(text);
envRef->CallStaticVoidMethod(whisperJNIClass, logMethodId, jstr);
if(jvmRef) {
JNIEnv *env;
if (jvmRef->AttachCurrentThread((void**)&env, NULL) != JNI_OK) {
return;
}
jclass whisperJNIClass = env->FindClass("io/github/givimad/whisperjni/WhisperJNI");
jmethodID logMethodId = env->GetStaticMethodID(whisperJNIClass, "log", "(Ljava/lang/String;)V");
jstring jstr = env->NewStringUTF(text);
env->CallStaticVoidMethod(whisperJNIClass, logMethodId, jstr);
jvmRef->DetachCurrentThread();
}
}
int getContextId() {
int i = 0;
while (i++ < 1000) {
int id = rand();
if(!contextMap.count(id)) {
return id;
}
}
throw std::runtime_error("Wrapper error: Unable to get config id");
}
int getStateId() {
int i = 0;
while (i++ < 1000) {
int id = rand();
if(!stateMap.count(id)) {
return id;
}
}
throw std::runtime_error("Wrapper error: Unable to get state id");
}
int insertModel(whisper_context *ctx)
{
int ref = rand();
int ref = getContextId();
contextMap.insert({ref, ctx});
return ref;
}

struct whisper_context_params newWhisperContextParams(JNIEnv *env, jobject jParams)
{
envRef = env;
jclass paramsJClass = env->GetObjectClass(jParams);
struct whisper_context_params params = whisper_context_default_params();
params.use_gpu = (jboolean)env->GetBooleanField(jParams, env->GetFieldID(paramsJClass, "useGPU", "Z"));
Expand All @@ -36,7 +58,6 @@ struct whisper_context_params newWhisperContextParams(JNIEnv *env, jobject jPara

void freeWhisperFullParams(JNIEnv *env, jobject jParams, whisper_full_params params)
{
envRef = env;
jclass paramsJClass = env->GetObjectClass(jParams);
jstring language = (jstring)env->GetObjectField(jParams, env->GetFieldID(paramsJClass, "language", "Ljava/lang/String;"));
if(language) {
Expand All @@ -50,7 +71,6 @@ void freeWhisperFullParams(JNIEnv *env, jobject jParams, whisper_full_params par

struct whisper_full_params newWhisperFullParams(JNIEnv *env, jobject jParams)
{
envRef = env;
jclass paramsJClass = env->GetObjectClass(jParams);

whisper_sampling_strategy samplingStrategy = (whisper_sampling_strategy)env->GetIntField(jParams, env->GetFieldID(paramsJClass, "strategy", "I"));
Expand Down Expand Up @@ -111,7 +131,6 @@ struct whisper_full_params newWhisperFullParams(JNIEnv *env, jobject jParams)

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_init(JNIEnv *env, jobject thisObject, jstring modelPath, jobject jParams)
{
envRef = env;
const char *path = env->GetStringUTFChars(modelPath, NULL);
struct whisper_context *context = whisper_init_from_file_with_params(path, newWhisperContextParams(env, jParams));
env->ReleaseStringUTFChars(modelPath, path);
Expand All @@ -123,7 +142,6 @@ JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_init(JNIEnv

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initNoState(JNIEnv *env, jobject thisObject, jstring modelPath, jobject jParams)
{
envRef = env;
const char *path = env->GetStringUTFChars(modelPath, NULL);
struct whisper_context *context = whisper_init_from_file_with_params_no_state(path, newWhisperContextParams(env, jParams));
env->ReleaseStringUTFChars(modelPath, path);
Expand All @@ -135,8 +153,7 @@ JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initNoState(

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initState(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
int stateRef = rand();
int stateRef = getStateId();
whisper_state *state = whisper_init_state(contextMap.at(ctxRef));
if(!state) {
return -1;
Expand All @@ -146,50 +163,48 @@ JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initState(JN
}

JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initOpenVINOEncoder(JNIEnv *env, jobject thisObject, jint ctxRef, jstring deviceString) {
envRef = env;
const char* device = env->GetStringUTFChars(deviceString, NULL);
whisper_ctx_init_openvino_encoder(contextMap.at(ctxRef), nullptr, device, nullptr);
env->ReleaseStringUTFChars(deviceString, device);
}

JNIEXPORT jboolean JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_isMultilingual(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
return whisper_is_multilingual(contextMap.at(ctxRef));
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_full(JNIEnv *env, jobject thisObject, jint ctxRef, jobject jParams, jfloatArray samples, jint numSamples)
{
envRef = env;
whisper_full_params params = newWhisperFullParams(env, jParams);
const float *samplesPointer = env->GetFloatArrayElements(samples, NULL);
jfloat *samplesPointer = env->GetFloatArrayElements(samples, NULL);
int result = whisper_full(contextMap.at(ctxRef), params, samplesPointer, numSamples);
freeWhisperFullParams(env, jParams, params);
env->ReleaseFloatArrayElements(samples, samplesPointer, 0);
return result;
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullWithState(JNIEnv *env, jobject thisObject, jint ctxRef, jint stateRef, jobject jParams, jfloatArray samples, jint numSamples)
{
envRef = env;
const float *samplesPointer = env->GetFloatArrayElements(samples, NULL);
return whisper_full_with_state(contextMap.at(ctxRef), stateMap.at(stateRef), newWhisperFullParams(env, jParams), samplesPointer, numSamples);
whisper_full_params params = newWhisperFullParams(env, jParams);
jfloat *samplesPointer = env->GetFloatArrayElements(samples, NULL);
int result = whisper_full_with_state(contextMap.at(ctxRef), stateMap.at(stateRef), params, samplesPointer, numSamples);
freeWhisperFullParams(env, jParams, params);
env->ReleaseFloatArrayElements(samples, samplesPointer, 0);
return result;
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullNSegments(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
return whisper_full_n_segments(contextMap.at(ctxRef));
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullNSegmentsFromState(JNIEnv *env, jobject thisObject, jint stateRef)
{
envRef = env;
return whisper_full_n_segments_from_state(stateMap.at(stateRef));
}

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp0(JNIEnv *env, jobject thisObject, jint ctxRef, jint index)
{
envRef = env;
whisper_context *whisper_ctx = contextMap.at(ctxRef);
int nSegments = whisper_full_n_segments(whisper_ctx);
if (nSegments < index + 1)
Expand All @@ -203,7 +218,6 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp1(JNIEnv *env, jobject thisObject, jint ctxRef, jint index)
{
envRef = env;
whisper_context *whisper_ctx = contextMap.at(ctxRef);
int nSegments = whisper_full_n_segments(whisper_ctx);
if (nSegments < index + 1)
Expand All @@ -217,7 +231,6 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentText(JNIEnv *env, jobject thisObject, jint ctxRef, jint index)
{
envRef = env;
whisper_context *whisper_ctx = contextMap.at(ctxRef);
int nSegments = whisper_full_n_segments(whisper_ctx);
if (nSegments < index + 1)
Expand All @@ -232,7 +245,6 @@ JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSe

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp0FromState(JNIEnv *env, jobject thisObject, jint stateRef, jint index)
{
envRef = env;
whisper_state *state = stateMap.at(stateRef);
int nSegments = whisper_full_n_segments_from_state(state);
if (nSegments < index + 1)
Expand All @@ -246,7 +258,6 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTimestamp1FromState(JNIEnv *env, jobject thisObject, jint stateRef, jint index)
{
envRef = env;
whisper_state *state = stateMap.at(stateRef);
int nSegments = whisper_full_n_segments_from_state(state);
if (nSegments < index + 1)
Expand All @@ -260,7 +271,6 @@ JNIEXPORT jlong JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegm

JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSegmentTextFromState(JNIEnv *env, jobject thisObject, jint stateRef, jint index)
{
envRef = env;
whisper_state *state = stateMap.at(stateRef);
int nSegments = whisper_full_n_segments_from_state(state);
if (nSegments < index + 1)
Expand All @@ -274,26 +284,27 @@ JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullGetSe
}
JNIEXPORT jstring JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_printSystemInfo(JNIEnv *env, jobject thisObject)
{
envRef = env;
const char *text = whisper_print_system_info();
return env->NewStringUTF(text);
}
JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_freeContext(JNIEnv *env, jobject thisObject, jint ctxRef)
{
envRef = env;
whisper_free(contextMap.at(ctxRef));
contextMap.erase(ctxRef);
}

JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_freeState(JNIEnv *env, jobject thisObject, jint stateRef)
{
envRef = env;
whisper_free_state(stateMap.at(stateRef));
stateMap.erase(stateRef);
}
JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_setLogger(JNIEnv *env, jclass thisClass, jboolean enabled) {
envRef = env;
if (enabled) {
if (!jvmRef && env->GetJavaVM(&jvmRef) != JNI_OK) {
jclass exClass = env->FindClass("java/lang/RuntimeException");
env->ThrowNew(exClass, "Failed getting reference to Java VM");
return;
}
whisper_set_log_callback(whisper_log_proxy);
} else {
whisper_set_log_callback(NULL);
Expand Down

0 comments on commit f709a5c

Please sign in to comment.