Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: reduce with multiple operands failed to legalize #21384

Open
jonatanklosko opened this issue May 23, 2024 · 2 comments
Open

jax-metal: reduce with multiple operands failed to legalize #21384

jonatanklosko opened this issue May 23, 2024 · 2 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jonatanklosko
Copy link

Description

import jax
import jax.numpy as jnp

def f(x):
  def reducer(op_val_index, acc_val_index):
    op_val, op_index = op_val_index
    acc_val, acc_index = acc_val_index
    return (op_val + acc_val, op_index + acc_index)

  idx = jax.lax.broadcasted_iota(jnp.int32, jnp.shape(x), 0)
  return jax.lax.reduce([x, idx], [jnp.array(0, x.dtype), jnp.array(0, idx.dtype)], reducer, [0])

x = jnp.array([1, 2])

print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.iota dim = 0 : tensor<2xi32>
    %1 = stablehlo.constant dense<0> : tensor<i32>
    %2 = stablehlo.constant dense<0> : tensor<i32>
    %3:2 = stablehlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi32>, tensor<2xi32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
     reducer(%arg1: tensor<i32>, %arg3: tensor<i32>) (%arg2: tensor<i32>, %arg4: tensor<i32>)  {
      %4 = stablehlo.add %arg1, %arg3 : tensor<i32>
      %5 = stablehlo.add %arg2, %arg4 : tensor<i32>
      stablehlo.return %4, %5 : tensor<i32>, tensor<i32>
    }
    return %3#0, %3#1 : tensor<i32>, tensor<i32>
  }
}

This fails with:

Traceback (most recent call last):
  File "/Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py", line 16, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py:11:0: error: failed to legalize operation 'mhlo.reduce'
/Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py:15:0: note: called from
/Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py:11:0: note: see current operation:
%5:2 = "mhlo.reduce"(%arg0, %4, %1, %1) ({
^bb0(%arg1: tensor<si32>, %arg2: tensor<si32>, %arg3: tensor<si32>, %arg4: tensor<si32>):
  %6 = "mhlo.add"(%arg1, %arg3) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  %7 = "mhlo.add"(%arg2, %arg4) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  "mhlo.return"(%6, %7) : (tensor<si32>, tensor<si32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xsi32>, tensor<2xsi32>, tensor<si32>, tensor<si32>) -> (tensor<si32>, tensor<si32>)

Interestingly jnp.argmax works and it is lowered to similar reduce on operand and index (just more elaborate).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

(I also tried with jax/jaxlib 0.4.28 and ENABLE_PJRT_COMPATIBILITY=1, but same result)

@jonatanklosko jonatanklosko added the bug Something isn't working label May 23, 2024
@shuhand0
Copy link
Collaborator

jax-metal (and its backend) don't yet support reducer with custom computing functions, neither multi-operands. Argmax/argmin are mapped to the corresponding backend ops as special cases.

@jonatanklosko
Copy link
Author

I see, and so I assume for reduce with single operand there are also special cases for +, * and similar, that makes sense. I saw x, y -> x + y + 1 ignore the + 1 and that's what I thought :D

For context, we try to integrate the metal plugin in Nx (Elixir project that uses XLA similarly to Jax). We implement argmax/argmin on top of reduce, but the IR does not match Jax exactly. I may try to align the IR in the meantime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants