From 491b59618cbad0ca9d53c70c53247ec3c85c6e13 Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Tue, 20 Jun 2023 16:28:04 -0700 Subject: [PATCH] dim --- apps/serdes/Deserializer.cpp | 27 +++++++++++++++++++++++++++ apps/serdes/Deserializer.h | 5 +++++ apps/serdes/Serializer.cpp | 22 ++++++++++++++++++++++ apps/serdes/Serializer.h | 4 ++++ apps/serdes/halide_ir.fbs | 18 ++++++++++++------ 5 files changed, 70 insertions(+), 6 deletions(-) diff --git a/apps/serdes/Deserializer.cpp b/apps/serdes/Deserializer.cpp index 48694b0dbc27..51e0ccaf366f 100644 --- a/apps/serdes/Deserializer.cpp +++ b/apps/serdes/Deserializer.cpp @@ -194,6 +194,20 @@ Halide::Internal::Split::SplitType Deserializer::deserialize_split_type(const Ha } } +Halide::Internal::DimType Deserializer::deserialize_dim_type(const Halide::Serialize::DimType dim_type) { + switch (dim_type) { + case Halide::Serialize::DimType::DimType_PureVar: + return Halide::Internal::DimType::PureVar; + case Halide::Serialize::DimType::DimType_PureRVar: + return Halide::Internal::DimType::PureRVar; + case Halide::Serialize::DimType::DimType_ImpureRVar: + return Halide::Internal::DimType::ImpureRVar; + default: + std::cerr << "unknown dim type " << dim_type << "\n"; + exit(1); + } +} + Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) { using Halide::Serialize::TypeCode; int bits = type->bits(); @@ -796,6 +810,19 @@ Halide::Internal::Split Deserializer::deserialize_split(const Halide::Serialize: return hl_split; } +Halide::Internal::Dim Deserializer::deserialize_dim(const Halide::Serialize::Dim *dim) { + auto var = deserialize_string(dim->var()); + auto for_type = deserialize_for_type(dim->for_type()); + auto device_api = deserialize_device_api(dim->device_api()); + auto dim_type = deserialize_dim_type(dim->dim_type()); + auto hl_dim = Halide::Internal::Dim(); + hl_dim.var = var; + hl_dim.for_type = for_type; + hl_dim.device_api = device_api; + hl_dim.dim_type = dim_type; + return hl_dim; +} + // TODO: will need to serialize a reverse table of map to // later reconstruct a map of find out which function ptrs to use here // std::map Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs) { diff --git a/apps/serdes/Deserializer.h b/apps/serdes/Deserializer.h index 2c509b8295a7..9693bf1e98c5 100644 --- a/apps/serdes/Deserializer.h +++ b/apps/serdes/Deserializer.h @@ -36,6 +36,8 @@ class Deserializer { Halide::Internal::Split::SplitType deserialize_split_type(const Halide::Serialize::SplitType split_type); + Halide::Internal::DimType deserialize_dim_type(const Halide::Serialize::DimType dim_type); + std::string deserialize_string(const flatbuffers::String *str); Halide::Type deserialize_type(const Halide::Serialize::Type *type); @@ -71,6 +73,9 @@ class Deserializer { Halide::Internal::PrefetchDirective deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive); Halide::Internal::Split deserialize_split(const Halide::Serialize::Split *split); + + Halide::Internal::Dim deserialize_dim(const Halide::Serialize::Dim *dim); + // std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrapper_refs); // std::map deserialize_func_mappings(const flatbuffers::Vector> *func_mappings); diff --git a/apps/serdes/Serializer.cpp b/apps/serdes/Serializer.cpp index a3f6262caddd..fa74f85778cb 100644 --- a/apps/serdes/Serializer.cpp +++ b/apps/serdes/Serializer.cpp @@ -193,6 +193,20 @@ Halide::Serialize::SplitType Serializer::serialize_split_type(const Halide::Inte } } +Halide::Serialize::DimType Serializer::serialize_dim_type(const Halide::Internal::DimType &dim_type) { + switch (dim_type) { + case Halide::Internal::DimType::PureVar: + return Halide::Serialize::DimType::DimType_PureVar; + case Halide::Internal::DimType::PureRVar: + return Halide::Serialize::DimType::DimType_PureRVar; + case Halide::Internal::DimType::ImpureRVar: + return Halide::Serialize::DimType::DimType_ImpureRVar; + default: + std::cerr << "Unsupported dim type\n"; + exit(1); + } +} + flatbuffers::Offset Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) { return builder.CreateString(str); } @@ -816,6 +830,14 @@ flatbuffers::Offset Serializer::serialize_split(flatbu return Halide::Serialize::CreateSplit(builder, old_var_serialized, outer_serialized, inner_serialized, factor_serialized.first, factor_serialized.second, tail_serialized, inner_to_outer_serialized); } +flatbuffers::Offset Serializer::serialize_dim(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Dim &dim) { + auto var_serialized = serialize_string(builder, dim.var); + auto for_type_serialized = serialize_for_type(dim.for_type); + auto device_api_serialized = serialize_device_api(dim.device_api); + auto dim_type_serialized = serialize_dim_type(dim.dim_type); + return Halide::Serialize::CreateDim(builder, var_serialized, for_type_serialized, device_api_serialized, dim_type_serialized); +} + // std::vector> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers) { // // instead of storing the function pointer or raw function address, // // we store a pre-computed function index as the serialized format for WrapperRef diff --git a/apps/serdes/Serializer.h b/apps/serdes/Serializer.h index 238e9920b556..81b343c977e2 100644 --- a/apps/serdes/Serializer.h +++ b/apps/serdes/Serializer.h @@ -37,6 +37,8 @@ class Serializer { Halide::Serialize::SplitType serialize_split_type(const Halide::Internal::Split::SplitType &split_type); + Halide::Serialize::DimType serialize_dim_type(const Halide::Internal::DimType &dim_type); + flatbuffers::Offset serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str); flatbuffers::Offset serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type); @@ -73,6 +75,8 @@ class Serializer { flatbuffers::Offset serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split); + flatbuffers::Offset serialize_dim(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Dim &dim); + // std::vector> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map &wrappers); // std::vector> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map &func_mappings); diff --git a/apps/serdes/halide_ir.fbs b/apps/serdes/halide_ir.fbs index 3044b6722d3e..8dd0824ee390 100644 --- a/apps/serdes/halide_ir.fbs +++ b/apps/serdes/halide_ir.fbs @@ -523,12 +523,18 @@ table Split { split_type: SplitType; } -// table Dim { -// var: string; -// for_type: ForType; -// device_api: DeviceAPI; -// dim_type: DimType; -// } +enum DimType: ubyte { + PureVar, + PureRVar, + ImpureRVar, +} + +table Dim { + var: string; + for_type: ForType; + device_api: DeviceAPI; + dim_type: DimType; +} // enum LoopAlignStrategy: ubyte { // AlignStart,