Skip to content

Commit

Permalink
[xtensa] widen ops, convert, division, gather_load improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Aelphy committed Jun 19, 2023
1 parent 69b254c commit 06e2720
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 17 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2484,6 +2484,8 @@ XTENSA_RUNTIME_SRC=$(ROOT_DIR)/src/runtime/alignment_128.cpp \
$(ROOT_DIR)/src/runtime/to_string.cpp \
$(ROOT_DIR)/src/runtime/posix_print.cpp \
$(ROOT_DIR)/src/runtime/posix_io.cpp \
$(ROOT_DIR)/src/runtime/posix_aligned_alloc.cpp \
$(ROOT_DIR)/src/runtime/posix_allocator.cpp \
$(ROOT_DIR)/src/runtime/xtensa_dma.cpp \

XTENSA_RUNTIME_OBJS=$(patsubst $(ROOT_DIR)/src/runtime/%,$(BIN_DIR)/%,$(patsubst %.cpp,%.o,$(XTENSA_RUNTIME_SRC)))
Expand Down
7 changes: 7 additions & 0 deletions src/CodeGen_Xtensa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ string CodeGen_Xtensa::print_xtensa_call(const Call *op) {
rhs << "IVP_ABSSUBUNX16U(" << args[0] + ", " + args[1] + ")";
}
return rhs.str();
} else if (op->name == "halide_xtensa_absd_u8") {
rhs << "IVP_ABSSUBU2NX8(" << args[0] + ", " + args[1] + ")";
return rhs.str();
} else if (op->name == "halide_xtensa_narrow_i48_with_shift_u16") {
rhs << "xb_vecNx16_rtor_xb_vecNx16U(IVP_PACKVRNRNX48(" << args[0] + ", " + args[1] + "))";
return rhs.str();
Expand Down Expand Up @@ -465,6 +468,10 @@ void CodeGen_Xtensa::visit(const Div *op) {
ostringstream rhs;
rhs << "IVP_DIVN_2XF32(" << print_expr(op->a) << ", " << print_expr(op->b) << ")";
print_assignment(op->type, rhs.str());
} else if (is_native_xtensa_vector<uint32_t>(op->type)) {
string sa = print_expr(op->a);
string sb = print_expr(op->b);
print_assignment(op->type, "halide_xtensa_div32(" + sa + ", " + sb + ")");
} else {
string sa = print_expr(op->a);
string sb = print_expr(op->b);
Expand Down
141 changes: 129 additions & 12 deletions src/CodeGen_Xtensa_vectors.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2212,8 +2212,9 @@ convert<native_vector_i8, native_vector_u16_x2>(const native_vector_u16_x2 &src)

template<>
HALIDE_ALWAYS_INLINE native_vector_u8 convert<native_vector_u8, native_vector_i16_x2>(const native_vector_i16_x2 &src) {
xb_vec2Nx24 wide = IVP_CVT24S2NX16(src.native_vector[1], src.native_vector[0]);
return xb_vec2Nx8_rtor_xb_vec2Nx8U(IVP_PACKL2NX24(wide));
return IVP_SEL2NX8UI(IVP_MOV2NX8U_FROMNX16(src.native_vector[1]),
IVP_MOV2NX8U_FROMNX16(src.native_vector[0]),
IVP_SELI_8B_EXTRACT_1_OF_2_OFF_0);
}

template<>
Expand Down Expand Up @@ -2367,12 +2368,12 @@ HALIDE_ALWAYS_INLINE native_vector_i32_x4 convert<native_vector_i32_x4, native_v
template<>
HALIDE_ALWAYS_INLINE native_vector_i32_x2
convert<native_vector_i32_x2, native_vector_i16>(const native_vector_i16 &src) {
const native_vector_i32 m = native_vector_i32(1U << (16 - 1));
native_vector_i32 x1 = IVP_MOVN_2X32_FROMNX16(
IVP_SELNX16I(native_vector_i16(0), src, IVP_SELI_16B_INTERLEAVE_1_LO));
native_vector_i32 x2 = IVP_MOVN_2X32_FROMNX16(
IVP_SELNX16I(native_vector_i16(0), src, IVP_SELI_16B_INTERLEAVE_1_HI));
return native_vector_i32_x2(native_vector_i32_x2::from_native_vector, (x1 ^ m) - m, (x2 ^ m) - m);
native_vector_i16 sign_val = src >> 15;
return native_vector_i32_x2(native_vector_i32_x2::from_native_vector,
IVP_MOVN_2X32_FROMNX16(
IVP_SELNX16UI(sign_val, src, IVP_SELI_16B_INTERLEAVE_1_LO)),
IVP_MOVN_2X32_FROMNX16(
IVP_SELNX16UI(sign_val, src, IVP_SELI_16B_INTERLEAVE_1_HI)));
}

template<>
Expand Down Expand Up @@ -2717,13 +2718,11 @@ HALIDE_ALWAYS_INLINE native_vector_u8 halide_xtensa_convert_concat_i16_to_u8(con
}

HALIDE_ALWAYS_INLINE native_vector_i8 halide_xtensa_convert_concat_u16_to_i8(const native_vector_u16 &a, const native_vector_u16 &b) {
xb_vec2Nx24 wide = IVP_CVT24U2NX16(xb_vecNx16U_rtor_xb_vecNx16(b), xb_vecNx16U_rtor_xb_vecNx16(a));
return IVP_PACKL2NX24(wide);
return IVP_SEL2NX8I(IVP_MOV2NX8_FROMNX16(b), IVP_MOV2NX8_FROMNX16(a), IVP_SELI_8B_EXTRACT_1_OF_2_OFF_0);
}

HALIDE_ALWAYS_INLINE native_vector_u8 halide_xtensa_convert_concat_u16_to_u8(const native_vector_u16 &a, const native_vector_u16 &b) {
xb_vec2Nx24 wide = IVP_CVT24U2NX16(xb_vecNx16U_rtor_xb_vecNx16(b), xb_vecNx16U_rtor_xb_vecNx16(a));
return xb_vec2Nx8_rtor_xb_vec2Nx8U(IVP_PACKL2NX24(wide));
return IVP_SEL2NX8UI(IVP_MOV2NX8_FROMNX16(b), IVP_MOV2NX8_FROMNX16(a), IVP_SELI_8B_EXTRACT_1_OF_2_OFF_0);
}

HALIDE_ALWAYS_INLINE native_vector_i16 halide_xtensa_convert_i8_low_i16(const native_vector_i8 &src, int native_lanes, int total_lines) {
Expand Down Expand Up @@ -2919,3 +2918,121 @@ HALIDE_ALWAYS_INLINE HALIDE_MAYBE_UNUSED native_vector_f32_x2 gather_load<native
IVP_GATHERDN_2XF32(gsr0),
IVP_GATHERDN_2XF32(gsr1));
}

HALIDE_ALWAYS_INLINE native_vector_u16
halide_xtensa_mul_add_u16(const native_vector_u16 &a, const native_vector_u16 &b, const native_vector_u16 &c) {
native_vector_u16 r = a;
IVP_MULANX16UPACKL(r, b, c);
return r;
}

HALIDE_ALWAYS_INLINE native_vector_i24
halide_xtensa_widen_add_u24(const native_vector_u8 &a, const native_vector_u8 &b) {
native_vector_i24 r ;
r = IVP_ADDWU2NX8U(a, b);
return r;
}

HALIDE_ALWAYS_INLINE native_vector_i24
halide_xtensa_widen_accum_u24(const native_vector_i24 &a, const native_vector_u8 &b) {
native_vector_i24 r = a;
IVP_ADDWUA2NX8U(r, b, native_vector_u8(0));
return r;
}

template<>
HALIDE_ALWAYS_INLINE native_vector_u8
convert<native_vector_u8, native_vector_u32_x4>(const native_vector_u32_x4 &src) {
xb_vec2Nx24 wide = IVP_CVT24UNX32L(src.native_vector[1], src.native_vector[0]);
IVP_CVT24UNX32H(wide, src.native_vector[3], src.native_vector[2]);
return IVP_PACKL2NX24(wide);
}

template<>
HALIDE_ALWAYS_INLINE native_vector_u32_x4
convert<native_vector_u32_x4, native_vector_i24>(const native_vector_i24 &src) {
return native_vector_u32_x4(native_vector_u32_x4::from_native_vector, IVP_CVT32S2NX24LL(src), IVP_CVT32S2NX24LH(src),
IVP_CVT32S2NX24HL(src), IVP_CVT32S2NX24HH(src));
}

HALIDE_ALWAYS_INLINE native_vector_u32
halide_xtensa_div_32_by_low16_of_32(native_vector_u32& a, native_vector_u32& b) {
native_vector_u32 quotient, remainder;
IVP_DIVN_2X32X16U(quotient, remainder, a, IVP_MOVNX16_FROMN_2X32(b), 0);
return quotient;
}

HALIDE_ALWAYS_INLINE native_vector_u32
halide_xtensa_div32(native_vector_u32 dividend, native_vector_u32 divisor) {
xb_vecN_2x32Uv nsa;
xb_vecNx16U vec_divisor;
xb_vecN_2x32Uv quotent;
xb_vecN_2x32Uv reminder;
vboolN_2 predicate;

nsa = IVP_NSAUN_2X32U(divisor);
predicate = IVP_LTUN_2X32U(16, nsa);
nsa = IVP_MOVN_2X32UT(0, (xb_vecN_2x32Uv)16 - nsa, predicate);
xb_vecN_2x32Uv divisor_nsa = IVP_SRLN_2X32U(divisor, nsa);

vec_divisor = IVP_MOVNX16_FROMN_2X32U(divisor_nsa);
IVP_DIVN_2X32X16U(quotent, reminder, dividend, vec_divisor, 0);
quotent = IVP_SRLN_2X32U(quotent, nsa);

xb_vecN_2x64w dividend_wide = IVP_MULUUN_2X16X32_0(IVP_MOVNX16_FROMN_2X32U(quotent), divisor);
xb_vecN_2x32Uv dividend_tmp = IVP_PACKLN_2X96(dividend_wide);
predicate = IVP_LTUN_2X32U(dividend, dividend_tmp);
IVP_SUBN_2X32UT(quotent, quotent, 1, predicate);
return quotent;
}

HALIDE_ALWAYS_INLINE native_vector_u16
halide_xtensa_narrow_with_rounding_shift_u16(const native_vector_u32_x2 &a, uint32_t shift) {
xb_vecNx48 wide = convert<native_vector_i48, native_vector_u32_x2>(a);
// Add rounding factor.
native_vector_u16 v1 = IVP_SLLNX16U(1, (shift - 1));
IVP_MULUUANX16(wide, v1, 1);
return xb_vecNx16_rtor_xb_vecNx16U(IVP_PACKVRNRNX48(wide, shift));
}

HALIDE_ALWAYS_INLINE native_vector_u16
halide_xtensa_narrow_i48_with_rounding_shift_u16(const native_vector_i48 &a, uint32_t shift) {
xb_vecNx48 wide = a;
if (15 == shift) {
return IVP_PACKQNX48(a);
}
// Add rounding factor.
native_vector_u16 v1 = IVP_SLLNX16U(1, (shift - 1));
IVP_MULUUANX16(wide, v1, 1);
return xb_vecNx16_rtor_xb_vecNx16U(IVP_PACKVRNRNX48(wide, shift));
}

HALIDE_ALWAYS_INLINE native_vector_i48
halide_xtensa_widen_mul_sub_i48(const native_vector_i48 &a, const native_vector_i16 &b, const native_vector_i16 &c) {
native_vector_i48 r = a;
IVP_MULSNX16(r, b, c);
return r;
}

template<>
HALIDE_ALWAYS_INLINE HALIDE_MAYBE_UNUSED native_vector_u8
gather_load<native_vector_u8, native_vector_i16_x2, uint8_t, VECTOR_WIDTH_U8, true>(const void *base, const native_vector_i16_x2& offset) {
auto addresses1 = xb_vecNx16_rtor_xb_vecNx16U(offset.native_vector[0]);
auto output1 = IVP_GATHERDNX8U(
IVP_GATHERANX8U(
(const uint8_t*) base,
(addresses1)
)
);

auto addresses2 = xb_vecNx16_rtor_xb_vecNx16U(offset.native_vector[1]);
auto output2 = IVP_GATHERDNX8U(
IVP_GATHERANX8U(
(const uint8_t*) base,
(addresses2)
)
);

// NOTE(aelphy): the intrinsic for gathering 8-bit elements extends them to 16-bit, and the conversion back to 8-bit is needed
return convert<native_vector_u8, native_vector_u16_x2>(native_vector_u16_x2(native_vector_u16_x2::from_native_vector, output1, output2));
}
62 changes: 57 additions & 5 deletions src/XtensaOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,33 @@ class MatchXtensaPatterns : public IRGraphMutator {
return call;
}

static Expr halide_xtensa_widen_add_u24(Expr v0, Expr v1) {
Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_add_u24", {std::move(v0), std::move(v1)}, Call::PureExtern);
return call;
}

static Expr halide_xtensa_widen_accum_u24(Expr v0, Expr v1) {
Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_accum_u24", {std::move(v0), std::move(v1)}, Call::PureExtern);
return call;
}

static Expr halide_xtensa_widen_mul_add_u24(Expr v0, Expr v1, Expr v2) {
Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_mul_add_u24", {std::move(v0), std::move(v1), std::move(v2)}, Call::PureExtern);
return call;
}

static Expr halide_xtensa_widen_pair_mul_add_u24(Expr w, Expr v0, Expr v1, Expr v2, Expr v3) {
Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_pair_mul_add_u24",
{std::move(w), std::move(v0), std::move(v1), std::move(v2), std::move(v3)},
Call::PureExtern);
return call;
}

