Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【paddle_test No.1】replace cc_test with paddle_test #60830

Merged
merged 40 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7ed9156
Update CMakeLists.txt
Liyulingyue Jan 15, 2024
e5e4670
Update CMakeLists.txt
Liyulingyue Jan 16, 2024
77306c3
Update CMakeLists.txt
Liyulingyue Jan 16, 2024
3056242
Update CMakeLists.txt
Liyulingyue Jan 28, 2024
6bedbd1
Merge branch 'develop' into cc1
Liyulingyue Jan 28, 2024
2e2bf71
Update CMakeLists.txt
Liyulingyue Jan 29, 2024
6351c9f
Update op_gen.py
Liyulingyue Jan 29, 2024
68a92d3
Update CMakeLists.txt
Liyulingyue Jan 29, 2024
63d2a70
Merge branch 'develop' into cc1
Liyulingyue Jan 30, 2024
af7f20c
Update CMakeLists.txt
Liyulingyue Jan 30, 2024
b19e188
Apply suggestions from code review
Liyulingyue Jan 30, 2024
57f5052
Update op_gen.py
Liyulingyue Jan 30, 2024
8788524
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Jan 31, 2024
b0685a1
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Feb 1, 2024
126b33b
Update op_gen.py
Liyulingyue Feb 1, 2024
5cbd233
Update op_gen.py
Liyulingyue Feb 1, 2024
af41fe8
Update transform_general_functions.h
Liyulingyue Feb 1, 2024
a7b78ca
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Feb 2, 2024
e4e6c81
Update manual_op.h
Liyulingyue Feb 2, 2024
4314aa0
Update drr_pattern_context.h
Liyulingyue Feb 2, 2024
5f3db99
Update drr_pattern_context.h
Liyulingyue Feb 2, 2024
7a4c908
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Feb 2, 2024
9759d3b
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Feb 2, 2024
eab5a64
Apply suggestions from code review
Liyulingyue Feb 3, 2024
588f009
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Feb 3, 2024
a781777
Merge branch 'PaddlePaddle:develop' into cc1
Liyulingyue Feb 3, 2024
62ef0f9
Update drr_pattern_context.h
Liyulingyue Feb 3, 2024
78cce3d
Update drr_pattern_context.h
Liyulingyue Feb 3, 2024
828fcb7
Update drr_pattern_base.h
Liyulingyue Feb 4, 2024
25c4c36
Update drr_match_context.h
Liyulingyue Feb 4, 2024
aa35989
Update drr_pattern_context.h
Liyulingyue Feb 4, 2024
1aa763b
Update drr_pattern_context.h
Liyulingyue Feb 4, 2024
6dde0cd
Merge branch 'develop' into cc1
Liyulingyue Feb 6, 2024
30c3828
Update drr_match_context.h
Liyulingyue Feb 6, 2024
cbfa108
Apply suggestions from code review
Liyulingyue Feb 6, 2024
527dd40
Apply suggestions from code review
Liyulingyue Feb 6, 2024
8ab27ad
add test_API
Liyulingyue Feb 7, 2024
c2b5e0d
Merge remote-tracking branch 'origin/cc1' into cc1
Liyulingyue Feb 7, 2024
f3f990c
Update CMakeLists.txt
Liyulingyue Feb 7, 2024
f69ccd9
Apply suggestions from code review
Liyulingyue Feb 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# Note(Galaxy1458) The need_export_symbol_op_list is used
# for some unittests these need to export symbol op compiled with dynamic lib.
need_export_symbol_op_list = [
'Add_Op',
'AbsOp',
'FullOp',
'UniformOp',
Expand All @@ -55,12 +56,22 @@
'Conv2dOp',
'BatchNormOp',
'FetchOp',
'FullIntArrayOp',
'MatmulOp',
'SoftmaxOp',
'ReshapeOp',
'TransposeOp',
'LessThanOp',
'LayerNormOp',
'AddGradOp',
'ConcatOp',
'CummaxOp',
'CastOp',
'ReluOp',
'ReluGradOp',
'BatchNorm_Op',
'GeluOp',
'GeluGradOp',
'MatmulGradOp',
]

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,10 @@ class ExpandOp : public pir::Op<ExpandOp,
pir::OperationArgument &argument, // NOLINT
pir::Value x_, // NOLINT
const std::vector<int64_t> &shape = {});
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_, // NOLINT
pir::Value shape_ // NOLINT
TEST_API static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_, // NOLINT
pir::Value shape_ // NOLINT
);
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand Down
27 changes: 14 additions & 13 deletions paddle/fluid/pir/drr/include/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>

#include "paddle/fluid/pir/drr/include/drr_match_context.h"
#include "paddle/utils/test_macros.h"

