diff --git a/spylon_kernel/scala_interpreter.py b/spylon_kernel/scala_interpreter.py index 6fe67e4..01a2757 100644 --- a/spylon_kernel/scala_interpreter.py +++ b/spylon_kernel/scala_interpreter.py @@ -35,6 +35,43 @@ def init_spark_session(conf: spylon.spark.SparkConfiguration=None, application_n spark_jvm_helpers = SparkJVMHelpers(spark_session._sc) +def get_web_ui_url(sc): + """Get the web ui for a spark context + + Parameters + ---------- + sc : SparkContext + + Returns + ------- + url : str + """ + # Dig into the java spark conf to actually be able to resolve the spark configuration + # noinspection PyProtectedMember + conf = sc._jsc.getConf() + if conf.getBoolean("spark.ui.reverseProxy", False): + proxy_url = conf.get("spark.ui.reverseProxyUrl", "") + if proxy_url: + web_ui_url = "Spark Context Web UI is available at ${proxy_url}/proxy/${application_id}".format( + proxy_url=proxy_url, application_id=sc.applicationId) + else: + web_ui_url = "Spark Context Web UI is available at Spark Master Public URL" + else: + # For spark 2.0 compatibility we have to retrieve this from the scala side. + joption = sc._jsc.sc().uiWebUrl() + if joption.isDefined(): + web_ui_url = joption.get() + else: + web_ui_url = "" + + # Legacy compatible version for YARN + yarn_proxy_spark_property = "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES" + if sc.master.startswith("yarn"): + web_ui_url = conf.get(yarn_proxy_spark_property) + + return web_ui_url + + # noinspection PyProtectedMember def initialize_scala_interpreter(): """ @@ -103,12 +140,8 @@ def cleanup(): def start_imain(): intp = jvm.scala.tools.nsc.interpreter.IMain(settings, jprintWriter) intp.initializeSynchronous() - # TODO : Redirect stdout / stderr to a known pair of files that we can watch. - """ - System.setOut(new PrintStream(new File("output-file.txt"))); - """ - # Copied directly from Spark + # Ensure that sc and spark are bound in the interpreter context. intp.interpret(""" @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { org.apache.spark.repl.Main.sparkSession @@ -117,21 +150,6 @@ def start_imain(): } @transient val sc = { val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println(s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") _sc } """) @@ -143,7 +161,6 @@ def start_imain(): return intp imain = start_imain() - return SparkInterpreter(jvm, imain, bytes_out) @@ -161,22 +178,44 @@ def __init__(self, scala_message, *args, **kwargs): tOutputHandler = Callable[[List[Any]], None] + class SparkInterpreter(object): + """Wrapper for a scala interpreter. + + Notes + ----- + Users should not instantiate this class themselves. Use `get_scala_interpreter` instead. + + Parameters + ---------- + jvm : py4j.java_gateway.JVMView + jimain : py4j.java_gateway.JavaObject + Java object representing an instance of `scala.tools.nsc.interpreter.IMain` + jbyteout : py4j.java_gateway.JavaObject + Java object representing an instance of `org.apache.commons.io.output.ByteArrayOutputStream` + This is used to return output data from the REPL. + loop : asyncio.AbstractEventLoop, optional + Asyncio eventloop + + """ executor = ThreadPoolExecutor(4) - def __init__(self, jvm, jiloop, jbyteout, loop: Union[None, asyncio.AbstractEventLoop]=None ): + def __init__(self, jvm, jimain, jbyteout, loop: Union[None, asyncio.AbstractEventLoop]=None): + self.spark_session = spark_session + # noinspection PyProtectedMember + self.sc = spark_session._sc + self.web_ui_url = get_web_ui_url(self.sc) self._jcompleter = None self.jvm = jvm - self.jiloop = jiloop + self.jimain = jimain if loop is None: # TODO: We may want to use new_event_loop here to avoid stopping and starting the main one. loop = asyncio.get_event_loop() self.loop = loop self.log = logging.getLogger(self.__class__.__name__) - interpreterPkg = getattr(getattr(self.jvm.scala.tools.nsc.interpreter, 'package$'), "MODULE$") - # spark_jvm_helpers.import_scala_package_object("scala.tools.nsc.interpreter") - self.iMainOps = interpreterPkg.IMainOps(jiloop) + jinterpreter_package = getattr(getattr(self.jvm.scala.tools.nsc.interpreter, 'package$'), "MODULE$") + self.iMainOps = jinterpreter_package.IMainOps(jimain) self.jbyteout = jbyteout tempdir = tempfile.mkdtemp() @@ -184,9 +223,7 @@ def __init__(self, jvm, jiloop, jbyteout, loop: Union[None, asyncio.AbstractEven self.tempdir = tempdir # Handlers for dealing with stout and stderr. This allows us to insert additional behavior for magics self._stdout_handlers = [] - # self.register_stdout_handler(lambda *args: print(*args, file=sys.stdout)) self._stderr_handlers = [] - # self.register_stderr_handler(lambda *args: print(*args, file=sys.stderr)) self._initialize_stdout_err() def register_stdout_handler(self, handler: tOutputHandler): @@ -200,7 +237,8 @@ def _initialize_stdout_err(self): stderr_file = os.path.abspath(os.path.join(self.tempdir, 'stderr')) # Start up the pipes on the JVM side - self.log.critical("Before Java redirected") + self.log.info("Before Java redirected") + self.log.debug("stdout/err redirected to %s", self.tempdir) code = 'Console.set{pipe}(new PrintStream(new FileOutputStream(new File(new java.net.URI("{filename}")), true)))' code = '\n'.join([ 'import java.io.{PrintStream, FileOutputStream, File}', @@ -209,7 +247,7 @@ def _initialize_stdout_err(self): code.format(pipe="Err", filename=pathlib.Path(stderr_file).as_uri()) ]) o = self.interpret(code) - self.log.critical("Console redirected") + self.log.info("Console redirected, %s", o) self.loop.create_task(self._poll_file(stdout_file, self.handle_stdout)) self.loop.create_task(self._poll_file(stderr_file, self.handle_stderr)) @@ -235,14 +273,14 @@ async def _poll_file(self, filename: str, fn: Callable[[Any], None]): while True: line = fd.readline() if line: - # self.log.critical("READ LINE from %s, %s", filename, line) + # processing a line from the file and running our processing function. fn(line) # self.log.critical("AFTER PUSH") await asyncio.sleep(0, loop=self.loop) else: await asyncio.sleep(0.01, loop=self.loop) - def interpret_sync(self, code: str, synthetic=False): + def _interpret_sync(self, code: str, synthetic=False): """Interpret a block of scala code. If you want to get the result as a python object, follow this will a call to `last_result()` @@ -258,7 +296,7 @@ def interpret_sync(self, code: str, synthetic=False): String output from the scala REPL. """ try: - res = self.jiloop.interpret(code, synthetic) + res = self.jimain.interpret(code, synthetic) pyres = self.jbyteout.toByteArray().decode("utf-8") # The scala interpreter returns a sentinel case class member here which is typically matched via # pattern matching. Due to it having a very long namespace, we just resort to simple string matching here. @@ -273,17 +311,38 @@ def interpret_sync(self, code: str, synthetic=False): finally: self.jbyteout.reset() - async def interpret_async(self, code: str, future: Future): + async def _interpret_async(self, code: str, future: Future): + """Async execute for running a block of scala code. + + Parameters + ---------- + code : str + future : Future + future used to hold the result of the computation. + """ try: - result = await self.loop.run_in_executor(self.executor, self.interpret_sync, code) + result = await self.loop.run_in_executor(self.executor, self._interpret_sync, code) future.set_result(result) except Exception as e: future.set_exception(e) return def interpret(self, code: str): + """Interpret a block of scala code. + + If you want to get the result as a python object, follow this will a call to `last_result()` + + Parameters + ---------- + code : str + + Returns + ------- + reploutput : str + String output from the scala REPL. + """ fut = asyncio.Future(loop=self.loop) - asyncio.ensure_future(self.interpret_async(code, fut), loop=self.loop) + asyncio.ensure_future(self._interpret_async(code, fut), loop=self.loop) res = self.loop.run_until_complete(fut) return res @@ -298,17 +357,19 @@ def last_result(self): object """ # TODO : when evaluating multiline expressions this returns the first result - lr = self.jiloop.lastRequest() + lr = self.jimain.lastRequest() res = lr.lineRep().call("$result", spark_jvm_helpers.to_scala_list([])) return res def bind(self, name: str, value: Any, jtyp: str="Any"): - """ + """Set a variable in the scala repl environment to a python valued type. Parameters ---------- - varname : str + name : str value : Any + jtyp : str + String representation of the Java type that we want to cast this as. """ modifiers = spark_jvm_helpers.to_scala_list(["@transient"]) @@ -319,13 +380,13 @@ def bind(self, name: str, value: Any, jtyp: str="Any"): int, str, bytes, bool, list, dict, JavaClass, JavaMember, JavaObject ) if isinstance(value, compatible_types): - self.jiloop.bind(name, "Any", value, modifiers) + self.jimain.bind(name, "Any", value, modifiers) @property def jcompleter(self): if self._jcompleter is None: jClass = self.jvm.scala.tools.nsc.interpreter.PresentationCompilerCompleter - self._jcompleter = jClass(self.jiloop) + self._jcompleter = jClass(self.jimain) return self._jcompleter def complete(self, code: str, pos: int) -> List[str]: @@ -359,7 +420,7 @@ def is_complete(self, code): One of 'complete', 'incomplete' or 'invalid' """ try: - res = self.jiloop.parse().apply(code) + res = self.jimain.parse().apply(code) output_class = res.getClass().getName() _, status = output_class.rsplit("$", 1) if status == 'Success': @@ -398,7 +459,7 @@ def get_help_on(self, info): return scala_type[-1] def printHelp(self): - return self.jiloop.helpSummary() + return self.jimain.helpSummary() def get_scala_interpreter(): diff --git a/spylon_kernel/scala_kernel.py b/spylon_kernel/scala_kernel.py index 389a242..34eedb4 100644 --- a/spylon_kernel/scala_kernel.py +++ b/spylon_kernel/scala_kernel.py @@ -9,13 +9,14 @@ from spylon_kernel.scala_interpreter import ScalaException, SparkInterpreter from .init_spark_magic import InitSparkMagic from .scala_magic import ScalaMagic +from ._version import get_versions class SpylonKernel(MetaKernel): implementation = 'spylon-kernel' - implementation_version = '1.0' + implementation_version = get_versions()['version'] language = 'scala' - language_version = '0.1' + language_version = '2.11' banner = "spylon-kernel - evaluates Scala statements and expressions." language_info = { 'mimetype': 'text/x-scala', @@ -54,6 +55,7 @@ def pythonmagic(self): @property def scala_interpreter(self): + # noinspection PyProtectedMember intp = self._scalamagic._get_scala_interpreter() assert isinstance(intp, SparkInterpreter) return intp @@ -82,7 +84,6 @@ def get_variable(self, name): return intp.last_result() def do_execute_direct(self, code, silent=False): - try: res = self._scalamagic.eval(code.strip(), raw=False) if res: diff --git a/spylon_kernel/scala_magic.py b/spylon_kernel/scala_magic.py index b02630e..8107520 100644 --- a/spylon_kernel/scala_magic.py +++ b/spylon_kernel/scala_magic.py @@ -6,6 +6,7 @@ from metakernel import option from metakernel.process_metakernel import TextOutput from tornado import ioloop, gen +from textwrap import dedent from .scala_interpreter import get_scala_interpreter, ScalaException from . import scala_interpreter @@ -18,15 +19,15 @@ class ScalaMagic(Magic): _interp : spylon_kernel.ScalaInterpreter """ - def __init__(self, kernel): super(ScalaMagic, self).__init__(kernel) self.retval = None self._interp = None self._is_complete_ready = False + self.spark_web_ui_url = "" def _get_scala_interpreter(self): - """ + """Ensure that we have a scala interpreter around and set up the stdout/err handlers if needed. Returns ------- @@ -36,20 +37,32 @@ def _get_scala_interpreter(self): assert isinstance(self.kernel, MetaKernel) self.kernel.Display("Intitializing scala interpreter....") self._interp = get_scala_interpreter() - self.kernel.Display("Scala interpreter initialized.") # Ensure that spark is available in the python session as well. - self.kernel.cell_magics['python'].env['spark'] = scala_interpreter.spark_session - # self.Display("Registered spark session in scala and python context as `spark`") - self._initialize_pipes() + self.kernel.cell_magics['python'].env['spark'] = self._interp.spark_session + self.kernel.cell_magics['python'].env['sc'] = self._interp.sc + + sc = self._interp.sc + self.kernel.Display(TextOutput(dedent("""\ + {webui} + Spark context available as 'sc' (master = {master}, app id = {app_id} + Spark context available as 'sc'" + """.format( + master=sc.master, + app_id=sc.applicationId, + webui=self._interp.web_ui_url + ) + ))) + self._is_complete_ready = True self._interp.register_stdout_handler(self.kernel.Write) self._interp.register_stderr_handler(self.kernel.Error) + # Set up the callbacks + self._initialize_pipes() return self._interp def _initialize_pipes(self): + self.kernel.log.info("Starting STDOUT/ERR callback") ioloop.IOLoop.current().spawn_callback(self._loop_alive) - # self._poll_file, STDOUT, self.Write) - # ioloop.IOLoop.current().spawn_callback(self._poll_file, STDERR, self.Error) @gen.coroutine def _loop_alive(self): @@ -97,7 +110,6 @@ def eval(self, code, raw): eclass, _, emessage = first.partition(':') from metakernel import ExceptionWrapper return ExceptionWrapper(eclass, emessage, tb[1:]) - #return self.kernel.Error(e.scala_message) @option( "-e", "--eval_output", action="store_true", default=False, @@ -130,7 +142,6 @@ def cell_scala(self, eval_output=False): """ if self.code.strip(): if eval_output: - # TODO: Validate this works? self.eval(self.code, False) # self.code = str(self.env["retval"]) if ("retval" in self.env and # self.env["retval"] != None) else "" @@ -150,7 +161,6 @@ def post_process(self, retval): def get_completions(self, info): intp = self._get_scala_interpreter() - # raise Exception(repr(info)) c = intp.complete(info['code'], info['help_pos']) # Find common bits in the middle diff --git a/test/test_scala_interpreter.py b/test/test_scala_interpreter.py index 91da7f9..6fc0884 100644 --- a/test/test_scala_interpreter.py +++ b/test/test_scala_interpreter.py @@ -1,77 +1,82 @@ import pytest import re -from spylon_kernel.scala_interpreter import initialize_scala_interpreter +from spylon_kernel.scala_interpreter import initialize_scala_interpreter, get_web_ui_url @pytest.fixture(scope="module") -def scala_kernel(request): +def scala_interpreter(request): wrapper = initialize_scala_interpreter() return wrapper -def test_simple_expression(scala_kernel): - result = scala_kernel.interpret("4 + 4") +def test_simple_expression(scala_interpreter): + result = scala_interpreter.interpret("4 + 4") assert re.match('res\d+: Int = 8\n', result) -def test_completion(scala_kernel): - scala_kernel.interpret("val x = 4") +def test_completion(scala_interpreter): + scala_interpreter.interpret("val x = 4") code = "x.toL" - result = scala_kernel.complete(code, len(code)) + result = scala_interpreter.complete(code, len(code)) assert result == ['toLong'] -def test_is_complete(scala_kernel): - result = scala_kernel.is_complete('val foo = 99') +def test_is_complete(scala_interpreter): + result = scala_interpreter.is_complete('val foo = 99') assert result == 'complete' - result = scala_kernel.is_complete('val foo = {99') + result = scala_interpreter.is_complete('val foo = {99') assert result == 'incomplete' - result = scala_kernel.is_complete('val foo {99') + result = scala_interpreter.is_complete('val foo {99') assert result == 'invalid' -def test_last_result(scala_kernel): - scala_kernel.interpret(""" +def test_last_result(scala_interpreter): + scala_interpreter.interpret(""" case class LastResult(member: Int) val foo = LastResult(8) """) - jres = scala_kernel.last_result() + jres = scala_interpreter.last_result() assert jres.getClass().getName().endswith("LastResult") assert jres.member() == 8 -def test_help(scala_kernel): - scala_kernel.interpret("val x = 4") - h = scala_kernel.get_help_on("x") +def test_help(scala_interpreter): + scala_interpreter.interpret("val x = 4") + h = scala_interpreter.get_help_on("x") - scala_kernel.interpret("case class Foo(bar: String)") - scala_kernel.interpret('val y = Foo("something") ') + scala_interpreter.interpret("case class Foo(bar: String)") + scala_interpreter.interpret('val y = Foo("something") ') - h1 = scala_kernel.get_help_on("y") - h2 = scala_kernel.get_help_on("y.bar") + h1 = scala_interpreter.get_help_on("y") + h2 = scala_interpreter.get_help_on("y.bar") assert h == "Int" assert h1 == "Foo" assert h2 == "String" -def test_spark_rdd(scala_kernel): +def test_spark_rdd(scala_interpreter): """Simple test to ensure we can do RDD things""" - result = scala_kernel.interpret("sc.parallelize(0 until 10).sum().toInt") + result = scala_interpreter.interpret("sc.parallelize(0 until 10).sum().toInt") assert result.strip().endswith(str(sum(range(10)))) -def test_spark_dataset(scala_kernel): - scala_kernel.interpret(""" +def test_spark_dataset(scala_interpreter): + scala_interpreter.interpret(""" case class DatasetTest(y: Int) import spark.implicits._ val df = spark.createDataset((0 until 10).map(DatasetTest(_))) import org.apache.spark.sql.functions.sum val res = df.agg(sum('y)).collect().head """) - strres = scala_kernel.interpret("res.getLong(0)") - result = scala_kernel.last_result() + strres = scala_interpreter.interpret("res.getLong(0)") + result = scala_interpreter.last_result() assert result == sum(range(10)) + + +def test_web_ui_url(scala_interpreter): + url = get_web_ui_url(scala_interpreter.sc) + assert url != ""