Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenACC] Implement 'reduction' sema for compute constructs #92808

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,35 @@ class OpenACCCreateClause final
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
};

class OpenACCReductionClause final
: public OpenACCClauseWithVarList,
public llvm::TrailingObjects<OpenACCReductionClause, Expr *> {
OpenACCReductionOperator Op;

OpenACCReductionClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
OpenACCReductionOperator Operator,
ArrayRef<Expr *> VarList, SourceLocation EndLoc)
: OpenACCClauseWithVarList(OpenACCClauseKind::Reduction, BeginLoc,
LParenLoc, EndLoc),
Op(Operator) {
std::uninitialized_copy(VarList.begin(), VarList.end(),
getTrailingObjects<Expr *>());
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
}

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Reduction;
}

static OpenACCReductionClause *
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
SourceLocation EndLoc);

OpenACCReductionOperator getReductionOp() const { return Op; }
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand Down
18 changes: 16 additions & 2 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12343,7 +12343,8 @@ def err_acc_num_gangs_num_args
"provided}0">;
def err_acc_not_a_var_ref
: Error<"OpenACC variable is not a valid variable name, sub-array, array "
"element, or composite variable member">;
"element,%select{| member of a composite variable,}0 or composite "
"variable member">;
def err_acc_typecheck_subarray_value
: Error<"OpenACC sub-array subscripted value is not an array or pointer">;
def err_acc_subarray_function_type
Expand Down Expand Up @@ -12374,5 +12375,18 @@ def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
def err_acc_clause_after_device_type
: Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
"compute construct">;

def err_acc_reduction_num_gangs_conflict
: Error<
"OpenACC 'reduction' clause may not appear on a 'parallel' construct "
"with a 'num_gangs' clause with more than 1 argument, have %0">;
def err_acc_reduction_type
: Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
"composite of scalar types;%select{| sub-array base}1 type is %0">;
def err_acc_reduction_composite_type
: Error<"OpenACC 'reduction' variable must be a composite of scalar types; "
"%1 %select{is not a class or struct|is incomplete|is not an "
"aggregate}0">;
def err_acc_reduction_composite_member_type :Error<
"OpenACC 'reduction' composite variable must not have non-scalar field">;
def note_acc_reduction_composite_member_loc : Note<"invalid field is here">;
} // end of sema component.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ VISIT_CLAUSE(NumGangs)
VISIT_CLAUSE(NumWorkers)
VISIT_CLAUSE(Present)
VISIT_CLAUSE(Private)
VISIT_CLAUSE(Reduction)
VISIT_CLAUSE(Self)
VISIT_CLAUSE(VectorLength)
VISIT_CLAUSE(Wait)
Expand Down
36 changes: 36 additions & 0 deletions clang/include/clang/Basic/OpenACCKinds.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,42 @@ enum class OpenACCReductionOperator {
/// Invalid Reduction Clause Kind.
Invalid,
};

template <typename StreamTy>
inline StreamTy &printOpenACCReductionOperator(StreamTy &Out,
OpenACCReductionOperator Op) {
switch (Op) {
case OpenACCReductionOperator::Addition:
return Out << "+";
case OpenACCReductionOperator::Multiplication:
return Out << "*";
case OpenACCReductionOperator::Max:
return Out << "max";
case OpenACCReductionOperator::Min:
return Out << "min";
case OpenACCReductionOperator::BitwiseAnd:
return Out << "&";
case OpenACCReductionOperator::BitwiseOr:
return Out << "|";
case OpenACCReductionOperator::BitwiseXOr:
return Out << "^";
case OpenACCReductionOperator::And:
return Out << "&&";
case OpenACCReductionOperator::Or:
return Out << "||";
case OpenACCReductionOperator::Invalid:
return Out << "<invalid>";
}
llvm_unreachable("Unknown reduction operator kind");
}
inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
OpenACCReductionOperator Op) {
return printOpenACCReductionOperator(Out, Op);
}
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
OpenACCReductionOperator Op) {
return printOpenACCReductionOperator(Out, Op);
}
} // namespace clang

