From e11e87cabc0bde8f265c6f9ecfac404306248e2f Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 26 Sep 2023 20:06:08 +0200 Subject: [PATCH] Add native support for BFloat16. --- base/boot.jl | 2 ++ src/abi_x86_64.cpp | 7 +++++-- src/aotcompile.cpp | 8 +++++-- src/ccall.cpp | 2 +- src/cgutils.cpp | 2 ++ src/codegen.cpp | 3 +++ src/intrinsics.cpp | 2 +- src/jitlayers.cpp | 8 ++++--- src/jl_exported_data.inc | 1 + src/jltypes.c | 2 ++ src/julia.h | 1 + src/julia_internal.h | 2 ++ src/runtime_intrinsics.c | 45 ++++++++++++++++++++++++++++++++++++++++ src/staticdata.c | 3 ++- 14 files changed, 78 insertions(+), 10 deletions(-) diff --git a/base/boot.jl b/base/boot.jl index 637b16e04c13e..7f7f4cf02422d 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -217,6 +217,8 @@ primitive type Float16 <: AbstractFloat 16 end primitive type Float32 <: AbstractFloat 32 end primitive type Float64 <: AbstractFloat 64 end +primitive type BFloat16 <: AbstractFloat 16 end + #primitive type Bool <: Integer 8 end abstract type AbstractChar end primitive type Char <: AbstractChar 32 end diff --git a/src/abi_x86_64.cpp b/src/abi_x86_64.cpp index c3d12417e6de8..5938e1e5778a2 100644 --- a/src/abi_x86_64.cpp +++ b/src/abi_x86_64.cpp @@ -118,7 +118,8 @@ struct Classification { void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const { // Floating point types - if (dt == jl_float64_type || dt == jl_float32_type) { + if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_float16_type || + dt == jl_bfloat16_type) { accum.addField(offset, Sse); } // Misc types @@ -239,7 +240,9 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const types[0] = Type::getIntNTy(ctx, nbits); break; case Sse: - if (size <= 4) + if (size <= 2) + types[0] = Type::getHalfTy(ctx); + else if (size <= 4) types[0] = Type::getFloatTy(ctx); else types[0] = Type::getDoubleTy(ctx); diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp index 3a54e2729ff5f..e3417a4c0dca1 100644 --- a/src/aotcompile.cpp +++ b/src/aotcompile.cpp @@ -497,7 +497,6 @@ static void reportWriterError(const ErrorInfoBase &E) jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str()); } -#if JULIA_FLOAT16_ABI == 1 static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT) { Function *target = M.getFunction(alias); @@ -514,7 +513,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT auto val = builder.CreateCall(target, CallArgs); builder.CreateRet(val); } -#endif + void multiversioning_preannotate(Module &M); // See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t. @@ -1061,6 +1060,11 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer #else emitFloat16Wrappers(M, false); #endif + + injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2", + FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false)); + injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2", + FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false)); } timers.optimize.stopTimer(); } diff --git a/src/ccall.cpp b/src/ccall.cpp index 1add621edde28..353c02490f438 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -1125,7 +1125,7 @@ std::string generate_func_sig(const char *fname) // see pull req #978. need to annotate signext/zeroext for // small integer arguments. jl_datatype_t *bt = (jl_datatype_t*)tti; - if (jl_datatype_size(bt) < 4 && bt != jl_float16_type) { + if (jl_datatype_size(bt) < 4) { if (jl_signed_type && jl_subtype(tti, (jl_value_t*)jl_signed_type)) ab.addAttribute(Attribute::SExt); else diff --git a/src/cgutils.cpp b/src/cgutils.cpp index c0916ef8a7076..d7355b1d9683c 100644 --- a/src/cgutils.cpp +++ b/src/cgutils.cpp @@ -658,6 +658,8 @@ static Type *bitstype_to_llvm(jl_value_t *bt, LLVMContext &ctxt, bool llvmcall = return getFloatTy(ctxt); if (bt == (jl_value_t*)jl_float64_type) return getDoubleTy(ctxt); + if (bt == (jl_value_t*)jl_bfloat16_type) + return getBFloatTy(ctxt); if (jl_is_llvmpointer_type(bt)) { jl_value_t *as_param = jl_tparam1(bt); int as; diff --git a/src/codegen.cpp b/src/codegen.cpp index edc3b614b2ccc..41a853a4b4ec3 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -125,6 +125,9 @@ auto getFloatTy(LLVMContext &ctxt) { auto getDoubleTy(LLVMContext &ctxt) { return Type::getDoubleTy(ctxt); } +auto getBFloatTy(LLVMContext &ctxt) { + return Type::getBFloatTy(ctxt); +} auto getFP128Ty(LLVMContext &ctxt) { return Type::getFP128Ty(ctxt); } diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index c7f1263af030a..57d7853ca6d70 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -165,7 +165,7 @@ static Type *INTT(Type *t, const DataLayout &DL) return getInt64Ty(ctxt); if (t == getFloatTy(ctxt)) return getInt32Ty(ctxt); - if (t == getHalfTy(ctxt)) + if (t == getHalfTy(ctxt) || t == getBFloatTy(ctxt)) return getInt16Ty(ctxt); unsigned nb = t->getPrimitiveSizeInBits(); assert(t != getVoidTy(ctxt) && nb > 0); diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index f0360c6addc95..6c356759cc066 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -1727,16 +1727,18 @@ JuliaOJIT::JuliaOJIT() ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); -#if JULIA_FLOAT16_ABI == 1 orc::SymbolAliasMap jl_crt = { +#if JULIA_FLOAT16_ABI == 1 { mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, { mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, - { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } } + { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }, +#endif + { mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } }, + { mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } }, }; cantFail(GlobalJD.define(orc::symbolAliases(jl_crt))); -#endif #ifdef MSAN_EMUTLS_WORKAROUND orc::SymbolMap msan_crt; diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 2acde218a104c..aa23b9d7b8205 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -42,6 +42,7 @@ XX(jl_float16_type) \ XX(jl_float32_type) \ XX(jl_float64_type) \ + XX(jl_bfloat16_type) \ XX(jl_floatingpoint_type) \ XX(jl_function_type) \ XX(jl_binding_type) \ diff --git a/src/jltypes.c b/src/jltypes.c index 998f3fe47f157..33b52158488a3 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3403,6 +3403,8 @@ void post_boot_hooks(void) //XX(float32); jl_float64_type = (jl_datatype_t*)core("Float64"); //XX(float64); + jl_bfloat16_type = (jl_datatype_t*)core("BFloat16"); + //XX(bfloat16); jl_floatingpoint_type = (jl_datatype_t*)core("AbstractFloat"); jl_number_type = (jl_datatype_t*)core("Number"); jl_signed_type = (jl_datatype_t*)core("Signed"); diff --git a/src/julia.h b/src/julia.h index 07f8459d37238..a357bdf558360 100644 --- a/src/julia.h +++ b/src/julia.h @@ -848,6 +848,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_uint64_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_float16_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_float32_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_float64_type JL_GLOBALLY_ROOTED; +extern JL_DLLIMPORT jl_datatype_t *jl_bfloat16_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_floatingpoint_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_number_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_void_type JL_GLOBALLY_ROOTED; // deprecated diff --git a/src/julia_internal.h b/src/julia_internal.h index 41f976b8585f3..9dff8e75cb2f5 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1663,6 +1663,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT; +JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT; +JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT; //JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT; //JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT; //JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT; diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index ed320aa9a6c35..287d22314c027 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -217,6 +217,51 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) return float_to_half(res); } +JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT +{ + uint16_t result; + + if (isnan(param)) + result = 0x7fc0; + else { + uint32_t bits = *((uint32_t*) ¶m); + + // round to nearest even + uint32_t bit_above_round = (bits >> 17) & 1; + uint32_t round_bit = (bits >> 16) & 1; + uint32_t sticky_bit = (bits & 0xFFFF) != 0; + if (round_bit && (sticky_bit || bit_above_round)) + bits += 0x10000; // Add 1 to bit just above the target bits + + result = (uint16_t)(bits >> 16); + } + + // on x86, bfloat16 needs to be returned in XMM. only GCC 13 provides the necessary ABI + // support in the form of the __bf16 type; older versions only provide __bfloat16 which + // is simply a typedef for short (i16). so use float, which is passed in XMM too. + uint32_t result_32bit = (uint32_t)result; + return *(float*)&result_32bit; +} + +JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT +{ + float res = (float)param; + uint32_t resi; + memcpy(&resi, &res, sizeof(res)); + + // Handle subnormals: If this logic is activated, it indicates that when we + // cast our double to a float, the float is a subnormal number. However, + // bfloat16 uses the same exponent as float32, so we don't need special handling + // for subnormals when truncating to bfloat16. + + if ((resi & 0x1ffu) == 0x100u) { // if we are halfway between 2 bfloat16 values + // adjust the value by 1 ULP in the direction that will make bfloat16(res) give the right answer + resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res)); + memcpy(&res, &resi, sizeof(res)); + } + return julia__truncsfbf2(res); +} + //JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); } //JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); } //JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); } diff --git a/src/staticdata.c b/src/staticdata.c index 536ca4cd6c3aa..df5652a5719c4 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -99,7 +99,7 @@ extern "C" { // TODO: put WeakRefs on the weak_refs list during deserialization // TODO: handle finalizers -#define NUM_TAGS 159 +#define NUM_TAGS 160 // An array of references that need to be restored from the sysimg // This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C. @@ -194,6 +194,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_float16_type); INSERT_TAG(jl_float32_type); INSERT_TAG(jl_float64_type); + INSERT_TAG(jl_bfloat16_type); INSERT_TAG(jl_floatingpoint_type); INSERT_TAG(jl_number_type); INSERT_TAG(jl_signed_type);