Skip to content

Commit

Permalink
update documentation and use macro
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
Co-authored-by: Pratik Nayak <pratikvn@protonmail.com>
  • Loading branch information
3 people committed Aug 23, 2024
1 parent 1f825ee commit 7d1ef72
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 74 deletions.
80 changes: 19 additions & 61 deletions core/config/type_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ginkgo/core/config/type_descriptor.hpp"

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/types.hpp>

#include "core/config/type_descriptor_helper.hpp"

Expand Down Expand Up @@ -45,67 +46,19 @@ type_descriptor make_type_descriptor()
type_string<GlobalIndexType>::str()};
}

// global_index: void
template type_descriptor make_type_descriptor<void, void, void>();
template type_descriptor make_type_descriptor<float, void, void>();
template type_descriptor make_type_descriptor<double, void, void>();
template type_descriptor
make_type_descriptor<std::complex<float>, void, void>();
template type_descriptor
make_type_descriptor<std::complex<double>, void, void>();
template type_descriptor make_type_descriptor<void, int32, void>();
template type_descriptor make_type_descriptor<float, int32, void>();
template type_descriptor make_type_descriptor<double, int32, void>();
template type_descriptor
make_type_descriptor<std::complex<float>, int32, void>();
template type_descriptor
make_type_descriptor<std::complex<double>, int32, void>();
template type_descriptor make_type_descriptor<void, int64, void>();
template type_descriptor make_type_descriptor<float, int64, void>();
template type_descriptor make_type_descriptor<double, int64, void>();
template type_descriptor
make_type_descriptor<std::complex<float>, int64, void>();
template type_descriptor
make_type_descriptor<std::complex<double>, int64, void>();

// global_index int32
template type_descriptor make_type_descriptor<void, void, int32>();
template type_descriptor make_type_descriptor<float, void, int32>();
template type_descriptor make_type_descriptor<double, void, int32>();
template type_descriptor
make_type_descriptor<std::complex<float>, void, int32>();
template type_descriptor
make_type_descriptor<std::complex<double>, void, int32>();
template type_descriptor make_type_descriptor<void, int32, int32>();
template type_descriptor make_type_descriptor<float, int32, int32>();
template type_descriptor make_type_descriptor<double, int32, int32>();
template type_descriptor
make_type_descriptor<std::complex<float>, int32, int32>();
template type_descriptor
make_type_descriptor<std::complex<double>, int32, int32>();

// global_index_type int64
template type_descriptor make_type_descriptor<void, void, int64>();
template type_descriptor make_type_descriptor<float, void, int64>();
template type_descriptor make_type_descriptor<double, void, int64>();
template type_descriptor
make_type_descriptor<std::complex<float>, void, int64>();
template type_descriptor
make_type_descriptor<std::complex<double>, void, int64>();
template type_descriptor make_type_descriptor<void, int32, int64>();
template type_descriptor make_type_descriptor<float, int32, int64>();
template type_descriptor make_type_descriptor<double, int32, int64>();
template type_descriptor
make_type_descriptor<std::complex<float>, int32, int64>();
template type_descriptor
make_type_descriptor<std::complex<double>, int32, int64>();
template type_descriptor make_type_descriptor<void, int64, int64>();
template type_descriptor make_type_descriptor<float, int64, int64>();
template type_descriptor make_type_descriptor<double, int64, int64>();
template type_descriptor
make_type_descriptor<std::complex<float>, int64, int64>();
template type_descriptor
make_type_descriptor<std::complex<double>, int64, int64>();
#define GKO_DECLARE_MAKE_TYPE_DESCRIPTOR(ValueType, LocalIndexType, \
GlobalIndexType) \
type_descriptor \
make_type_descriptor<ValueType, LocalIndexType, GlobalIndexType>()
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_MAKE_TYPE_DESCRIPTOR);

#define GKO_DECLARE_MAKE_VOID_TYPE_DESCRIPTOR(ValueType, LocalIndexType, \
GlobalIndexType) \
type_descriptor \
make_type_descriptor<void, LocalIndexType, GlobalIndexType>()
GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_MAKE_VOID_TYPE_DESCRIPTOR);


type_descriptor::type_descriptor(std::string value_typestr,
Expand All @@ -126,6 +79,11 @@ const std::string& type_descriptor::get_index_typestr() const
return index_typestr_;
}

const std::string& type_descriptor::get_local_index_typestr() const
{
return this->get_index_typestr();
}

