Skip to content

Commit

Permalink
dim
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 20, 2023
1 parent 88f4a5d commit 491b596
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 6 deletions.
27 changes: 27 additions & 0 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ Halide::Internal::Split::SplitType Deserializer::deserialize_split_type(const Ha
}
}

Halide::Internal::DimType Deserializer::deserialize_dim_type(const Halide::Serialize::DimType dim_type) {
switch (dim_type) {
case Halide::Serialize::DimType::DimType_PureVar:
return Halide::Internal::DimType::PureVar;
case Halide::Serialize::DimType::DimType_PureRVar:
return Halide::Internal::DimType::PureRVar;
case Halide::Serialize::DimType::DimType_ImpureRVar:
return Halide::Internal::DimType::ImpureRVar;
default:
std::cerr << "unknown dim type " << dim_type << "\n";
exit(1);
}
}

Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
using Halide::Serialize::TypeCode;
int bits = type->bits();
Expand Down Expand Up @@ -796,6 +810,19 @@ Halide::Internal::Split Deserializer::deserialize_split(const Halide::Serialize:
return hl_split;
}

Halide::Internal::Dim Deserializer::deserialize_dim(const Halide::Serialize::Dim *dim) {
auto var = deserialize_string(dim->var());
auto for_type = deserialize_for_type(dim->for_type());
auto device_api = deserialize_device_api(dim->device_api());
auto dim_type = deserialize_dim_type(dim->dim_type());
auto hl_dim = Halide::Internal::Dim();
hl_dim.var = var;
hl_dim.for_type = for_type;
hl_dim.device_api = device_api;
hl_dim.dim_type = dim_type;
return hl_dim;
}

// TODO: will need to serialize a reverse table of map<address, func_name> to
// later reconstruct a map of <name, func_ptr> find out which function ptrs to use here
// std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
Expand Down
5 changes: 5 additions & 0 deletions apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class Deserializer {

Halide::Internal::Split::SplitType deserialize_split_type(const Halide::Serialize::SplitType split_type);

Halide::Internal::DimType deserialize_dim_type(const Halide::Serialize::DimType dim_type);

std::string deserialize_string(const flatbuffers::String *str);

Halide::Type deserialize_type(const Halide::Serialize::Type *type);
Expand Down Expand Up @@ -71,6 +73,9 @@ class Deserializer {
Halide::Internal::PrefetchDirective deserialize_prefetch_directive(const Halide::Serialize::PrefetchDirective *prefetch_directive);

Halide::Internal::Split deserialize_split(const Halide::Serialize::Split *split);

Halide::Internal::Dim deserialize_dim(const Halide::Serialize::Dim *dim);

// std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);

// std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);
Expand Down
22 changes: 22 additions & 0 deletions apps/serdes/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ Halide::Serialize::SplitType Serializer::serialize_split_type(const Halide::Inte
}
}

Halide::Serialize::DimType Serializer::serialize_dim_type(const Halide::Internal::DimType &dim_type) {
switch (dim_type) {
case Halide::Internal::DimType::PureVar:
return Halide::Serialize::DimType::DimType_PureVar;
case Halide::Internal::DimType::PureRVar:
return Halide::Serialize::DimType::DimType_PureRVar;
case Halide::Internal::DimType::ImpureRVar:
return Halide::Serialize::DimType::DimType_ImpureRVar;
default:
std::cerr << "Unsupported dim type\n";
exit(1);
}
}

flatbuffers::Offset<flatbuffers::String> Serializer::serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str) {
return builder.CreateString(str);
}
Expand Down Expand Up @@ -816,6 +830,14 @@ flatbuffers::Offset<Halide::Serialize::Split> Serializer::serialize_split(flatbu
return Halide::Serialize::CreateSplit(builder, old_var_serialized, outer_serialized, inner_serialized, factor_serialized.first, factor_serialized.second, tail_serialized, inner_to_outer_serialized);
}

flatbuffers::Offset<Halide::Serialize::Dim> Serializer::serialize_dim(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Dim &dim) {
auto var_serialized = serialize_string(builder, dim.var);
auto for_type_serialized = serialize_for_type(dim.for_type);
auto device_api_serialized = serialize_device_api(dim.device_api);
auto dim_type_serialized = serialize_dim_type(dim.dim_type);
return Halide::Serialize::CreateDim(builder, var_serialized, for_type_serialized, device_api_serialized, dim_type_serialized);
}

// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> Serializer::serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers) {
// // instead of storing the function pointer or raw function address,
// // we store a pre-computed function index as the serialized format for WrapperRef
Expand Down
4 changes: 4 additions & 0 deletions apps/serdes/Serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Serializer {

Halide::Serialize::SplitType serialize_split_type(const Halide::Internal::Split::SplitType &split_type);

Halide::Serialize::DimType serialize_dim_type(const Halide::Internal::DimType &dim_type);

flatbuffers::Offset<flatbuffers::String> serialize_string(flatbuffers::FlatBufferBuilder &builder, const std::string &str);

flatbuffers::Offset<Halide::Serialize::Type> serialize_type(flatbuffers::FlatBufferBuilder &builder, const Halide::Type &type);
Expand Down Expand Up @@ -73,6 +75,8 @@ class Serializer {

flatbuffers::Offset<Halide::Serialize::Split> serialize_split(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Split &split);

flatbuffers::Offset<Halide::Serialize::Dim> serialize_dim(flatbuffers::FlatBufferBuilder &builder, const Halide::Internal::Dim &dim);

// std::vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> serialize_wrapper_refs(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, Halide::Internal::FunctionPtr> &wrappers);

// std::vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> serialize_func_mappings(flatbuffers::FlatBufferBuilder &builder, const std::map<std::string, int32_t> &func_mappings);
Expand Down
18 changes: 12 additions & 6 deletions apps/serdes/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,18 @@ table Split {
split_type: SplitType;
}

// table Dim {
// var: string;
// for_type: ForType;
// device_api: DeviceAPI;
// dim_type: DimType;
// }
enum DimType: ubyte {
PureVar,
PureRVar,
ImpureRVar,
}

table Dim {
var: string;
for_type: ForType;
device_api: DeviceAPI;
dim_type: DimType;
}

// enum LoopAlignStrategy: ubyte {
// AlignStart,
Expand Down

0 comments on commit 491b596

Please sign in to comment.