Skip to content

Commit

Permalink
Merge pull request #11 from mariusvniekerk/doccleanup
Browse files Browse the repository at this point in the history
Added some more documentation and added a method to retrieve the webui url.
  • Loading branch information
parente committed Jan 31, 2017
2 parents 58eebe3 + 40a307c commit 57669cf
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 84 deletions.
147 changes: 104 additions & 43 deletions spylon_kernel/scala_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,43 @@ def init_spark_session(conf: spylon.spark.SparkConfiguration=None, application_n
spark_jvm_helpers = SparkJVMHelpers(spark_session._sc)


def get_web_ui_url(sc):
"""Get the web ui for a spark context
Parameters
----------
sc : SparkContext
Returns
-------
url : str
"""
# Dig into the java spark conf to actually be able to resolve the spark configuration
# noinspection PyProtectedMember
conf = sc._jsc.getConf()
if conf.getBoolean("spark.ui.reverseProxy", False):
proxy_url = conf.get("spark.ui.reverseProxyUrl", "")
if proxy_url:
web_ui_url = "Spark Context Web UI is available at ${proxy_url}/proxy/${application_id}".format(
proxy_url=proxy_url, application_id=sc.applicationId)
else:
web_ui_url = "Spark Context Web UI is available at Spark Master Public URL"
else:
# For spark 2.0 compatibility we have to retrieve this from the scala side.
joption = sc._jsc.sc().uiWebUrl()
if joption.isDefined():
web_ui_url = joption.get()
else:
web_ui_url = ""

# Legacy compatible version for YARN
yarn_proxy_spark_property = "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES"
if sc.master.startswith("yarn"):
web_ui_url = conf.get(yarn_proxy_spark_property)

return web_ui_url


# noinspection PyProtectedMember
def initialize_scala_interpreter():
"""
Expand Down Expand Up @@ -103,12 +140,8 @@ def cleanup():
def start_imain():
intp = jvm.scala.tools.nsc.interpreter.IMain(settings, jprintWriter)
intp.initializeSynchronous()
# TODO : Redirect stdout / stderr to a known pair of files that we can watch.
"""
System.setOut(new PrintStream(new File("output-file.txt")));
"""

# Copied directly from Spark
# Ensure that sc and spark are bound in the interpreter context.
intp.interpret("""
@transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) {
org.apache.spark.repl.Main.sparkSession
Expand All @@ -117,21 +150,6 @@ def start_imain():
}
@transient val sc = {
val _sc = spark.sparkContext
if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) {
val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null)
if (proxyUrl != null) {
println(s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}")
} else {
println(s"Spark Context Web UI is available at Spark Master Public URL")
}
} else {
_sc.uiWebUrl.foreach {
webUrl => println(s"Spark context Web UI available at ${webUrl}")
}
}
println("Spark context available as 'sc' " +
s"(master = ${_sc.master}, app id = ${_sc.applicationId}).")
println("Spark session available as 'spark'.")
_sc
}
""")
Expand All @@ -143,7 +161,6 @@ def start_imain():
return intp

imain = start_imain()

return SparkInterpreter(jvm, imain, bytes_out)


Expand All @@ -161,32 +178,52 @@ def __init__(self, scala_message, *args, **kwargs):

tOutputHandler = Callable[[List[Any]], None]


class SparkInterpreter(object):
"""Wrapper for a scala interpreter.
Notes
-----
Users should not instantiate this class themselves. Use `get_scala_interpreter` instead.
Parameters
----------
jvm : py4j.java_gateway.JVMView
jimain : py4j.java_gateway.JavaObject
Java object representing an instance of `scala.tools.nsc.interpreter.IMain`
jbyteout : py4j.java_gateway.JavaObject
Java object representing an instance of `org.apache.commons.io.output.ByteArrayOutputStream`
This is used to return output data from the REPL.
loop : asyncio.AbstractEventLoop, optional
Asyncio eventloop
"""
executor = ThreadPoolExecutor(4)

def __init__(self, jvm, jiloop, jbyteout, loop: Union[None, asyncio.AbstractEventLoop]=None ):
def __init__(self, jvm, jimain, jbyteout, loop: Union[None, asyncio.AbstractEventLoop]=None):
self.spark_session = spark_session
# noinspection PyProtectedMember
self.sc = spark_session._sc
self.web_ui_url = get_web_ui_url(self.sc)
self._jcompleter = None
self.jvm = jvm
self.jiloop = jiloop
self.jimain = jimain
if loop is None:
# TODO: We may want to use new_event_loop here to avoid stopping and starting the main one.
loop = asyncio.get_event_loop()
self.loop = loop
self.log = logging.getLogger(self.__class__.__name__)

interpreterPkg = getattr(getattr(self.jvm.scala.tools.nsc.interpreter, 'package$'), "MODULE$")
# spark_jvm_helpers.import_scala_package_object("scala.tools.nsc.interpreter")
self.iMainOps = interpreterPkg.IMainOps(jiloop)
jinterpreter_package = getattr(getattr(self.jvm.scala.tools.nsc.interpreter, 'package$'), "MODULE$")
self.iMainOps = jinterpreter_package.IMainOps(jimain)
self.jbyteout = jbyteout

tempdir = tempfile.mkdtemp()
atexit.register(shutil.rmtree, tempdir, True)
self.tempdir = tempdir
# Handlers for dealing with stout and stderr. This allows us to insert additional behavior for magics
self._stdout_handlers = []
# self.register_stdout_handler(lambda *args: print(*args, file=sys.stdout))
self._stderr_handlers = []
# self.register_stderr_handler(lambda *args: print(*args, file=sys.stderr))
self._initialize_stdout_err()

def register_stdout_handler(self, handler: tOutputHandler):
Expand All @@ -200,7 +237,8 @@ def _initialize_stdout_err(self):
stderr_file = os.path.abspath(os.path.join(self.tempdir, 'stderr'))
# Start up the pipes on the JVM side

self.log.critical("Before Java redirected")
self.log.info("Before Java redirected")
self.log.debug("stdout/err redirected to %s", self.tempdir)
code = 'Console.set{pipe}(new PrintStream(new FileOutputStream(new File(new java.net.URI("{filename}")), true)))'
code = '\n'.join([
'import java.io.{PrintStream, FileOutputStream, File}',
Expand All @@ -209,7 +247,7 @@ def _initialize_stdout_err(self):
code.format(pipe="Err", filename=pathlib.Path(stderr_file).as_uri())
])
o = self.interpret(code)
self.log.critical("Console redirected")
self.log.info("Console redirected, %s", o)

self.loop.create_task(self._poll_file(stdout_file, self.handle_stdout))
self.loop.create_task(self._poll_file(stderr_file, self.handle_stderr))
Expand All @@ -235,14 +273,14 @@ async def _poll_file(self, filename: str, fn: Callable[[Any], None]):
while True:
line = fd.readline()
if line:
# self.log.critical("READ LINE from %s, %s", filename, line)
# processing a line from the file and running our processing function.
fn(line)
# self.log.critical("AFTER PUSH")
await asyncio.sleep(0, loop=self.loop)
else:
await asyncio.sleep(0.01, loop=self.loop)

def interpret_sync(self, code: str, synthetic=False):
def _interpret_sync(self, code: str, synthetic=False):
"""Interpret a block of scala code.
If you want to get the result as a python object, follow this will a call to `last_result()`
Expand All @@ -258,7 +296,7 @@ def interpret_sync(self, code: str, synthetic=False):
String output from the scala REPL.
"""
try:
res = self.jiloop.interpret(code, synthetic)
res = self.jimain.interpret(code, synthetic)
pyres = self.jbyteout.toByteArray().decode("utf-8")
# The scala interpreter returns a sentinel case class member here which is typically matched via
# pattern matching. Due to it having a very long namespace, we just resort to simple string matching here.
Expand All @@ -273,17 +311,38 @@ def interpret_sync(self, code: str, synthetic=False):
finally:
self.jbyteout.reset()

async def interpret_async(self, code: str, future: Future):
async def _interpret_async(self, code: str, future: Future):
"""Async execute for running a block of scala code.
Parameters
----------
code : str
future : Future
future used to hold the result of the computation.
"""
try:
result = await self.loop.run_in_executor(self.executor, self.interpret_sync, code)
result = await self.loop.run_in_executor(self.executor, self._interpret_sync, code)
future.set_result(result)
except Exception as e:
future.set_exception(e)
return

def interpret(self, code: str):
"""Interpret a block of scala code.
If you want to get the result as a python object, follow this will a call to `last_result()`
Parameters
----------
code : str
Returns
-------
reploutput : str
String output from the scala REPL.
"""
fut = asyncio.Future(loop=self.loop)
asyncio.ensure_future(self.interpret_async(code, fut), loop=self.loop)
asyncio.ensure_future(self._interpret_async(code, fut), loop=self.loop)
res = self.loop.run_until_complete(fut)
return res

Expand All @@ -298,17 +357,19 @@ def last_result(self):
object
"""
# TODO : when evaluating multiline expressions this returns the first result
lr = self.jiloop.lastRequest()
lr = self.jimain.lastRequest()
res = lr.lineRep().call("$result", spark_jvm_helpers.to_scala_list([]))
return res