#endif // LLVM_CLANG_BASIC_OPENACCKINDS_H
4 changes: 2 additions & 2 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3686,9 +3686,9 @@ class Parser : public CodeCompletionHandler {

using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
/// Parses a single variable in a variable list for OpenACC.
OpenACCVarParseResult ParseOpenACCVar();
OpenACCVarParseResult ParseOpenACCVar(OpenACCClauseKind CK);
/// Parses the variable list for the variety of places that take a var-list.
llvm::SmallVector<Expr *> ParseOpenACCVarList();
llvm::SmallVector<Expr *> ParseOpenACCVarList(OpenACCClauseKind CK);
/// Parses any parameters for an OpenACC Clause, including required/optional
/// parens.
OpenACCClauseParseResult
Expand Down
29 changes: 27 additions & 2 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ class SemaOpenACC : public SemaBase {
struct DeviceTypeDetails {
SmallVector<DeviceTypeArgument> Archs;
};
struct ReductionDetails {
OpenACCReductionOperator Op;
SmallVector<Expr *> VarList;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
ReductionDetails>
Details = std::monostate{};

public:
Expand Down Expand Up @@ -170,6 +175,10 @@ class SemaOpenACC : public SemaBase {
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
}

OpenACCReductionOperator getReductionOp() const {
return std::get<ReductionDetails>(Details).Op;
}

ArrayRef<Expr *> getVarList() {
assert((ClauseKind == OpenACCClauseKind::Private ||
ClauseKind == OpenACCClauseKind::NoCreate ||
Expand All @@ -188,8 +197,13 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::PresentOrCreate ||
ClauseKind == OpenACCClauseKind::Attach ||
ClauseKind == OpenACCClauseKind::DevicePtr ||
ClauseKind == OpenACCClauseKind::Reduction ||
ClauseKind == OpenACCClauseKind::FirstPrivate) &&
"Parsed clause kind does not have a var-list");

if (ClauseKind == OpenACCClauseKind::Reduction)
return std::get<ReductionDetails>(Details).VarList;

return std::get<VarListDetails>(Details).VarList;
}

Expand Down Expand Up @@ -334,6 +348,13 @@ class SemaOpenACC : public SemaBase {
Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
}

void setReductionDetails(OpenACCReductionOperator Op,
llvm::SmallVector<Expr *> &&VarList) {
assert(ClauseKind == OpenACCClauseKind::Reduction &&
"reduction details only valid on reduction");
Details = ReductionDetails{Op, std::move(VarList)};
}

void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
llvm::SmallVector<Expr *> &&IntExprs) {
assert(ClauseKind == OpenACCClauseKind::Wait &&
Expand Down Expand Up @@ -394,7 +415,11 @@ class SemaOpenACC : public SemaBase {

/// Called when encountering a 'var' for OpenACC, ensures it is actually a
/// declaration reference to a variable of the correct type.
ExprResult ActOnVar(Expr *VarExpr);
ExprResult ActOnVar(OpenACCClauseKind CK, Expr *VarExpr);

/// Called while semantically analyzing the reduction clause, ensuring the var
/// is the correct kind of reference.
ExprResult CheckReductionVar(Expr *VarExpr);

/// Called to check the 'var' type is a variable of pointer type, necessary
/// for 'deviceptr' and 'attach' clauses. Returns true on success.
Expand Down
20 changes: 19 additions & 1 deletion clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) ||
OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
OpenACCCreateClause::classof(C);
OpenACCReductionClause::classof(C) || OpenACCCreateClause::classof(C);
}
bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
Expand Down Expand Up @@ -310,6 +310,16 @@ OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
}

OpenACCReductionClause *OpenACCReductionClause::Create(
const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
SourceLocation EndLoc) {
void *Mem = C.Allocate(
OpenACCReductionClause::totalSizeToAlloc<Expr *>(VarList.size()));
return new (Mem)
OpenACCReductionClause(BeginLoc, LParenLoc, Operator, VarList, EndLoc);
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -445,6 +455,14 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
OS << ")";
}

void OpenACCClausePrinter::VisitReductionClause(
const OpenACCReductionClause &C) {
OS << "reduction(" << C.getReductionOp() << ": ";
llvm::interleaveComma(C.getVarList(), OS,
[&](const Expr *E) { printExpr(E); });
OS << ")";
}

void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
OS << "wait";
if (!C.getLParenLoc().isInvalid()) {
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,12 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitDeviceTypeClause(
const OpenACCDeviceTypeClause &Clause) {}

void OpenACCClauseProfiler::VisitReductionClause(
const OpenACCReductionClause &Clause) {
for (auto *E : Clause.getVarList())
Profiler.VisitStmt(E);
}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,10 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
});
OS << ")";
break;
case OpenACCClauseKind::Reduction:
OS << " clause Operator: "
<< cast<OpenACCReductionClause>(C)->getReductionOp();
break;
default:
// Nothing to do here.
break;
Expand Down
30 changes: 16 additions & 14 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::PresentOrCopyIn: {
bool IsReadOnly = tryParseAndConsumeSpecialTokenKind(
*this, OpenACCSpecialTokenKind::ReadOnly, ClauseKind);
ParsedClause.setVarListDetails(ParseOpenACCVarList(), IsReadOnly,
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
IsReadOnly,
/*IsZero=*/false);
break;
}
Expand All @@ -932,16 +933,17 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::PresentOrCopyOut: {
bool IsZero = tryParseAndConsumeSpecialTokenKind(
*this, OpenACCSpecialTokenKind::Zero, ClauseKind);
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
/*IsReadOnly=*/false, IsZero);
break;
}
case OpenACCClauseKind::Reduction:
case OpenACCClauseKind::Reduction: {
// If we're missing a clause-kind (or it is invalid), see if we can parse
// the var-list anyway.
ParseReductionOperator(*this);
ParseOpenACCVarList();
OpenACCReductionOperator Op = ParseReductionOperator(*this);
ParsedClause.setReductionDetails(Op, ParseOpenACCVarList(ClauseKind));
break;
}
case OpenACCClauseKind::Self:
// The 'self' clause is a var-list instead of a 'condition' in the case of
// the 'update' clause, so we have to handle it here. U se an assert to
Expand All @@ -955,11 +957,11 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::Host:
case OpenACCClauseKind::Link:
case OpenACCClauseKind::UseDevice:
ParseOpenACCVarList();
ParseOpenACCVarList(ClauseKind);
break;
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::DevicePtr:
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
/*IsReadOnly=*/false, /*IsZero=*/false);
break;
case OpenACCClauseKind::Copy:
Expand All @@ -969,7 +971,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::NoCreate:
case OpenACCClauseKind::Present:
case OpenACCClauseKind::Private:
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
/*IsReadOnly=*/false, /*IsZero=*/false);
break;
case OpenACCClauseKind::Collapse: {
Expand Down Expand Up @@ -1278,7 +1280,7 @@ ExprResult Parser::ParseOpenACCBindClauseArgument() {
/// - an array element
/// - a member of a composite variable
/// - a common block name between slashes (fortran only)
Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
Parser::OpenACCVarParseResult Parser::ParseOpenACCVar(OpenACCClauseKind CK) {
OpenACCArraySectionRAII ArraySections(*this);

ExprResult Res = ParseAssignmentExpression();
Expand All @@ -1289,15 +1291,15 @@ Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
if (!Res.isUsable())
return {Res, OpenACCParseCanContinue::Can};

Res = getActions().OpenACC().ActOnVar(Res.get());
Res = getActions().OpenACC().ActOnVar(CK, Res.get());

return {Res, OpenACCParseCanContinue::Can};
}

llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList(OpenACCClauseKind CK) {
llvm::SmallVector<Expr *> Vars;

auto [Res, CanContinue] = ParseOpenACCVar();
auto [Res, CanContinue] = ParseOpenACCVar(CK);
if (Res.isUsable()) {
Vars.push_back(Res.get());
} else if (CanContinue == OpenACCParseCanContinue::Cannot) {
Expand All @@ -1308,7 +1310,7 @@ llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
ExpectAndConsume(tok::comma);

auto [Res, CanContinue] = ParseOpenACCVar();
auto [Res, CanContinue] = ParseOpenACCVar(CK);

if (Res.isUsable()) {
Vars.push_back(Res.get());
Expand Down Expand Up @@ -1342,7 +1344,7 @@ void Parser::ParseOpenACCCacheVarList() {

// ParseOpenACCVarList should leave us before a r-paren, so no need to skip
// anything here.
ParseOpenACCVarList();
ParseOpenACCVarList(OpenACCClauseKind::Invalid);
}

Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
Expand Down
Loading
Loading