Skip to content

Commit 5487046

Browse files
authored
Add native support for BFloat16. (JuliaLang#51470)
This PR adds native support for the LLVM `bfloat` type, through a new `BFloat16` type. It doesn't add any language-level functionality, only the bare minimum support (e.g. runtime conversion routines). Use of the BFloat16s.jl package is still required to use BFloat16 values.
1 parent 20a5fa7 commit 5487046

21 files changed

+188
-34
lines changed

base/boot.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ primitive type Float16 <: AbstractFloat 16 end
217217
primitive type Float32 <: AbstractFloat 32 end
218218
primitive type Float64 <: AbstractFloat 64 end
219219

220+
primitive type BFloat16 <: AbstractFloat 16 end
221+
220222
#primitive type Bool <: Integer 8 end
221223
abstract type AbstractChar end
222224
primitive type Char <: AbstractChar 32 end

doc/src/base/reflection.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ the abstract `DataType` [`AbstractFloat`](@ref) has four (concrete) subtypes:
5252

5353
```jldoctest; setup = :(using InteractiveUtils)
5454
julia> subtypes(AbstractFloat)
55-
4-element Vector{Any}:
55+
5-element Vector{Any}:
5656
BigFloat
57+
Core.BFloat16
5758
Float16
5859
Float32
5960
Float64

src/abi_x86_64.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct Classification {
118118
void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const
119119
{
120120
// Floating point types
121-
if (dt == jl_float64_type || dt == jl_float32_type) {
121+
if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_bfloat16_type) {
122122
accum.addField(offset, Sse);
123123
}
124124
// Misc types
@@ -239,7 +239,9 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const
239239
types[0] = Type::getIntNTy(ctx, nbits);
240240
break;
241241
case Sse:
242-
if (size <= 4)
242+
if (size <= 2)
243+
types[0] = Type::getHalfTy(ctx);
244+
else if (size <= 4)
243245
types[0] = Type::getFloatTy(ctx);
244246
else
245247
types[0] = Type::getDoubleTy(ctx);

src/aotcompile.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ static void reportWriterError(const ErrorInfoBase &E)
497497
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
498498
}
499499

500-
#if JULIA_FLOAT16_ABI == 1
501500
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
502501
{
503502
Function *target = M.getFunction(alias);
@@ -514,7 +513,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
514513
auto val = builder.CreateCall(target, CallArgs);
515514
builder.CreateRet(val);
516515
}
517-
#endif
516+
518517
void multiversioning_preannotate(Module &M);
519518

520519
// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
@@ -1061,6 +1060,11 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer
10611060
#else
10621061
emitFloat16Wrappers(M, false);
10631062
#endif
1063+
1064+
injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2",
1065+
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
1066+
injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2",
1067+
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
10641068
}
10651069
timers.optimize.stopTimer();
10661070
}

src/ccall.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,22 +1123,21 @@ std::string generate_func_sig(const char *fname)
11231123
isboxed = false;
11241124
}
11251125
else {
1126-
if (jl_is_primitivetype(tti)) {
1126+
t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall);
1127+
if (t == getVoidTy(LLVMCtx)) {
1128+
return make_errmsg(fname, i + 1, " type doesn't correspond to a C type");
1129+
}
1130+
if (jl_is_primitivetype(tti) && t->isIntegerTy()) {
11271131
// see pull req #978. need to annotate signext/zeroext for
11281132
// small integer arguments.
11291133
jl_datatype_t *bt = (jl_datatype_t*)tti;
1130-
if (jl_datatype_size(bt) < 4 && bt != jl_float16_type) {
1134+
if (jl_datatype_size(bt) < 4) {
11311135
if (jl_signed_type && jl_subtype(tti, (jl_value_t*)jl_signed_type))
11321136
ab.addAttribute(Attribute::SExt);
11331137
else
11341138
ab.addAttribute(Attribute::ZExt);
11351139
}
11361140
}
1137-
1138-
t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall);
1139-
if (t == getVoidTy(LLVMCtx)) {
1140-
return make_errmsg(fname, i + 1, " type doesn't correspond to a C type");
1141-
}
11421141
}
11431142

11441143
Type *pat;

src/cgutils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,8 @@ static Type *bitstype_to_llvm(jl_value_t *bt, LLVMContext &ctxt, bool llvmcall =
665665
return getFloatTy(ctxt);
666666
if (bt == (jl_value_t*)jl_float64_type)
667667
return getDoubleTy(ctxt);
668+
if (bt == (jl_value_t*)jl_bfloat16_type)
669+
return getBFloatTy(ctxt);
668670
if (jl_is_llvmpointer_type(bt)) {
669671
jl_value_t *as_param = jl_tparam1(bt);
670672
int as;

src/codegen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ auto getFloatTy(LLVMContext &ctxt) {
125125
auto getDoubleTy(LLVMContext &ctxt) {
126126
return Type::getDoubleTy(ctxt);
127127
}
128+
auto getBFloatTy(LLVMContext &ctxt) {
129+
return Type::getBFloatTy(ctxt);
130+
}
128131
auto getFP128Ty(LLVMContext &ctxt) {
129132
return Type::getFP128Ty(ctxt);
130133
}

src/intrinsics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ static Type *INTT(Type *t, const DataLayout &DL)
165165
return getInt64Ty(ctxt);
166166
if (t == getFloatTy(ctxt))
167167
return getInt32Ty(ctxt);
168-
if (t == getHalfTy(ctxt))
168+
if (t == getHalfTy(ctxt) || t == getBFloatTy(ctxt))
169169
return getInt16Ty(ctxt);
170170
unsigned nb = t->getPrimitiveSizeInBits();
171171
assert(t != getVoidTy(ctxt) && nb > 0);

src/jitlayers.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,16 +1727,18 @@ JuliaOJIT::JuliaOJIT()
17271727
ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
17281728
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
17291729

1730-
#if JULIA_FLOAT16_ABI == 1
17311730
orc::SymbolAliasMap jl_crt = {
1731+
#if JULIA_FLOAT16_ABI == 1
17321732
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
17331733
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
17341734
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
17351735
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
1736-
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
1736+
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } },
1737+
#endif
1738+
{ mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } },
1739+
{ mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } },
17371740
};
17381741
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
1739-
#endif
17401742

17411743
#ifdef MSAN_EMUTLS_WORKAROUND
17421744
orc::SymbolMap msan_crt;

src/jl_exported_data.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
XX(jl_float16_type) \
4343
XX(jl_float32_type) \
4444
XX(jl_float64_type) \
45+
XX(jl_bfloat16_type) \
4546
XX(jl_floatingpoint_type) \
4647
XX(jl_function_type) \
4748
XX(jl_binding_type) \

0 commit comments

Comments
 (0)