Skip to content

Commit

Permalink
Add native support for BFloat16. (#51470)
Browse files Browse the repository at this point in the history
This PR adds native support for the LLVM `bfloat` type, through a new
`BFloat16` type. It doesn't add any language-level functionality,
only the bare minimum support (e.g. runtime conversion routines).
Use of the BFloat16s.jl package is still required to use BFloat16 values.
  • Loading branch information
maleadt committed Oct 6, 2023
1 parent 20a5fa7 commit 5487046
Show file tree
Hide file tree
Showing 21 changed files with 188 additions and 34 deletions.
2 changes: 2 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ primitive type Float16 <: AbstractFloat 16 end
primitive type Float32 <: AbstractFloat 32 end
primitive type Float64 <: AbstractFloat 64 end

primitive type BFloat16 <: AbstractFloat 16 end

#primitive type Bool <: Integer 8 end
abstract type AbstractChar end
primitive type Char <: AbstractChar 32 end
Expand Down
3 changes: 2 additions & 1 deletion doc/src/base/reflection.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ the abstract `DataType` [`AbstractFloat`](@ref) has four (concrete) subtypes:

```jldoctest; setup = :(using InteractiveUtils)
julia> subtypes(AbstractFloat)
4-element Vector{Any}:
5-element Vector{Any}:
BigFloat
Core.BFloat16
Float16
Float32
Float64
Expand Down
6 changes: 4 additions & 2 deletions src/abi_x86_64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct Classification {
void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const
{
// Floating point types
if (dt == jl_float64_type || dt == jl_float32_type) {
if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_bfloat16_type) {
accum.addField(offset, Sse);
}
// Misc types
Expand Down Expand Up @@ -239,7 +239,9 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const
types[0] = Type::getIntNTy(ctx, nbits);
break;
case Sse:
if (size <= 4)
if (size <= 2)
types[0] = Type::getHalfTy(ctx);
else if (size <= 4)
types[0] = Type::getFloatTy(ctx);
else
types[0] = Type::getDoubleTy(ctx);
Expand Down
8 changes: 6 additions & 2 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

#if JULIA_FLOAT16_ABI == 1
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
Expand All @@ -514,7 +513,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}
#endif

void multiversioning_preannotate(Module &M);

// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
Expand Down Expand Up @@ -1061,6 +1060,11 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer
#else
emitFloat16Wrappers(M, false);
#endif

injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
}
timers.optimize.stopTimer();
}
Expand Down
13 changes: 6 additions & 7 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,22 +1123,21 @@ std::string generate_func_sig(const char *fname)
isboxed = false;
}
else {
if (jl_is_primitivetype(tti)) {
t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall);
if (t == getVoidTy(LLVMCtx)) {
return make_errmsg(fname, i + 1, " type doesn't correspond to a C type");
}
if (jl_is_primitivetype(tti) && t->isIntegerTy()) {
// see pull req #978. need to annotate signext/zeroext for
// small integer arguments.
jl_datatype_t *bt = (jl_datatype_t*)tti;
if (jl_datatype_size(bt) < 4 && bt != jl_float16_type) {
if (jl_datatype_size(bt) < 4) {
if (jl_signed_type && jl_subtype(tti, (jl_value_t*)jl_signed_type))
ab.addAttribute(Attribute::SExt);
else
ab.addAttribute(Attribute::ZExt);
}
}

t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall);
if (t == getVoidTy(LLVMCtx)) {
return make_errmsg(fname, i + 1, " type doesn't correspond to a C type");
}
}

Type *pat;
Expand Down
2 changes: 2 additions & 0 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ static Type *bitstype_to_llvm(jl_value_t *bt, LLVMContext &ctxt, bool llvmcall =
return getFloatTy(ctxt);
if (bt == (jl_value_t*)jl_float64_type)
return getDoubleTy(ctxt);
if (bt == (jl_value_t*)jl_bfloat16_type)
return getBFloatTy(ctxt);
if (jl_is_llvmpointer_type(bt)) {
jl_value_t *as_param = jl_tparam1(bt);
int as;
Expand Down
3 changes: 3 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ auto getFloatTy(LLVMContext &ctxt) {
auto getDoubleTy(LLVMContext &ctxt) {
return Type::getDoubleTy(ctxt);
}
auto getBFloatTy(LLVMContext &ctxt) {
return Type::getBFloatTy(ctxt);
}
auto getFP128Ty(LLVMContext &ctxt) {
return Type::getFP128Ty(ctxt);
}
Expand Down
2 changes: 1 addition & 1 deletion src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static Type *INTT(Type *t, const DataLayout &DL)
return getInt64Ty(ctxt);
if (t == getFloatTy(ctxt))
return getInt32Ty(ctxt);
if (t == getHalfTy(ctxt))
if (t == getHalfTy(ctxt) || t == getBFloatTy(ctxt))
return getInt16Ty(ctxt);
unsigned nb = t->getPrimitiveSizeInBits();
assert(t != getVoidTy(ctxt) && nb > 0);
Expand Down
8 changes: 5 additions & 3 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,16 +1727,18 @@ JuliaOJIT::JuliaOJIT()
ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
#if JULIA_FLOAT16_ABI == 1
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } },
#endif
{ mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } },
{ mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } },
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
#endif

#ifdef MSAN_EMUTLS_WORKAROUND
orc::SymbolMap msan_crt;
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
XX(jl_float16_type) \
XX(jl_float32_type) \
XX(jl_float64_type) \
XX(jl_bfloat16_type) \
XX(jl_floatingpoint_type) \
XX(jl_function_type) \
XX(jl_binding_type) \
Expand Down
2 changes: 2 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3403,6 +3403,8 @@ void post_boot_hooks(void)
//XX(float32);
jl_float64_type = (jl_datatype_t*)core("Float64");
//XX(float64);
jl_bfloat16_type = (jl_datatype_t*)core("BFloat16");
//XX(bfloat16);
jl_floatingpoint_type = (jl_datatype_t*)core("AbstractFloat");
jl_number_type = (jl_datatype_t*)core("Number");
jl_signed_type = (jl_datatype_t*)core("Signed");
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_uint64_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float16_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float32_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float64_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_bfloat16_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_floatingpoint_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_number_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_void_type JL_GLOBALLY_ROOTED; // deprecated
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
Expand Down
43 changes: 32 additions & 11 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

// This pass finds floating-point operations on 16-bit (half precision) values, and replaces
// them by equivalent operations on 32-bit (single precision) values surrounded by a fpext
// and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved.
// This pass finds floating-point operations on 16-bit values (half precision and bfloat),
// and replaces them by equivalent operations on 32-bit (single precision) values surrounded
// by a fpext and fptrunc. This ensures that the exact semantics of IEEE floating-point are
// preserved.
//
// Without this pass, back-ends that do not natively support half-precision (e.g. x86_64)
// similarly pattern-match half-precision operations with single-precision equivalents, but
Expand Down Expand Up @@ -71,25 +72,40 @@ static bool have_fp16(Function &caller, const Triple &TT) {
return false;
}

static bool have_bf16(Function &caller, const Triple &TT) {
if (caller.hasFnAttribute("julia.hasbf16")) {
return true;
}

// there's no targets that fully support bfloat yet;,
// AVX512BF16 only provides conversion and dot product instructions.
return false;
}

static bool demoteFloat16(Function &F)
{
auto TT = Triple(F.getParent()->getTargetTriple());
if (have_fp16(F, TT))
auto has_fp16 = have_fp16(F, TT);
auto has_bf16 = have_bf16(F, TT);
if (has_fp16 && has_bf16)
return false;

auto &ctx = F.getContext();
auto T_float32 = Type::getFloatTy(ctx);
SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
// check whether there's any 16-bit floating point operands to extend
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
bool BFloat16 = I.getType()->getScalarType()->isBFloatTy();
for (size_t i = 0; !BFloat16 && !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy())
BFloat16 = true;
}
if (!Float16)
if (!Float16 && !BFloat16)
continue;

switch (I.getOpcode()) {
Expand All @@ -113,19 +129,24 @@ static bool demoteFloat16(Function &F)

IRBuilder<> builder(&I);

// extend Float16 operands to Float32
// extend 16-bit floating point operands
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy()) {
if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) {
// extend Float16 to Float32
++TotalExt;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
} else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) {
// extend BFloat16 to Float32
++TotalExt;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
}
Operands[i] = Op;
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
// truncating the result back to the original type
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
Expand Down
10 changes: 6 additions & 4 deletions src/llvm-multiversioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ extern Optional<bool> always_have_fma(Function&, const Triple &TT);

namespace {
constexpr uint32_t clone_mask =
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16;
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16 | JL_TARGET_CLONE_BFLOAT16;

// Treat identical mapping as missing and return `def` in that case.
// We mainly need this to identify cloned function using value map after LLVM cloning
Expand Down Expand Up @@ -126,12 +126,14 @@ static uint32_t collect_func_info(Function &F, const Triple &TT, bool &has_vecca
}

for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
if(I.getOperand(i)->getType()->isHalfTy()) {
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here
if(I.getOperand(i)->getType()->isBFloatTy()) {
flag |= JL_TARGET_CLONE_BFLOAT16;
}
}
uint32_t veccall_flags = JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16;
uint32_t veccall_flags = JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16 | JL_TARGET_CLONE_BFLOAT16;
if (has_veccall && (flag & veccall_flags) == veccall_flags) {
return flag;
}
Expand Down
2 changes: 2 additions & 0 deletions src/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ enum {
JL_TARGET_CLONE_CPU = 1 << 8,
// Clone when the function uses fp16
JL_TARGET_CLONE_FLOAT16 = 1 << 9,
// Clone when the function uses bf16
JL_TARGET_CLONE_BFLOAT16 = 1 << 10,
};

#define JL_FEATURE_DEF_NAME(name, bit, llvmver, str) JL_FEATURE_DEF(name, bit, llvmver)
Expand Down
7 changes: 7 additions & 0 deletions src/processor_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ static void ensure_jit_target(bool imaging)
break;
}
}
static constexpr uint32_t clone_bf16[] = {Feature::avx512bf16};
for (auto fe: clone_bf16) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_BFLOAT16;
break;
}
}
}
}

Expand Down
38 changes: 38 additions & 0 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,44 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
return float_to_half(res);
}

JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT
{
uint16_t result;

if (isnan(param))
result = 0x7fc0;
else {
uint32_t bits = *((uint32_t*) &param);

// round to nearest even
bits += 0x7fff + ((bits >> 16) & 1);
result = (uint16_t)(bits >> 16);
}

// on x86, bfloat16 needs to be returned in XMM. only GCC 13 provides the necessary ABI
// support in the form of the __bf16 type; older versions only provide __bfloat16 which
// is simply a typedef for short (i16). so use float, which is passed in XMM too.
uint32_t result_32bit = (uint32_t)result;
return *(float*)&result_32bit;
}

JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT
{
float res = (float)param;
uint32_t resi;
memcpy(&resi, &res, sizeof(res));

// bfloat16 uses the same exponent as float32, so we don't need special handling
// for subnormals when truncating float64 to bfloat16.

if ((resi & 0x1ffu) == 0x100u) { // if we are halfway between 2 bfloat16 values
// adjust the value by 1 ULP in the direction that will make bfloat16(res) give the right answer
resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res));
memcpy(&res, &resi, sizeof(res));
}
return julia__truncsfbf2(res);
}

//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); }
Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ extern "C" {
// TODO: put WeakRefs on the weak_refs list during deserialization
// TODO: handle finalizers

#define NUM_TAGS 159
#define NUM_TAGS 160

// An array of references that need to be restored from the sysimg
// This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C.
Expand Down Expand Up @@ -194,6 +194,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_float16_type);
INSERT_TAG(jl_float32_type);
INSERT_TAG(jl_float64_type);
INSERT_TAG(jl_bfloat16_type);
INSERT_TAG(jl_floatingpoint_type);
INSERT_TAG(jl_number_type);
INSERT_TAG(jl_signed_type);
Expand Down
Loading

0 comments on commit 5487046

Please sign in to comment.