const std::string& type_descriptor::get_global_index_typestr() const
{
return global_index_typestr_;
Expand Down
40 changes: 35 additions & 5 deletions core/test/config/type_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ TEST(TypeDescriptor, TemplateCreate)

ASSERT_EQ(td.get_value_typestr(), "float64");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
{
Expand All @@ -29,6 +30,7 @@ TEST(TypeDescriptor, TemplateCreate)

ASSERT_EQ(td.get_value_typestr(), "float32");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int32");
}
{
Expand All @@ -37,24 +39,37 @@ TEST(TypeDescriptor, TemplateCreate)

ASSERT_EQ(td.get_value_typestr(), "float32");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
{
SCOPED_TRACE("specify all template");
SCOPED_TRACE("specify local index template");
auto td =
make_type_descriptor<std::complex<float>, gko::int64, gko::int64>();

ASSERT_EQ(td.get_value_typestr(), "complex<float32>");
ASSERT_EQ(td.get_index_typestr(), "int64");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
{
SCOPED_TRACE("specify global index template");
auto td =
make_type_descriptor<std::complex<float>, gko::int32, gko::int32>();

ASSERT_EQ(td.get_value_typestr(), "complex<float32>");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int32");
}
{
SCOPED_TRACE("specify void");
auto td = make_type_descriptor<void, void, void>();
auto td = make_type_descriptor<void>();

ASSERT_EQ(td.get_value_typestr(), "void");
ASSERT_EQ(td.get_index_typestr(), "void");
ASSERT_EQ(td.get_global_index_typestr(), "void");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
}

Expand All @@ -67,19 +82,34 @@ TEST(TypeDescriptor, Constructor)

ASSERT_EQ(td.get_value_typestr(), "float64");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
{
SCOPED_TRACE("specify valuetype");
type_descriptor td("float32");

ASSERT_EQ(td.get_value_typestr(), "float32");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
{
SCOPED_TRACE("specify all parameters");
SCOPED_TRACE("specify local index parameters");
type_descriptor td("void", "int64");

ASSERT_EQ(td.get_value_typestr(), "void");
ASSERT_EQ(td.get_index_typestr(), "int64");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int64");
}
{
SCOPED_TRACE("specify global index parameters");
type_descriptor td("void", "int32", "int32");

ASSERT_EQ(td.get_value_typestr(), "void");
ASSERT_EQ(td.get_index_typestr(), "int32");
ASSERT_EQ(td.get_local_index_typestr(), td.get_index_typestr());
ASSERT_EQ(td.get_global_index_typestr(), "int32");
}
}
22 changes: 14 additions & 8 deletions include/ginkgo/core/config/type_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ namespace config {
* auto cg = parse(config, context, type_descriptor("float64", "int32"));
* ```
* will have the value type `float64` and the index type `int32`. Any Ginkgo
* object that does not require one of these types will just ignore it. The
* value `void` can be used to specify that no default type is provided. In this
* case, the configuration has to provide the necessary template types.
* object that does not require one of these types will just ignore it. In
* value_type, one additional value `void` can be used to specify that no
* default type is provided. In this case, the configuration has to provide the
* necessary template types.
*
* If the configuration specifies one field (only allow value_type now):
* ```
* value_type: "some_value_type"
* ```
* these types will take precedence over the type_descriptor.
* this type will take precedence over the type_descriptor.
*/
class type_descriptor final {
public:
Expand All @@ -42,9 +43,8 @@ class type_descriptor final {
* `make_type_descriptor` to create the object by template.
*
* @param value_typestr the value type string. "void" means no default.
* @param index_typestr the index type string. "void" means no default.
* @param global_index_typestr the global index type string. "void" means
* no default.
* @param index_typestr the (local) index type string.
* @param global_index_typestr the global index type string.
*
* @note there is no way to call the constructor with explicit template, so
* we create another free function to handle it.
Expand All @@ -63,6 +63,12 @@ class type_descriptor final {
*/
const std::string& get_index_typestr() const;

/**
* Get the local index type string, which gives the same result as
* get_index_typestr()
*/
const std::string& get_local_index_typestr() const;

/**
* Get the global index type string
*/
Expand All @@ -83,7 +89,7 @@ class type_descriptor final {
* @tparam IndexType the index type in descriptor
* @tparam GlobalIndexType the global index type in descriptor
*/
template <typename ValueType = double, typename IndexType = int,
template <typename ValueType = double, typename IndexType = int32,
typename GlobalIndexType = int64>
type_descriptor make_type_descriptor();

Expand Down

0 comments on commit 7d1ef72

Please sign in to comment.