diff --git a/include/cppflow/datatype.h b/include/cppflow/datatype.h index 712f767..a02a273 100644 --- a/include/cppflow/datatype.h +++ b/include/cppflow/datatype.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace cppflow { @@ -101,6 +102,10 @@ namespace cppflow { return TF_UINT32; if (std::is_same::value) return TF_UINT64; + if (std::is_same>::value) + return TF_COMPLEX64; + if (std::is_same>::value) + return TF_COMPLEX128; // decode with `c++filt --type $output` for gcc throw std::runtime_error{"Could not deduce type! type_name: " + std::string(typeid(T).name())};