diff --git a/R/training_run.R b/R/training_run.R index 0e6c888..6f2495c 100644 --- a/R/training_run.R +++ b/R/training_run.R @@ -388,8 +388,8 @@ reset_tf_graph <- function() { tryCatch({ if (reticulate::py_module_available("tensorflow")) { tf <- reticulate::import("tensorflow") - if (tf_version(tf) >= 1.13 && !tf$executing_eagerly()) - if(tf_version(tf) >= 2.0) { + if (tf_version(tf) >= "1.13" && !tf$executing_eagerly()) + if(tf_version(tf) >= "2.0") { tf$compat$v1$reset_default_graph() } else { tf$reset_default_graph()