def bind(self, name: str, value: Any, jtyp: str="Any"):
"""
"""Set a variable in the scala repl environment to a python valued type.
Parameters
----------
varname : str
name : str
value : Any
jtyp : str
String representation of the Java type that we want to cast this as.
"""
modifiers = spark_jvm_helpers.to_scala_list(["@transient"])
Expand All @@ -319,13 +380,13 @@ def bind(self, name: str, value: Any, jtyp: str="Any"):
int, str, bytes, bool, list, dict, JavaClass, JavaMember, JavaObject
)
if isinstance(value, compatible_types):
self.jiloop.bind(name, "Any", value, modifiers)
self.jimain.bind(name, "Any", value, modifiers)

@property
def jcompleter(self):
if self._jcompleter is None:
jClass = self.jvm.scala.tools.nsc.interpreter.PresentationCompilerCompleter
self._jcompleter = jClass(self.jiloop)
self._jcompleter = jClass(self.jimain)
return self._jcompleter

def complete(self, code: str, pos: int) -> List[str]:
Expand Down Expand Up @@ -359,7 +420,7 @@ def is_complete(self, code):
One of 'complete', 'incomplete' or 'invalid'
"""
try:
res = self.jiloop.parse().apply(code)
res = self.jimain.parse().apply(code)
output_class = res.getClass().getName()
_, status = output_class.rsplit("$", 1)
if status == 'Success':
Expand Down Expand Up @@ -398,7 +459,7 @@ def get_help_on(self, info):
return scala_type[-1]