namespace paddle {
namespace drr {
Expand Down Expand Up @@ -105,7 +106,7 @@ class DrrPatternContext {
DrrPatternContext();
~DrrPatternContext() = default;

drr::SourcePattern SourcePattern();
TEST_API drr::SourcePattern SourcePattern();

std::shared_ptr<SourcePatternGraph> source_pattern_graph() const {
return source_pattern_graph_;
Expand All @@ -121,19 +122,19 @@ class DrrPatternContext {
friend class drr::SourcePattern;
friend class drr::ResultPattern;

const Op& SourceOpPattern(
TEST_API const Op& SourceOpPattern(
const std::string& op_type,
const std::unordered_map<std::string, Attribute>& attributes = {});
const drr::Tensor& SourceTensorPattern(const std::string& name);
TEST_API const drr::Tensor& SourceTensorPattern(const std::string& name);

const Op& ResultOpPattern(
TEST_API const Op& ResultOpPattern(
const std::string& op_type,
const std::unordered_map<std::string, Attribute>& attributes = {});
drr::Tensor& ResultTensorPattern(const std::string& name);
TEST_API drr::Tensor& ResultTensorPattern(const std::string& name);

// void RequireEqual(const Attribute& first, const Attribute& second);
void RequireEqual(const TensorShape& first, const TensorShape& second);
void RequireEqual(const TensorDataType& first, const TensorDataType& second);
TEST_API void RequireEqual(const TensorDataType& first, const TensorDataType& second);
Liyulingyue marked this conversation as resolved.
Show resolved Hide resolved
void RequireNativeCall(const ConstraintFunction& custom_fn);

std::shared_ptr<SourcePatternGraph> source_pattern_graph_;
Expand All @@ -147,17 +148,17 @@ class Op {
public:
const std::string& name() const { return op_type_name_; }

void operator()(const Tensor& arg, const Tensor* out) const;
TEST_API void operator()(const Tensor& arg, const Tensor* out) const;

Tensor& operator()() const;

Tensor& operator()(const Tensor& arg) const;
Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const;
TEST_API Tensor& operator()(const Tensor& arg) const;
TEST_API Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const;
Tensor& operator()(const Tensor& arg0,
const Tensor& arg1,
const Tensor& arg2) const;
void operator()(const std::vector<const Tensor*>& args,
const std::vector<const Tensor*>& outputs) const;
TEST_API void operator()(const std::vector<const Tensor*>& args,
const std::vector<const Tensor*>& outputs) const;
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const
// Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const
// Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor&
Expand Down Expand Up @@ -198,9 +199,9 @@ class Tensor {

bool is_none() const { return name_ == NONE_TENSOR_NAME; }

void Assign(const Tensor& other);
TEST_API void Assign(const Tensor& other);

void operator=(const Tensor& other) const; // NOLINT
TEST_API void operator=(const Tensor& other) const; // NOLINT

const std::string& name() const { return name_; }

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/transform_general_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pir::Type GetDataTypeFromValue(pir::Value value);
*
* @return Operation*
*/
Operation* GetDefiningOpForInput(const Operation* op, uint32_t index);
TEST_API Operation* GetDefiningOpForInput(const Operation* op, uint32_t index);

/**
* @brief Get operations and the index of designative op operand (op result)
Expand Down
31 changes: 7 additions & 24 deletions test/cpp/pir/pattern_rewrite/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,31 +1,14 @@
cc_test(
pattern_rewrite_test
SRCS pattern_rewrite_test.cc
DEPS gtest op_dialect_vjp pir pir_transforms)
paddle_test(pattern_rewrite_test SRCS pattern_rewrite_test.cc)

cc_test(
drr_test
SRCS drr_test.cc
DEPS drr pir_transforms)
paddle_test(drr_test SRCS drr_test.cc)

cc_test(
drr_same_type_binding_test
SRCS drr_same_type_binding_test.cc
DEPS drr gtest op_dialect_vjp pir pir_transforms)
paddle_test(drr_same_type_binding_test SRCS drr_same_type_binding_test.cc)

cc_test(
drr_fuse_linear_test
SRCS drr_fuse_linear_test.cc
DEPS pir_transforms drr gtest op_dialect_vjp pir)
paddle_test(drr_fuse_linear_test SRCS drr_fuse_linear_test.cc)

cc_test(
drr_fuse_linear_param_grad_add_test
SRCS drr_fuse_linear_param_grad_add_test.cc
DEPS pir_transforms drr gtest op_dialect_vjp pir)
paddle_test(drr_fuse_linear_param_grad_add_test SRCS
drr_fuse_linear_param_grad_add_test.cc)

if(WITH_GPU)
cc_test(
drr_attention_fuse_test
SRCS drr_attention_fuse_test.cc
DEPS pir_transforms drr gtest op_dialect_vjp pir)
paddle_test(drr_attention_fuse_test SRCS drr_attention_fuse_test.cc)
endif()