Skip to content

Commit

Permalink
ModulusRemainder and VectorReduceOp, some minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 20, 2023
1 parent db47e2c commit 63ac57a
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 197 deletions.
37 changes: 33 additions & 4 deletions apps/serdes/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,28 @@ Halide::Internal::Call::CallType Deserializer::deserialize_call_type(const Halid
}
}

Halide::Internal::VectorReduce::Operator Deserializer::deserialize_vector_reduce_op(const Halide::Serialize::VectorReduceOp vector_reduce_op) {
switch (vector_reduce_op) {
case Halide::Serialize::VectorReduceOp::VectorReduceOp_Add:
return Halide::Internal::VectorReduce::Operator::Add;
case Halide::Serialize::VectorReduceOp::VectorReduceOp_SaturatingAdd:
return Halide::Internal::VectorReduce::Operator::SaturatingAdd;
case Halide::Serialize::VectorReduceOp::VectorReduceOp_Mul:
return Halide::Internal::VectorReduce::Operator::Mul;
case Halide::Serialize::VectorReduceOp::VectorReduceOp_Min:
return Halide::Internal::VectorReduce::Operator::Min;
case Halide::Serialize::VectorReduceOp::VectorReduceOp_Max:
return Halide::Internal::VectorReduce::Operator::Max;
case Halide::Serialize::VectorReduceOp::VectorReduceOp_And:
return Halide::Internal::VectorReduce::Operator::And;
case Halide::Serialize::VectorReduceOp::VectorReduceOp_Or:
return Halide::Internal::VectorReduce::Operator::Or;
default:
std::cerr << "unknown vector reduce op " << vector_reduce_op << "\n";
exit(1);
}
}

Halide::Type Deserializer::deserialize_type(const Halide::Serialize::Type *type) {
using Halide::Serialize::TypeCode;
int bits = type->bits();
Expand Down Expand Up @@ -207,7 +229,8 @@ Halide::Internal::Stmt Deserializer::deserialize_stmt(uint8_t type_code, const v
auto predicate = deserialize_expr(store_stmt->predicate_type(), store_stmt->predicate());
auto value = deserialize_expr(store_stmt->value_type(), store_stmt->value());
auto index = deserialize_expr(store_stmt->index_type(), store_stmt->index());
return Halide::Internal::Store::make(name, value, index, Halide::Internal::Parameter(), predicate, Halide::Internal::ModulusRemainder());
auto alignment = deserialize_modulus_remainder(store_stmt->alignment());
return Halide::Internal::Store::make(name, value, index, Halide::Internal::Parameter(), predicate, alignment);
}
case Halide::Serialize::Stmt_Provide: {
const Halide::Serialize::Provide *provide_stmt = (const Halide::Serialize::Provide *)stmt;
Expand Down Expand Up @@ -460,7 +483,8 @@ Halide::Expr Deserializer::deserialize_expr(uint8_t type_code, const void *expr)
auto name = deserialize_string(load_expr->name());
auto predicate = deserialize_expr(load_expr->predicate_type(), load_expr->predicate());
auto index = deserialize_expr(load_expr->index_type(), load_expr->index());
return Halide::Internal::Load::make(Halide::Int(64), name, index, Halide::Buffer<float, 3>(), Halide::Internal::Parameter(), predicate, Halide::Internal::ModulusRemainder());
auto alignment = deserialize_modulus_remainder(load_expr->alignment());
return Halide::Internal::Load::make(Halide::Int(64), name, index, Halide::Buffer<float, 3>(), Halide::Internal::Parameter(), predicate, alignment);
}
case Halide::Serialize::Expr::Expr_Ramp: {
const Halide::Serialize::Ramp *ramp_expr = (const Halide::Serialize::Ramp *)expr;
Expand Down Expand Up @@ -511,8 +535,9 @@ Halide::Expr Deserializer::deserialize_expr(uint8_t type_code, const void *expr)
case Halide::Serialize::Expr::Expr_VectorReduce: {
const Halide::Serialize::VectorReduce *vector_reduce_expr = (const Halide::Serialize::VectorReduce *)expr;
auto value = deserialize_expr(vector_reduce_expr->value_type(), vector_reduce_expr->value());
// TODO: fix op here and store lanes during serialization
return Halide::Internal::VectorReduce::make(Halide::Internal::VectorReduce::Operator::Add, value, 16);
auto reduction_op = deserialize_vector_reduce_op(vector_reduce_expr->reduction_op());
// TODO: store lanes during serialization
return Halide::Internal::VectorReduce::make(reduction_op, value, 16);
}
case Halide::Serialize::Expr::Expr_UndefinedExpr: {
return Halide::Expr();
Expand Down Expand Up @@ -656,6 +681,10 @@ Halide::Internal::ReductionDomain Deserializer::deserialize_reduction_domain(con
return Halide::Internal::ReductionDomain(domain, predicate, frozen);
}

Halide::Internal::ModulusRemainder Deserializer::deserialize_modulus_remainder(const Halide::Serialize::ModulusRemainder *modulus_remainder) {
return Halide::Internal::ModulusRemainder(modulus_remainder->modulus(), modulus_remainder->remainder());
}

// TODO: will need to serialize a reverse table of map<address, func_name> to
// later reconstruct a map of <name, func_ptr> find out which function ptrs to use here
// std::map<std::string, Halide::Internal::FunctionPtr> Deserializer::deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs) {
Expand Down
6 changes: 5 additions & 1 deletion apps/serdes/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ class Deserializer {

Halide::Internal::ForType deserialize_for_type(const Halide::Serialize::ForType for_type);

Halide::DeviceAPI deserialize_device_api(const Halide::Serialize::DeviceAPI device_api);

Halide::Internal::Call::CallType deserialize_call_type(const Halide::Serialize::CallType call_type);

Halide::DeviceAPI deserialize_device_api(const Halide::Serialize::DeviceAPI device_api);
Halide::Internal::VectorReduce::Operator deserialize_vector_reduce_op(const Halide::Serialize::VectorReduceOp vector_reduce_op);

std::string deserialize_string(const flatbuffers::String *str);

Expand Down Expand Up @@ -56,6 +58,8 @@ class Deserializer {

Halide::Internal::ReductionDomain deserialize_reduction_domain(const Halide::Serialize::ReductionDomain *reduction_domain);

Halide::Internal::ModulusRemainder deserialize_modulus_remainder(const Halide::Serialize::ModulusRemainder *modulus_remainder);

// std::map<std::string, Halide::Internal::FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::WrapperRef>> *wrapper_refs);

// std::map<std::string, int32_t> deserialize_func_mappings(const flatbuffers::Vector<flatbuffers::Offset<Halide::Serialize::FuncMapping>> *func_mappings);
Expand Down
Loading

0 comments on commit 63ac57a

Please sign in to comment.