Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add f6E3M2FN type #105573

Merged
merged 2 commits into from
Sep 10, 2024
Merged

[MLIR] Add f6E3M2FN type #105573

merged 2 commits into from
Sep 10, 2024

Conversation

sergey-kozub
Copy link
Contributor

This PR adds f6E3M2FN type to mlir.

f6E3M2FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-97118 [MLIR] Add f8E4M3 type - was used as a template for this PR

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-llvm-adt

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f6E3M2FN type to mlir.

f6E3M2FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-97118 [MLIR] Add f8E4M3 type - was used as a template for this PR

Patch is 21.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105573.diff

25 Files Affected:

  • (modified) llvm/unittests/ADT/APFloatTest.cpp (+6)
  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+8-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+18)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+1-1)
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index be675bb7fe5a53..323a35d41bb6d2 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -2084,8 +2084,14 @@ TEST(APFloatTest, getSmallestNormalized) {
   EXPECT_FALSE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
   EXPECT_TRUE(test.isSmallestNormalized());
+
   test = APFloat::getSmallestNormalized(APFloat::Float6E3M2FN(), false);
   expected = APFloat(APFloat::Float6E3M2FN(), "0x1p-2");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  EXPECT_TRUE(test.isSmallestNormalized());
 
   test = APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), false);
   expected = APFloat(APFloat::Float4E2M1FN(), "0x1p0");
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index d698bf4764568f..7f2942050dc080 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -149,6 +149,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float6E3M2FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E3M2FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2(MlirType type);
+
+/// Creates an f8E3M2FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index b5962f3783924f..e310a94c110a93 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -67,6 +67,7 @@ class Builder {
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
   FloatType getFloat8E3M4Type();
+  FloatType getFloat6E3M2FNType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index eefa4279df1a01..479771969f869d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E3M4(MLIRContext *ctx);
+  static FloatType getFloat6E3M2FN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -415,9 +416,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 inline bool FloatType::classof(Type type) {
   return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                    Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
-                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
-                   Float80Type, Float128Type>(type);
+                   Float8E4M3B11FNUZType, Float8E3M4Type, Float6E3M2FNType,
+                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+                   Float64Type, Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -448,6 +449,10 @@ inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
   return Float8E3M4Type::get(ctx);
 }
 
+inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
+  return Float6E3M2FNType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1ab1bbe9bfc9b2..aa495904b69ad3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,24 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float6E3M2FNType
+
+def Builtin_Float6E3M2FN : Builtin_FloatType<"Float6E3M2FN", "f6E3M2FN"> {
+  let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+  let description = [{
+    An 6-bit floating point type with 1 sign bit, 3 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M2
+      * exponent bias: 3
+      * infinities: Not supported
+      * NaNs: Not supported
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..09eab50f53a540 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
              BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+               BuildableType<"$_builder.getFloat6E3M2FNType()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index dfc7e69472c891..8c55d29bb5b5cd 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -132,6 +132,7 @@ class Type {
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
   bool isFloat8E3M4() const;
+  bool isFloat6E3M2FN() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 4c1c1c21031c88..fa18cbe9e2b901 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E3M2FN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index f070c072c43296..b324e0b336de0d 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -46,6 +46,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
   case Token::kw_f8E3M4:
+  case Token::kw_f6E3M2FN:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -324,6 +325,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E3M4:
     consumeToken(Token::kw_f8E3M4);
     return builder.getFloat8E3M4Type();
+  case Token::kw_f6E3M2FN:
+    consumeToken(Token::kw_f6E3M2FN);
+    return builder.getFloat6E3M2FNType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index c3d42c0ef8e3cb..eeb6c8d5c09385 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -266,6 +266,27 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float6E3M2FNType.
+class PyFloat6E3M2FNType
+    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E3M2FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E3M2FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+          return PyFloat6E3M2FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -885,6 +906,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
   PyFloat8E3M4Type::bind(m);
+  PyFloat6E3M2FNType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 2aa2e922f2abcc..371b714914691e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -169,6 +169,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
+  return wrap(Float6E3M2FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
+  return unwrap(type).isFloat6E3M2FN();
+}
+
+MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat6E3M2FN(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5313a64ed47e3a..b2c54bb3212edb 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,8 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
+      type.isFloat6E3M2FN())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 51f229ef937c40..0834b2a07219ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -61,6 +61,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
       .Case("f8E3M4", b.getFloat8E3M4Type())
+      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..060fbc564803e9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2582,6 +2582,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
       .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
+      .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e3d6d71fb61dfb..ce318f32961f94 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -62,6 +62,10 @@ FloatType Builder::getFloat8E3M4Type() {
   return FloatType::getFloat8E3M4(context);
 }
 
+FloatType Builder::getFloat6E3M2FNType() {
+  return FloatType::getFloat6E3M2FN(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 16b53efa55fb80..71f8564e700751 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -91,6 +91,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
+  if (llvm::isa<Float6E3M2FNType>(*this))
+    return 6;
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                 Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
                 Float8E3M4Type>(*this))
@@ -124,6 +126,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3B11FNUZ();
   if (llvm::isa<Float8E3M4Type>(*this))
     return APFloat::Float8E3M4();
+  if (llvm::isa<Float6E3M2FNType>(*this))
+    return APFloat::Float6E3M2FN();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 5c93747438ecdb..6ae38b665df6f0 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -228,6 +228,7 @@ class MLIRContextImpl {
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
   Float8E3M4Type f8E3M4Ty;
+  Float6E3M2FNType f6E3M2FNTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -320,6 +321,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
   impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
+  impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1034,6 +1036,9 @@ Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
 Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
   return context->getImpl().f8E3M4Ty;
 }
+Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
+  return context->getImpl().f6E3M2FNTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 2bc26388b6218a..bbdc702c9eed39 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -47,6 +47,7 @@ bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
 bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
+bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index e3599d3c84ffed..fd504fcdbdd2ef 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float6E3M2FNType",
     "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
@@ -1538,6 +1539,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float6E3M2FNType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float6E3M2FNType:
+        """
+        Create a float6_e3m2fn type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E3M4Type(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fe7c3e25d16906..0c6ece91d8b94a 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float6E3M2FNType,
     Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
@@ -74,6 +75,7 @@ def ui(width):
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 f8E3M4 = lambda: Float8E3M4Type.get()
+f6E3M2FN = lambda: Float6E3M2FNType.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index ac0aec113add17..6d2485b68e11d6 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -64,6 +64,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E3M4
     float_attr = 2. : f8E3M4
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f6E3M2FN
+    float_attr = 2. : f6E3M2FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8453983aa07c33..0103694d1fc638 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
+llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
+
 // CHECK: @f8E3M4_global_as_i8 = internal global i8 56
 llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index f95cccc54105ed..b72ef4de0bd6dd 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
@@ -233,6 +235,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f6E3M2FN
+        print("float:", Float6E3M2FNType.get())
         # CHECK: float: f8E3M4
         print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
@@ -609,6 +613,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float6E3M2FNType, Float6E3M2FNType.get()),
             (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
@@ -634,6 +639,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float6E3M2FNType(f6E3M2FN)
         # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
@@ -713,6 +719,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float6E3M2FNType
+        # CHECK: Float6E3M2FNType(f6E3M2FN)
+        print_downcasted(Float6E3M2FNType.get())
         # CHECK: Float8E3M4Type
         # CHECK: Float8E3M4Type(f8E3M4)
         print_downcasted(Float8E3M4Type.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index e7c526842439b9..672a2856294601 100644
--- a/mlir/utils/lldb-sc...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-mlir-arith

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f6E3M2FN type to mlir.

f6E3M2FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-97118 [MLIR] Add f8E4M3 type - was used as a template for this PR

Patch is 21.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105573.diff

25 Files Affected:

  • (modified) llvm/unittests/ADT/APFloatTest.cpp (+6)
  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+8-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+18)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+1-1)
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index be675bb7fe5a53..323a35d41bb6d2 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -2084,8 +2084,14 @@ TEST(APFloatTest, getSmallestNormalized) {
   EXPECT_FALSE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
   EXPECT_TRUE(test.isSmallestNormalized());
+
   test = APFloat::getSmallestNormalized(APFloat::Float6E3M2FN(), false);
   expected = APFloat(APFloat::Float6E3M2FN(), "0x1p-2");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  EXPECT_TRUE(test.isSmallestNormalized());
 
   test = APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), false);
   expected = APFloat(APFloat::Float4E2M1FN(), "0x1p0");
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index d698bf4764568f..7f2942050dc080 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -149,6 +149,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float6E3M2FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E3M2FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2(MlirType type);
+
+/// Creates an f8E3M2FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index b5962f3783924f..e310a94c110a93 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -67,6 +67,7 @@ class Builder {
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
   FloatType getFloat8E3M4Type();
+  FloatType getFloat6E3M2FNType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index eefa4279df1a01..479771969f869d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E3M4(MLIRContext *ctx);
+  static FloatType getFloat6E3M2FN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -415,9 +416,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 inline bool FloatType::classof(Type type) {
   return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                    Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
-                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
-                   Float80Type, Float128Type>(type);
+                   Float8E4M3B11FNUZType, Float8E3M4Type, Float6E3M2FNType,
+                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+                   Float64Type, Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -448,6 +449,10 @@ inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
   return Float8E3M4Type::get(ctx);
 }
 
+inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
+  return Float6E3M2FNType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1ab1bbe9bfc9b2..aa495904b69ad3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,24 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float6E3M2FNType
+
+def Builtin_Float6E3M2FN : Builtin_FloatType<"Float6E3M2FN", "f6E3M2FN"> {
+  let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+  let description = [{
+    An 6-bit floating point type with 1 sign bit, 3 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M2
+      * exponent bias: 3
+      * infinities: Not supported
+      * NaNs: Not supported
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..09eab50f53a540 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
              BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+               BuildableType<"$_builder.getFloat6E3M2FNType()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index dfc7e69472c891..8c55d29bb5b5cd 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -132,6 +132,7 @@ class Type {
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
   bool isFloat8E3M4() const;
+  bool isFloat6E3M2FN() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 4c1c1c21031c88..fa18cbe9e2b901 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E3M2FN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index f070c072c43296..b324e0b336de0d 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -46,6 +46,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
   case Token::kw_f8E3M4:
+  case Token::kw_f6E3M2FN:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -324,6 +325,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E3M4:
     consumeToken(Token::kw_f8E3M4);
     return builder.getFloat8E3M4Type();
+  case Token::kw_f6E3M2FN:
+    consumeToken(Token::kw_f6E3M2FN);
+    return builder.getFloat6E3M2FNType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index c3d42c0ef8e3cb..eeb6c8d5c09385 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -266,6 +266,27 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float6E3M2FNType.
+class PyFloat6E3M2FNType
+    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E3M2FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E3M2FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+          return PyFloat6E3M2FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -885,6 +906,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
   PyFloat8E3M4Type::bind(m);
+  PyFloat6E3M2FNType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 2aa2e922f2abcc..371b714914691e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -169,6 +169,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
+  return wrap(Float6E3M2FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
+  return unwrap(type).isFloat6E3M2FN();
+}
+
+MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat6E3M2FN(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5313a64ed47e3a..b2c54bb3212edb 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,8 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
+      type.isFloat6E3M2FN())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 51f229ef937c40..0834b2a07219ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -61,6 +61,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
       .Case("f8E3M4", b.getFloat8E3M4Type())
+      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..060fbc564803e9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2582,6 +2582,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
       .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
+      .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e3d6d71fb61dfb..ce318f32961f94 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -62,6 +62,10 @@ FloatType Builder::getFloat8E3M4Type() {
   return FloatType::getFloat8E3M4(context);
 }
 
+FloatType Builder::getFloat6E3M2FNType() {
+  return FloatType::getFloat6E3M2FN(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 16b53efa55fb80..71f8564e700751 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -91,6 +91,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
+  if (llvm::isa<Float6E3M2FNType>(*this))
+    return 6;
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                 Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
                 Float8E3M4Type>(*this))
@@ -124,6 +126,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3B11FNUZ();
   if (llvm::isa<Float8E3M4Type>(*this))
     return APFloat::Float8E3M4();
+  if (llvm::isa<Float6E3M2FNType>(*this))
+    return APFloat::Float6E3M2FN();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 5c93747438ecdb..6ae38b665df6f0 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -228,6 +228,7 @@ class MLIRContextImpl {
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
   Float8E3M4Type f8E3M4Ty;
+  Float6E3M2FNType f6E3M2FNTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -320,6 +321,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
   impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
+  impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1034,6 +1036,9 @@ Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
 Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
   return context->getImpl().f8E3M4Ty;
 }
+Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
+  return context->getImpl().f6E3M2FNTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 2bc26388b6218a..bbdc702c9eed39 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -47,6 +47,7 @@ bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
 bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
+bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index e3599d3c84ffed..fd504fcdbdd2ef 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float6E3M2FNType",
     "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
@@ -1538,6 +1539,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float6E3M2FNType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float6E3M2FNType:
+        """
+        Create a float6_e3m2fn type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E3M4Type(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fe7c3e25d16906..0c6ece91d8b94a 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float6E3M2FNType,
     Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
@@ -74,6 +75,7 @@ def ui(width):
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 f8E3M4 = lambda: Float8E3M4Type.get()
+f6E3M2FN = lambda: Float6E3M2FNType.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index ac0aec113add17..6d2485b68e11d6 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -64,6 +64,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E3M4
     float_attr = 2. : f8E3M4
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f6E3M2FN
+    float_attr = 2. : f6E3M2FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8453983aa07c33..0103694d1fc638 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
+llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
+
 // CHECK: @f8E3M4_global_as_i8 = internal global i8 56
 llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index f95cccc54105ed..b72ef4de0bd6dd 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
@@ -233,6 +235,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f6E3M2FN
+        print("float:", Float6E3M2FNType.get())
         # CHECK: float: f8E3M4
         print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
@@ -609,6 +613,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float6E3M2FNType, Float6E3M2FNType.get()),
             (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
@@ -634,6 +639,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float6E3M2FNType(f6E3M2FN)
         # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
@@ -713,6 +719,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float6E3M2FNType
+        # CHECK: Float6E3M2FNType(f6E3M2FN)
+        print_downcasted(Float6E3M2FNType.get())
         # CHECK: Float8E3M4Type
         # CHECK: Float8E3M4Type(f8E3M4)
         print_downcasted(Float8E3M4Type.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index e7c526842439b9..672a2856294601 100644
--- a/mlir/utils/lldb-sc...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-mlir-core

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f6E3M2FN type to mlir.

f6E3M2FN type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625

Related PRs:

  • PR-94735 [APFloat] Add APFloat support for FP6 data types
  • PR-97118 [MLIR] Add f8E4M3 type - was used as a template for this PR

Patch is 21.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105573.diff

25 Files Affected:

  • (modified) llvm/unittests/ADT/APFloatTest.cpp (+6)
  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+8-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+18)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+1-1)
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index be675bb7fe5a53..323a35d41bb6d2 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -2084,8 +2084,14 @@ TEST(APFloatTest, getSmallestNormalized) {
   EXPECT_FALSE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
   EXPECT_TRUE(test.isSmallestNormalized());
+
   test = APFloat::getSmallestNormalized(APFloat::Float6E3M2FN(), false);
   expected = APFloat(APFloat::Float6E3M2FN(), "0x1p-2");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  EXPECT_TRUE(test.isSmallestNormalized());
 
   test = APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), false);
   expected = APFloat(APFloat::Float4E2M1FN(), "0x1p0");
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index d698bf4764568f..7f2942050dc080 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -149,6 +149,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float6E3M2FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E3M2FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2(MlirType type);
+
+/// Creates an f8E3M2FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index b5962f3783924f..e310a94c110a93 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -67,6 +67,7 @@ class Builder {
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
   FloatType getFloat8E3M4Type();
+  FloatType getFloat6E3M2FNType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index eefa4279df1a01..479771969f869d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E3M4(MLIRContext *ctx);
+  static FloatType getFloat6E3M2FN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -415,9 +416,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 inline bool FloatType::classof(Type type) {
   return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                    Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
-                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
-                   Float80Type, Float128Type>(type);
+                   Float8E4M3B11FNUZType, Float8E3M4Type, Float6E3M2FNType,
+                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+                   Float64Type, Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -448,6 +449,10 @@ inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
   return Float8E3M4Type::get(ctx);
 }
 
+inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
+  return Float6E3M2FNType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1ab1bbe9bfc9b2..aa495904b69ad3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,24 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float6E3M2FNType
+
+def Builtin_Float6E3M2FN : Builtin_FloatType<"Float6E3M2FN", "f6E3M2FN"> {
+  let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+  let description = [{
+    An 6-bit floating point type with 1 sign bit, 3 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M2
+      * exponent bias: 3
+      * infinities: Not supported
+      * NaNs: Not supported
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..09eab50f53a540 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
              BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+               BuildableType<"$_builder.getFloat6E3M2FNType()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index dfc7e69472c891..8c55d29bb5b5cd 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -132,6 +132,7 @@ class Type {
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
   bool isFloat8E3M4() const;
+  bool isFloat6E3M2FN() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 4c1c1c21031c88..fa18cbe9e2b901 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
 TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E3M2FN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index f070c072c43296..b324e0b336de0d 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -46,6 +46,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
   case Token::kw_f8E3M4:
+  case Token::kw_f6E3M2FN:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -324,6 +325,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E3M4:
     consumeToken(Token::kw_f8E3M4);
     return builder.getFloat8E3M4Type();
+  case Token::kw_f6E3M2FN:
+    consumeToken(Token::kw_f6E3M2FN);
+    return builder.getFloat6E3M2FNType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index c3d42c0ef8e3cb..eeb6c8d5c09385 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -266,6 +266,27 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float6E3M2FNType.
+class PyFloat6E3M2FNType
+    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E3M2FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E3M2FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+          return PyFloat6E3M2FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -885,6 +906,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
   PyFloat8E3M4Type::bind(m);
+  PyFloat6E3M2FNType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 2aa2e922f2abcc..371b714914691e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -169,6 +169,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
+  return wrap(Float6E3M2FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
+  return unwrap(type).isFloat6E3M2FN();
+}
+
+MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat6E3M2FN(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5313a64ed47e3a..b2c54bb3212edb 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,8 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
+      type.isFloat6E3M2FN())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 51f229ef937c40..0834b2a07219ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -61,6 +61,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
       .Case("f8E3M4", b.getFloat8E3M4Type())
+      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..060fbc564803e9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2582,6 +2582,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
       .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
+      .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e3d6d71fb61dfb..ce318f32961f94 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -62,6 +62,10 @@ FloatType Builder::getFloat8E3M4Type() {
   return FloatType::getFloat8E3M4(context);
 }
 
+FloatType Builder::getFloat6E3M2FNType() {
+  return FloatType::getFloat6E3M2FN(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 16b53efa55fb80..71f8564e700751 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -91,6 +91,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
+  if (llvm::isa<Float6E3M2FNType>(*this))
+    return 6;
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
                 Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
                 Float8E3M4Type>(*this))
@@ -124,6 +126,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3B11FNUZ();
   if (llvm::isa<Float8E3M4Type>(*this))
     return APFloat::Float8E3M4();
+  if (llvm::isa<Float6E3M2FNType>(*this))
+    return APFloat::Float6E3M2FN();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 5c93747438ecdb..6ae38b665df6f0 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -228,6 +228,7 @@ class MLIRContextImpl {
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
   Float8E3M4Type f8E3M4Ty;
+  Float6E3M2FNType f6E3M2FNTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -320,6 +321,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
   impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
+  impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1034,6 +1036,9 @@ Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
 Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
   return context->getImpl().f8E3M4Ty;
 }
+Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
+  return context->getImpl().f6E3M2FNTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 2bc26388b6218a..bbdc702c9eed39 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -47,6 +47,7 @@ bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
 bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
+bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index e3599d3c84ffed..fd504fcdbdd2ef 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float6E3M2FNType",
     "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
@@ -1538,6 +1539,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float6E3M2FNType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float6E3M2FNType:
+        """
+        Create a float6_e3m2fn type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E3M4Type(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fe7c3e25d16906..0c6ece91d8b94a 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float6E3M2FNType,
     Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
@@ -74,6 +75,7 @@ def ui(width):
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
 f8E3M4 = lambda: Float8E3M4Type.get()
+f6E3M2FN = lambda: Float6E3M2FNType.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index ac0aec113add17..6d2485b68e11d6 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -64,6 +64,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E3M4
     float_attr = 2. : f8E3M4
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f6E3M2FN
+    float_attr = 2. : f6E3M2FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8453983aa07c33..0103694d1fc638 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
+llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
+
 // CHECK: @f8E3M4_global_as_i8 = internal global i8 56
 llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index f95cccc54105ed..b72ef4de0bd6dd 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
@@ -233,6 +235,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f6E3M2FN
+        print("float:", Float6E3M2FNType.get())
         # CHECK: float: f8E3M4
         print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
@@ -609,6 +613,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float6E3M2FNType, Float6E3M2FNType.get()),
             (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
@@ -634,6 +639,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float6E3M2FNType(f6E3M2FN)
         # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
@@ -713,6 +719,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float6E3M2FNType
+        # CHECK: Float6E3M2FNType(f6E3M2FN)
+        print_downcasted(Float6E3M2FNType.get())
         # CHECK: Float8E3M4Type
         # CHECK: Float8E3M4Type(f8E3M4)
         print_downcasted(Float8E3M4Type.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index e7c526842439b9..672a2856294601 100644
--- a/mlir/utils/lldb-sc...
[truncated]

@sergey-kozub
Copy link
Contributor Author

I'm planning to add a similar PR for f6E2M3FN type (also has APFloat semantics already).

@sergey-kozub sergey-kozub marked this pull request as draft August 21, 2024 19:35
sergey-kozub added a commit to sergey-kozub/llvm-project that referenced this pull request Aug 27, 2024
`f6E3M2FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-97118](llvm#97118) [MLIR] Add f8E4M3 type - was used as a template for this PR
@sergey-kozub sergey-kozub marked this pull request as ready for review August 27, 2024 19:40
sergey-kozub added a commit to sergey-kozub/llvm-project that referenced this pull request Aug 28, 2024
`f6E3M2FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-97118](llvm#97118) [MLIR] Add f8E4M3 type - was used as a template for this PR
@durga4github
Copy link
Contributor

The changes look good to me,

@ThomasRaoux , Could you please help with a review?

sergey-kozub added a commit to sergey-kozub/llvm-project that referenced this pull request Aug 28, 2024
`f6E3M2FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-97118](llvm#97118) [MLIR] Add f8E4M3 type - was used as a template for this PR
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG in general. Dropped some minor comments

@@ -67,6 +67,7 @@ class Builder {
FloatType getFloat8E4M3FNUZType();
FloatType getFloat8E4M3B11FNUZType();
FloatType getFloat8E3M4Type();
FloatType getFloat6E3M2FNType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it looks like methods are ordered by type bitwidth (?) so I would move this one before the Float8 ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, here and in the other places.

Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
Float16Type, FloatTF32Type, Float32Type, Float64Type,
Float80Type, Float128Type>(type);
Float8E4M3B11FNUZType, Float8E3M4Type, Float6E3M2FNType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same? (follow bitwidth order)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

let description = [{
An 6-bit floating point type with 1 sign bit, 3 bits exponent and 2 bits
mantissa. This is not a standard type as defined by IEEE-754, but it
follows similar conventions with the following characteristics:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a reference to the OCP spec or at least mention it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a link.

@@ -61,6 +61,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -91,6 +91,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
if (llvm::isa<Float6E3M2FNType>(*this))
return 6;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use APFloat here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this in a separate PR, if you don't mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #107372

@stellaraccident
Copy link
Contributor

stellaraccident commented Sep 5, 2024

Just curious as there have been various conversations: have we found a use case for the elemental form of f6 yet? When I've drilled in before, it seemed like there might be a case for that but all uses in the wild I could see were for the shared scale forms (which are definitely not a scalar type).

(Edit: I'm not opposed. In prior discussions, I had suggested that a scalar form of these types would likely be convenient at some point in the process, but those I spoke to couldn't identify a use case)

`f6E3M2FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-97118](llvm#97118) [MLIR] Add f8E4M3 type - was used as a template for this PR
@sergey-kozub
Copy link
Contributor Author

have we found a use case for the elemental form of f6 yet?

I'm not sure about that. This type is supposed to be used as a block scaled format, i.e. for reducing the weights size in forward layers.

@sergey-kozub
Copy link
Contributor Author

The PR appears green now, should I press the "Squash and merge" button, or should I let a reviewer do it instead?

@apivovarov
Copy link
Member

The PR appears green now, should I press the "Squash and merge" button, or should I let a reviewer do it instead?

If github allows you to merge - you can merge.

@dcaballe
Copy link
Contributor

Just curious as there have been various conversations: have we found a use case for the elemental form of f6 yet? When I've drilled in before, it seemed like there might be a case for that but all uses in the wild I could see were for the shared scale forms (which are definitely not a scalar type).

I believe the goal is to use the element types as a building block to model block-scaled types. After looking at how block-scaled types are modeled in other projects, I reached the conclusion that we need to build more expertise before introducing the full concept of a block-scaled type in MLIR, if at all. By using element types, we could experiment with block-scaled types by modeling them with separate tensors: one for all the block elements and another for all the block scales. Regardless of our final approach, we would need this basic support in APFloat to handle their specific semantics.

@sergey-kozub sergey-kozub merged commit 918222b into llvm:main Sep 10, 2024
8 checks passed
sergey-kozub added a commit that referenced this pull request Sep 10, 2024
sergey-kozub added a commit to sergey-kozub/llvm-project that referenced this pull request Sep 10, 2024
This PR adds `f6E2M3FN` type to mlir.

`f6E2M3FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
This PR adds `f6E3M2FN` type to mlir.

`f6E3M2FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E3M2. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f6E3M2FN
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat]
Add APFloat support for FP6 data types
- [PR-97118](llvm#97118) [MLIR] Add
f8E4M3 type - was used as a template for this PR
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
sergey-kozub added a commit to sergey-kozub/llvm-project that referenced this pull request Sep 16, 2024
This PR adds `f6E2M3FN` type to mlir.

`f6E2M3FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.

```c
f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125
```

Related PRs:
- [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types
- [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR
sergey-kozub added a commit that referenced this pull request Sep 16, 2024
This PR adds `f6E2M3FN` type to mlir.

`f6E2M3FN` type is proposed in [OpenCompute MX
Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
It defines a 6-bit floating point number with bit layout S1E2M3. Unlike
IEEE-754 types, there are no infinity or NaN values.

```c
f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125
```

Related PRs:
- [PR-94735](#94735) [APFloat]
Add APFloat support for FP6 data types
- [PR-105573](#105573) [MLIR]
Add f6E3M2FN type - was used as a template for this PR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants