Skip to content

Commit

Permalink
[OpenACC] Implement 'reduction' sema for compute constructs (llvm#92808)
Browse files Browse the repository at this point in the history
'reduction' has a few restrictions over normal 'var-list' clauses:

1- On parallel, a num_gangs can only have 1 argument when combined with
reduction. These two aren't able to be combined on any other of the
compute constructs however.

2- The vars all must be 'numerical data types' types of some sort, or a
'composite of numerical data types'. A list of types is given in the
standard as a minimum, so we choose 'isScalar', which covers all of
these types and keeps types that are actually numeric. Other compilers
don't seem to implement the 'composite of numerical data types', though
we do.

3- Because of the above restrictions, member-of-composite is not
allowed, so any access via a memberexpr is disallowed. Array-element and
sub-arrays (aka array sections) are both permitted, so long as they meet
the requirements of rust-lang#2.

This patch implements all of these for compute constructs.
  • Loading branch information
erichkeane committed May 21, 2024
1 parent fbc798e commit a15b685
Show file tree
Hide file tree
Showing 39 changed files with 1,005 additions and 157 deletions.
29 changes: 29 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,35 @@ class OpenACCCreateClause final
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
};

class OpenACCReductionClause final
: public OpenACCClauseWithVarList,
public llvm::TrailingObjects<OpenACCReductionClause, Expr *> {
OpenACCReductionOperator Op;

OpenACCReductionClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
OpenACCReductionOperator Operator,
ArrayRef<Expr *> VarList, SourceLocation EndLoc)
: OpenACCClauseWithVarList(OpenACCClauseKind::Reduction, BeginLoc,
LParenLoc, EndLoc),
Op(Operator) {
std::uninitialized_copy(VarList.begin(), VarList.end(),
getTrailingObjects<Expr *>());
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
}

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Reduction;
}

static OpenACCReductionClause *
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
SourceLocation EndLoc);

OpenACCReductionOperator getReductionOp() const { return Op; }
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand Down
18 changes: 16 additions & 2 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12343,7 +12343,8 @@ def err_acc_num_gangs_num_args
"provided}0">;
def err_acc_not_a_var_ref
: Error<"OpenACC variable is not a valid variable name, sub-array, array "
"element, or composite variable member">;
"element,%select{| member of a composite variable,}0 or composite "
"variable member">;
def err_acc_typecheck_subarray_value
: Error<"OpenACC sub-array subscripted value is not an array or pointer">;
def err_acc_subarray_function_type
Expand Down Expand Up @@ -12374,5 +12375,18 @@ def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
def err_acc_clause_after_device_type
: Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
"compute construct">;

def err_acc_reduction_num_gangs_conflict
: Error<
"OpenACC 'reduction' clause may not appear on a 'parallel' construct "
"with a 'num_gangs' clause with more than 1 argument, have %0">;
def err_acc_reduction_type
: Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
"composite of scalar types;%select{| sub-array base}1 type is %0">;
def err_acc_reduction_composite_type
: Error<"OpenACC 'reduction' variable must be a composite of scalar types; "
"%1 %select{is not a class or struct|is incomplete|is not an "
"aggregate}0">;
def err_acc_reduction_composite_member_type :Error<
"OpenACC 'reduction' composite variable must not have non-scalar field">;
def note_acc_reduction_composite_member_loc : Note<"invalid field is here">;
} // end of sema component.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ VISIT_CLAUSE(NumGangs)
VISIT_CLAUSE(NumWorkers)
VISIT_CLAUSE(Present)
VISIT_CLAUSE(Private)
VISIT_CLAUSE(Reduction)
VISIT_CLAUSE(Self)
VISIT_CLAUSE(VectorLength)
VISIT_CLAUSE(Wait)
Expand Down
36 changes: 36 additions & 0 deletions clang/include/clang/Basic/OpenACCKinds.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,42 @@ enum class OpenACCReductionOperator {
/// Invalid Reduction Clause Kind.
Invalid,
};

template <typename StreamTy>
inline StreamTy &printOpenACCReductionOperator(StreamTy &Out,
OpenACCReductionOperator Op) {
switch (Op) {
case OpenACCReductionOperator::Addition:
return Out << "+";
case OpenACCReductionOperator::Multiplication:
return Out << "*";
case OpenACCReductionOperator::Max:
return Out << "max";
case OpenACCReductionOperator::Min:
return Out << "min";
case OpenACCReductionOperator::BitwiseAnd:
return Out << "&";
case OpenACCReductionOperator::BitwiseOr:
return Out << "|";
case OpenACCReductionOperator::BitwiseXOr:
return Out << "^";
case OpenACCReductionOperator::And:
return Out << "&&";
case OpenACCReductionOperator::Or:
return Out << "||";
case OpenACCReductionOperator::Invalid:
return Out << "<invalid>";
}
llvm_unreachable("Unknown reduction operator kind");
}
inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
OpenACCReductionOperator Op) {
return printOpenACCReductionOperator(Out, Op);
}
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
OpenACCReductionOperator Op) {
return printOpenACCReductionOperator(Out, Op);
}
} // namespace clang

#endif // LLVM_CLANG_BASIC_OPENACCKINDS_H
4 changes: 2 additions & 2 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3686,9 +3686,9 @@ class Parser : public CodeCompletionHandler {

using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
/// Parses a single variable in a variable list for OpenACC.
OpenACCVarParseResult ParseOpenACCVar();
OpenACCVarParseResult ParseOpenACCVar(OpenACCClauseKind CK);
/// Parses the variable list for the variety of places that take a var-list.
llvm::SmallVector<Expr *> ParseOpenACCVarList();
llvm::SmallVector<Expr *> ParseOpenACCVarList(OpenACCClauseKind CK);
/// Parses any parameters for an OpenACC Clause, including required/optional
/// parens.
OpenACCClauseParseResult
Expand Down
29 changes: 27 additions & 2 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ class SemaOpenACC : public SemaBase {
struct DeviceTypeDetails {
SmallVector<DeviceTypeArgument> Archs;
};
struct ReductionDetails {
OpenACCReductionOperator Op;
SmallVector<Expr *> VarList;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
ReductionDetails>
Details = std::monostate{};

public:
Expand Down Expand Up @@ -170,6 +175,10 @@ class SemaOpenACC : public SemaBase {
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
}

OpenACCReductionOperator getReductionOp() const {
return std::get<ReductionDetails>(Details).Op;
}

ArrayRef<Expr *> getVarList() {
assert((ClauseKind == OpenACCClauseKind::Private ||
ClauseKind == OpenACCClauseKind::NoCreate ||
Expand All @@ -188,8 +197,13 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::PresentOrCreate ||
ClauseKind == OpenACCClauseKind::Attach ||
ClauseKind == OpenACCClauseKind::DevicePtr ||
ClauseKind == OpenACCClauseKind::Reduction ||
ClauseKind == OpenACCClauseKind::FirstPrivate) &&
"Parsed clause kind does not have a var-list");

if (ClauseKind == OpenACCClauseKind::Reduction)
return std::get<ReductionDetails>(Details).VarList;

return std::get<VarListDetails>(Details).VarList;
}

Expand Down Expand Up @@ -334,6 +348,13 @@ class SemaOpenACC : public SemaBase {
Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
}

void setReductionDetails(OpenACCReductionOperator Op,
llvm::SmallVector<Expr *> &&VarList) {
assert(ClauseKind == OpenACCClauseKind::Reduction &&
"reduction details only valid on reduction");
Details = ReductionDetails{Op, std::move(VarList)};
}

void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
llvm::SmallVector<Expr *> &&IntExprs) {
assert(ClauseKind == OpenACCClauseKind::Wait &&
Expand Down Expand Up @@ -394,7 +415,11 @@ class SemaOpenACC : public SemaBase {

/// Called when encountering a 'var' for OpenACC, ensures it is actually a
/// declaration reference to a variable of the correct type.
ExprResult ActOnVar(Expr *VarExpr);
ExprResult ActOnVar(OpenACCClauseKind CK, Expr *VarExpr);

/// Called while semantically analyzing the reduction clause, ensuring the var
/// is the correct kind of reference.
ExprResult CheckReductionVar(Expr *VarExpr);

/// Called to check the 'var' type is a variable of pointer type, necessary
/// for 'deviceptr' and 'attach' clauses. Returns true on success.
Expand Down
20 changes: 19 additions & 1 deletion clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) ||
OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
OpenACCCreateClause::classof(C);
OpenACCReductionClause::classof(C) || OpenACCCreateClause::classof(C);
}
bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
Expand Down Expand Up @@ -310,6 +310,16 @@ OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
}

OpenACCReductionClause *OpenACCReductionClause::Create(
const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
SourceLocation EndLoc) {
void *Mem = C.Allocate(
OpenACCReductionClause::totalSizeToAlloc<Expr *>(VarList.size()));
return new (Mem)
OpenACCReductionClause(BeginLoc, LParenLoc, Operator, VarList, EndLoc);
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -445,6 +455,14 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
OS << ")";
}

void OpenACCClausePrinter::VisitReductionClause(
const OpenACCReductionClause &C) {
OS << "reduction(" << C.getReductionOp() << ": ";
llvm::interleaveComma(C.getVarList(), OS,
[&](const Expr *E) { printExpr(E); });
OS << ")";
}

void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
OS << "wait";
if (!C.getLParenLoc().isInvalid()) {
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,12 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitDeviceTypeClause(
const OpenACCDeviceTypeClause &Clause) {}

void OpenACCClauseProfiler::VisitReductionClause(
const OpenACCReductionClause &Clause) {
for (auto *E : Clause.getVarList())
Profiler.VisitStmt(E);
}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,10 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
});
OS << ")";
break;
case OpenACCClauseKind::Reduction:
OS << " clause Operator: "
<< cast<OpenACCReductionClause>(C)->getReductionOp();
break;
default:
// Nothing to do here.
break;
Expand Down
30 changes: 16 additions & 14 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::PresentOrCopyIn: {
bool IsReadOnly = tryParseAndConsumeSpecialTokenKind(
*this, OpenACCSpecialTokenKind::ReadOnly, ClauseKind);
ParsedClause.setVarListDetails(ParseOpenACCVarList(), IsReadOnly,
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
IsReadOnly,
/*IsZero=*/false);
break;
}
Expand All @@ -932,16 +933,17 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::PresentOrCopyOut: {
bool IsZero = tryParseAndConsumeSpecialTokenKind(
*this, OpenACCSpecialTokenKind::Zero, ClauseKind);
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
/*IsReadOnly=*/false, IsZero);
break;
}
case OpenACCClauseKind::Reduction:
case OpenACCClauseKind::Reduction: {
// If we're missing a clause-kind (or it is invalid), see if we can parse
// the var-list anyway.
ParseReductionOperator(*this);
ParseOpenACCVarList();
OpenACCReductionOperator Op = ParseReductionOperator(*this);
ParsedClause.setReductionDetails(Op, ParseOpenACCVarList(ClauseKind));
break;
}
case OpenACCClauseKind::Self:
// The 'self' clause is a var-list instead of a 'condition' in the case of
// the 'update' clause, so we have to handle it here. U se an assert to
Expand All @@ -955,11 +957,11 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::Host:
case OpenACCClauseKind::Link:
case OpenACCClauseKind::UseDevice:
ParseOpenACCVarList();
ParseOpenACCVarList(ClauseKind);
break;
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::DevicePtr:
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
/*IsReadOnly=*/false, /*IsZero=*/false);
break;
case OpenACCClauseKind::Copy:
Expand All @@ -969,7 +971,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::NoCreate:
case OpenACCClauseKind::Present:
case OpenACCClauseKind::Private:
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
/*IsReadOnly=*/false, /*IsZero=*/false);
break;
case OpenACCClauseKind::Collapse: {
Expand Down Expand Up @@ -1278,7 +1280,7 @@ ExprResult Parser::ParseOpenACCBindClauseArgument() {
/// - an array element
/// - a member of a composite variable
/// - a common block name between slashes (fortran only)
Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
Parser::OpenACCVarParseResult Parser::ParseOpenACCVar(OpenACCClauseKind CK) {
OpenACCArraySectionRAII ArraySections(*this);

ExprResult Res = ParseAssignmentExpression();
Expand All @@ -1289,15 +1291,15 @@ Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
if (!Res.isUsable())
return {Res, OpenACCParseCanContinue::Can};

Res = getActions().OpenACC().ActOnVar(Res.get());
Res = getActions().OpenACC().ActOnVar(CK, Res.get());

return {Res, OpenACCParseCanContinue::Can};
}

llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList(OpenACCClauseKind CK) {
llvm::SmallVector<Expr *> Vars;

auto [Res, CanContinue] = ParseOpenACCVar();
auto [Res, CanContinue] = ParseOpenACCVar(CK);
if (Res.isUsable()) {
Vars.push_back(Res.get());
} else if (CanContinue == OpenACCParseCanContinue::Cannot) {
Expand All @@ -1308,7 +1310,7 @@ llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
ExpectAndConsume(tok::comma);

auto [Res, CanContinue] = ParseOpenACCVar();
auto [Res, CanContinue] = ParseOpenACCVar(CK);

if (Res.isUsable()) {
Vars.push_back(Res.get());
Expand Down Expand Up @@ -1342,7 +1344,7 @@ void Parser::ParseOpenACCCacheVarList() {

// ParseOpenACCVarList should leave us before a r-paren, so no need to skip
// anything here.
ParseOpenACCVarList();
ParseOpenACCVarList(OpenACCClauseKind::Invalid);
}

Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
Expand Down
Loading

0 comments on commit a15b685

Please sign in to comment.