Skip to content

Commit

Permalink
testing 101
Browse files Browse the repository at this point in the history
  • Loading branch information
TH3CHARLie committed Jun 27, 2023
1 parent ad6eb06 commit ad54aeb
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ set(HEADER_FILES
Introspection.h
IntrusivePtr.h
IR.h
IRComparator.h
IREquality.h
IRMatch.h
IRMutator.h
Expand Down Expand Up @@ -256,6 +257,7 @@ set(SOURCE_FILES
Interval.cpp
Introspection.cpp
IR.cpp
IRComparator.cpp
IREquality.cpp
IRMatch.cpp
IRMutator.cpp
Expand Down
67 changes: 67 additions & 0 deletions src/IRComparator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "IRComparator.h"
#include "FindCalls.h"
#include "Func.h"
#include "Function.h"
#include "IREquality.h"
#include "IRVisitor.h"
#include <map>
#include <string>

namespace Halide {
namespace Internal {
class IRComparator {
public:
IRComparator() = default;

bool compare_pipeline(const Pipeline &p1, const Pipeline &p2);

private:
bool compare_function(const Function &f1, const Function &f2);
};

bool IRComparator::compare_pipeline(const Pipeline &p1, const Pipeline &p2) {
std::map<std::string, Function> p1_env, p2_env;
for (const Func &func : p1.outputs()) {
const Halide::Internal::Function &f = func.function();
std::map<std::string, Halide::Internal::Function> more_funcs = find_transitive_calls(f);
p1_env.insert(more_funcs.begin(), more_funcs.end());
}
for (const Func &func : p2.outputs()) {
const Halide::Internal::Function &f = func.function();
std::map<std::string, Halide::Internal::Function> more_funcs = find_transitive_calls(f);
p2_env.insert(more_funcs.begin(), more_funcs.end());
}
if (p1_env.size() != p2_env.size()) {
return false;
}
for (auto it = p1_env.begin(); it != p1_env.end(); it++) {
if (p2_env.find(it->first) == p2_env.end()) {
return false;
}
if (!compare_function(it->second, p2_env[it->first])) {
return false;
}
}
for (size_t i = 0; i < p1.requirements().size() && i < p2.requirements().size(); i++) {
if (!equal(p1.requirements()[i], p2.requirements()[i])) {
return false;
}
}
return true;
}

bool IRComparator::compare_function(const Function &f1, const Function &f2) {
if (f1.name() != f2.name()) {
return false;
}
if (f1.origin_name() != f2.origin_name()) {
return false;
}
return true;
}

bool equal(const Halide::Pipeline &p1, const Halide::Pipeline &p2) {
return IRComparator().compare_pipeline(p1, p2);
}
} // namespace Internal
} // namespace Halide
11 changes: 11 additions & 0 deletions src/IRComparator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifndef HALIDE_IRCOMPARATOR_H_
#define HALIDE_IRCOMPARATOR_H_
#include "Pipeline.h"

namespace Halide {
namespace Internal {
bool equal(const Pipeline &p1, const Pipeline &p2);
}
} // namespace Halide

#endif // HALIDE_SERIALIZATION_IRCOMPARATOR_H_
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ else ()
endif ()


if (WITH_SERIALIZATION)
if (BUILD_SERIALIZATION)
message(STATUS "Building serialization tests enabled")
add_subdirectory(serialization)
else ()
Expand Down
3 changes: 2 additions & 1 deletion test/serialization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
tests(GROUPS roundtrip
SOURCES
roundtrip.cpp
single_func_pipe.cpp
multiple_func_pipe.cpp
)

set_tests_properties(${TEST_NAMES} PROPERTIES RUN_SERIAL TRUE)
Expand Down
File renamed without changes.
22 changes: 22 additions & 0 deletions test/serialization/single_func_pipe.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;
using namespace Halide::Internal;

int main(int argc, char **argv) {
Func gradient("gradient_func");
Var x, y;
gradient(x, y) = x + y;
Pipeline pipe(gradient);

Serializer serializer;
serializer.serialize(pipe, "single_func_pipe.hlpipe");
Deserializer deserializer;
Pipeline deserialized_pipe = deserializer.deserialize("single_func_pipe.hlpipe");
bool result = equal(pipe, deserialized_pipe);

assert(result == true);
printf("Success!\n");
return 0;
}
2 changes: 1 addition & 1 deletion tools/build_halide_h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <set>
#include <string>
#include <iostream>

std::set<std::string> done;

Expand Down

0 comments on commit ad54aeb

Please sign in to comment.