static Expr halide_xtensa_widen_mul_sub_i48(Expr v0, Expr v1, Expr v2) {
Expr call = Call::make(wild_i48x.type(), "halide_xtensa_widen_mul_sub_i48", {std::move(v0), std::move(v1), std::move(v2)}, Call::PureExtern);
return call;
}

Expr visit(const Add *op) override {
if (op->type.is_vector()) {
static const std::vector<Pattern> adds = {
Expand Down Expand Up @@ -631,7 +658,7 @@ class MatchXtensaPatterns : public IRGraphMutator {
wild_i24x + call("halide_xtensa_widen_mul_i24", wild_i24x, {wild_i8x, wild_i8x})},

{"halide_xtensa_widen_quad_mul_add_i24",
wild_i24x + call("halide_xtensa_widen_quad_mul_i24", wild_i24x, {wild_i8x, wild_i8x, wild_i8x, wild_i8x, wild_i8x})},
wild_i24x + call("halide_xtensa_widen_quad_mul_i24", wild_i24x, {wild_i8x, wild_i8x, wild_i8x, wild_i8x, wild_i8})},

// Add to accumulator type.
// Paired add.
Expand All @@ -651,6 +678,14 @@ class MatchXtensaPatterns : public IRGraphMutator {
{"halide_xtensa_widen_mul_add_i64", widening_mul(wild_i32x, wild_i32x) + bc(wild_i64), Pattern::NarrowOp2 | Pattern::AccumulatorOutput64},
{"halide_xtensa_widen_mul_add_i64", widening_mul(wild_i32x, wild_i32x) + wild_i64x, Pattern::NarrowOp2 | Pattern::AccumulatorOutput64},
{"halide_xtensa_widen_mul_add_i64", i32(wild_i64x) + i32(call("halide_xtensa_mul_i32", wild_i64x, {wild_i32x, wild_i32x})), Pattern::AccumulatorOutput64},

{"halide_xtensa_widen_pair_mul_add_u24", i16(halide_xtensa_widen_mul_add_u24(wild_i24x, wild_u8x, wild_u8x)) + i16(halide_xtensa_widen_mul_u24(wild_u8x, wild_u8x)), Pattern::AccumulatorOutput24},
{"halide_xtensa_widen_pair_mul_add_u24", halide_xtensa_widen_mul_add_u24(wild_i24x, wild_u8x, wild_u8x) + halide_xtensa_widen_mul_u24(wild_u8x, wild_u8x)},

{"halide_xtensa_mul_add_u16", wild_u16x + wild_u16x*wild_u16x},

{"halide_xtensa_widen_add_u24", i24(wild_u8x) + i24(wild_u8x) , Pattern::AccumulatorOutput24},
{"halide_xtensa_widen_accum_u24", wild_i24x + i24(wild_u8x) , Pattern::AccumulatorOutput24},
};

Expr new_expr = apply_commutative_patterns(op, adds, this);
Expand All @@ -673,6 +708,8 @@ class MatchXtensaPatterns : public IRGraphMutator {
// {"halide_xtensa_pred_sub_i16", wild_i16x - select(wild_u1x, wild_i16x, wild_i16x)},
// {"halide_xtensa_pred_sub_i32", wild_i32x - select(wild_u1x, wild_i32x, wild_i32x)},
{"halide_xtensa_widen_mul_sub_u24", wild_i24x - halide_xtensa_widen_mul_u24(wild_u8x, wild_u8x)},
{"halide_xtensa_widen_mul_sub_i48", i32(wild_i48x) - i32(halide_xtensa_widen_mul_i48(wild_i16x, wild_i16x)), Pattern::AccumulatorOutput48},
{"halide_xtensa_widen_mul_sub_i48", wild_i48x - halide_xtensa_widen_mul_i48(wild_i16x, wild_i16x)},
};

Expr new_expr = apply_patterns(op, subs, this);
Expand Down Expand Up @@ -868,6 +905,7 @@ class MatchXtensaPatterns : public IRGraphMutator {
{"halide_xtensa_convert_concat_i32_to_u16", u16(halide_xtensa_concat_from_native_i32(wild_i32x, wild_i32x))},
{"halide_xtensa_convert_concat_u32_to_i16", i16(halide_xtensa_concat_from_native_u32(wild_u32x, wild_u32x))},
{"halide_xtensa_convert_concat_u32_to_u16", u16(halide_xtensa_concat_from_native_u32(wild_u32x, wild_u32x))},
{"halide_xtensa_narrow_with_rounding_shift_u16", u16(rounding_shift_right(wild_u32x, bc(wild_u32)))},
};
if (op->type.is_vector()) {
Expr cast = op;
Expand Down Expand Up @@ -952,11 +990,18 @@ class MatchXtensaPatterns : public IRGraphMutator {
// that they generate.
internal_assert(op->args.size() == 3);
return mutate(lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target));
} else if (op->is_intrinsic(Call::absd) && op->type.is_vector() && op->type.is_uint() && (op->type.bits() == 16)) {
} else if (op->is_intrinsic(Call::absd) && op->type.is_vector() && op->type.is_uint()) {
internal_assert(op->args.size() == 2);
return Call::make(op->type, "halide_xtensa_absd_i16",

if (op->type.bits() == 16) {
return Call::make(op->type, "halide_xtensa_absd_i16",
{mutate(op->args[0]), mutate(op->args[1])},
Call::PureExtern);
} else if (op->type.bits() == 8) {
return Call::make(op->type, "halide_xtensa_absd_u8",
{mutate(op->args[0]), mutate(op->args[1])},
Call::PureExtern);
}
} else if (op->is_intrinsic(Call::widening_shift_left)) {
// Replace widening left shift with multiplication.
const uint64_t *c = as_const_uint(op->args[1]);
Expand Down Expand Up @@ -1069,8 +1114,7 @@ class MatchXtensaPatterns : public IRGraphMutator {
{"halide_xtensa_widen_quad_mul_add_i24",
call("halide_xtensa_widen_pair_mul_add_i24", wild_i24x, {call("halide_xtensa_widen_pair_mul_add_i24", wild_i24x, {wild_i24x, wild_i8x, wild_i8, wild_i8x, wild_i8}), wild_i8x, wild_i8, wild_i8x, wild_i8})},
{"halide_xtensa_widen_pair_mul_add_i24",
call("halide_xtensa_widen_mul_add_i24", wild_i24x, {call("halide_xtensa_widen_mul_add_i24", wild_i24x, {wild_i24x, wild_i8x, wild_i8}), wild_i8x, wild_i8})},

call("halide_xtensa_widen_mul_add_i24", wild_i24x, {call("halide_xtensa_widen_mul_add_i24", wild_i24x, {wild_i24x, wild_i8x, wild_i8x}), wild_i8x, wild_i8x})},
{"halide_xtensa_widen_pair_mul_add_i48",
call("halide_xtensa_widen_mul_add_i48", wild_i48x,
{call("halide_xtensa_widen_mul_add_i48", wild_i48x, {wild_i48x, wild_i16x, wild_i16x}), wild_i16x, wild_i16x})},
Expand Down Expand Up @@ -1115,6 +1159,14 @@ class MatchXtensaPatterns : public IRGraphMutator {
{"halide_xtensa_narrow_i48_with_shift_i32", i32(wild_i48x) >> wild_i32},
{"halide_xtensa_narrow_i48_with_shift_u32", u32(wild_i48x) >> wild_u32},

{"halide_xtensa_widen_add_u24", widening_add(wild_u8x, wild_u8x), Pattern::AccumulatorOutput24},
{"halide_xtensa_widen_accum_u24", widening_add(wild_i24x, wild_u8x), Pattern::AccumulatorOutput24},

{"halide_xtensa_widen_pair_mul_add_u24",
call("halide_xtensa_widen_mul_add_u24", wild_i24x,
{call("halide_xtensa_widen_mul_add_u24", wild_i24x, {wild_i24x, wild_u8x, wild_u8x}), wild_u8x, wild_u8x})},
{"halide_xtensa_narrow_i48_with_rounding_shift_u16", call("halide_xtensa_narrow_with_rounding_shift_u16", wild_u16x, {u32(wild_i48x), wild_u32})},

// Predicated saturated add/sub.
// NOTE(vksnk): patterns below are for predicated instructions and look like they may
// be more efficient, but they are not according to simulator. We will need to check with
Expand Down

0 comments on commit 06e2720

Please sign in to comment.