def printHelp(self):
return self.jiloop.helpSummary()
return self.jimain.helpSummary()


def get_scala_interpreter():
Expand Down
7 changes: 4 additions & 3 deletions spylon_kernel/scala_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from spylon_kernel.scala_interpreter import ScalaException, SparkInterpreter
from .init_spark_magic import InitSparkMagic
from .scala_magic import ScalaMagic
from ._version import get_versions


class SpylonKernel(MetaKernel):
implementation = 'spylon-kernel'
implementation_version = '1.0'
implementation_version = get_versions()['version']
language = 'scala'
language_version = '0.1'
language_version = '2.11'
banner = "spylon-kernel - evaluates Scala statements and expressions."
language_info = {
'mimetype': 'text/x-scala',
Expand Down Expand Up @@ -54,6 +55,7 @@ def pythonmagic(self):

@property
def scala_interpreter(self):
# noinspection PyProtectedMember
intp = self._scalamagic._get_scala_interpreter()
assert isinstance(intp, SparkInterpreter)
return intp
Expand Down Expand Up @@ -82,7 +84,6 @@ def get_variable(self, name):
return intp.last_result()

def do_execute_direct(self, code, silent=False):

try:
res = self._scalamagic.eval(code.strip(), raw=False)
if res:
Expand Down
Loading

0 comments on commit 57669cf

Please sign in to comment.