Skip to content

Thread Safety Analysis: Compare values of literals #148551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aaronpuchert
Copy link
Member

The typical case for literals is an array of mutexes, where we want to distinguish mutex[0] from mutex[1] and so on. Currently they're treated as the same expression, in fact all literals are treated as the same expression.

The infrastructure for literals is already there, although it required some changes, and some simplifications seemed opportune:

  • The ValueType had fields for size and signedness. But Clang doesn't use native types and stores integer and (floating-point) literals as llvm::APInt regardless of size, so we don't need these properties. We could use them for characters, but it seems easier to just create different base types for now.
  • We remove the BT_Void: void literals don't exist in C++.
  • We remove BT_Float and BT_ValueRef: floating-point numbers and complex numbers are probably not used in lock expressions.

We turn Literal into a pure base class, as it seems to have been intended, and only create LiteralT instances of the correct type. Assertions on as ensure we're not mixing up types.

We print to llvm::raw_ostream instead of std::ostream because that's required for CharacterLiteral::print. Perhaps we should implement that ourselves though.

Fixes #58535.

The typical case for literals is an array of mutexes, where we want to
distinguish `mutex[0]` from `mutex[1]` and so on. Currently they're
treated as the same expression, in fact all literals are treated as the
same expression.

The infrastructure for literals is already there, although it required
some changes, and some simplifications seemed opportune:
* The `ValueType` had fields for size and signedness. But Clang doesn't
  use native types and stores integer and (floating-point) literals as
  `llvm::APInt` regardless of size, so we don't need these properties.
  We could use them for characters, but it seems easier to just create
  different base types for now.
* We remove the `BT_Void`: `void` literals don't exist in C++.
* We remove `BT_Float` and `BT_ValueRef`: floating-point numbers and
  complex numbers are probably not used in lock expressions.

We turn `Literal` into a pure base class, as it seems to have been
intended, and only create `LiteralT` instances of the correct type.
Assertions on `as` ensure we're not mixing up types.

We print to `llvm::raw_ostream` instead of `std::ostream` because that's
required for `CharacterLiteral::print`. Perhaps we should implement that
ourselves though.

Fixes llvm#58535.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:analysis labels Jul 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 13, 2025

@llvm/pr-subscribers-clang

Author: Aaron Puchert (aaronpuchert)

Changes

The typical case for literals is an array of mutexes, where we want to distinguish mutex[0] from mutex[1] and so on. Currently they're treated as the same expression, in fact all literals are treated as the same expression.

The infrastructure for literals is already there, although it required some changes, and some simplifications seemed opportune:

  • The ValueType had fields for size and signedness. But Clang doesn't use native types and stores integer and (floating-point) literals as llvm::APInt regardless of size, so we don't need these properties. We could use them for characters, but it seems easier to just create different base types for now.
  • We remove the BT_Void: void literals don't exist in C++.
  • We remove BT_Float and BT_ValueRef: floating-point numbers and complex numbers are probably not used in lock expressions.

We turn Literal into a pure base class, as it seems to have been intended, and only create LiteralT instances of the correct type. Assertions on as ensure we're not mixing up types.

We print to llvm::raw_ostream instead of std::ostream because that's required for CharacterLiteral::print. Perhaps we should implement that ourselves though.

Fixes #58535.


Patch is 20.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148551.diff

5 Files Affected:

  • (modified) clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h (+4-3)
  • (modified) clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h (+68-145)
  • (modified) clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h (+43-80)
  • (modified) clang/lib/Analysis/ThreadSafetyCommon.cpp (+25-4)
  • (modified) clang/test/SemaCXX/warn-thread-safety-analysis.cpp (+50)
diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h b/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h
index 6c97905a2d7f9..e5cd1948c9314 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h
@@ -35,7 +35,7 @@
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
-#include <sstream>
+#include "llvm/Support/raw_ostream.h"
 #include <string>
 #include <utility>
 #include <vector>
@@ -90,9 +90,10 @@ inline bool partiallyMatches(const til::SExpr *E1, const til::SExpr *E2) {
 }
 
 inline std::string toString(const til::SExpr *E) {
-  std::stringstream ss;
+  std::string s;
+  llvm::raw_string_ostream ss(s);
   til::StdPrinter::print(E, ss);
-  return ss.str();
+  return s;
 }
 
 }  // namespace sx
diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h b/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h
index 14c5b679428a3..890ba19465f7f 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h
@@ -148,129 +148,63 @@ StringRef getBinaryOpcodeString(TIL_BinaryOpcode Op);
 /// All variables and expressions must have a value type.
 /// Pointer types are further subdivided into the various heap-allocated
 /// types, such as functions, records, etc.
-/// Structured types that are passed by value (e.g. complex numbers)
-/// require special handling; they use BT_ValueRef, and size ST_0.
 struct ValueType {
   enum BaseType : unsigned char {
-    BT_Void = 0,
     BT_Bool,
+    BT_AsciiChar,
+    BT_WideChar,
+    BT_UTF16Char,
+    BT_UTF32Char,
     BT_Int,
-    BT_Float,
-    BT_String,    // String literals
+    BT_String, // String literals
     BT_Pointer,
-    BT_ValueRef
   };
 
-  enum SizeType : unsigned char {
-    ST_0 = 0,
-    ST_1,
-    ST_8,
-    ST_16,
-    ST_32,
-    ST_64,
-    ST_128
-  };
-
-  ValueType(BaseType B, SizeType Sz, bool S, unsigned char VS)
-      : Base(B), Size(Sz), Signed(S), VectSize(VS) {}
-
-  inline static SizeType getSizeType(unsigned nbytes);
+  ValueType(BaseType B) : Base(B) {}
 
   template <class T>
   inline static ValueType getValueType();
 
   BaseType Base;
-  SizeType Size;
-  bool Signed;
-
-  // 0 for scalar, otherwise num elements in vector
-  unsigned char VectSize;
 };
 
-inline ValueType::SizeType ValueType::getSizeType(unsigned nbytes) {
-  switch (nbytes) {
-    case 1: return ST_8;
-    case 2: return ST_16;
-    case 4: return ST_32;
-    case 8: return ST_64;
-    case 16: return ST_128;
-    default: return ST_0;
-  }
-}
-
-template<>
-inline ValueType ValueType::getValueType<void>() {
-  return ValueType(BT_Void, ST_0, false, 0);
+inline bool operator==(const ValueType &a, const ValueType &b) {
+  return a.Base == b.Base;
 }
 
 template<>
 inline ValueType ValueType::getValueType<bool>() {
-  return ValueType(BT_Bool, ST_1, false, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<int8_t>() {
-  return ValueType(BT_Int, ST_8, true, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<uint8_t>() {
-  return ValueType(BT_Int, ST_8, false, 0);
+  return ValueType(BT_Bool);
 }
 
-template<>
-inline ValueType ValueType::getValueType<int16_t>() {
-  return ValueType(BT_Int, ST_16, true, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<uint16_t>() {
-  return ValueType(BT_Int, ST_16, false, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<int32_t>() {
-  return ValueType(BT_Int, ST_32, true, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<uint32_t>() {
-  return ValueType(BT_Int, ST_32, false, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<int64_t>() {
-  return ValueType(BT_Int, ST_64, true, 0);
+template <> inline ValueType ValueType::getValueType<char>() {
+  return ValueType(BT_AsciiChar);
 }
 
-template<>
-inline ValueType ValueType::getValueType<uint64_t>() {
-  return ValueType(BT_Int, ST_64, false, 0);
+template <> inline ValueType ValueType::getValueType<wchar_t>() {
+  return ValueType(BT_WideChar);
 }
 
-template<>
-inline ValueType ValueType::getValueType<float>() {
-  return ValueType(BT_Float, ST_32, true, 0);
+template <> inline ValueType ValueType::getValueType<char16_t>() {
+  return ValueType(BT_UTF16Char);
 }
 
-template<>
-inline ValueType ValueType::getValueType<double>() {
-  return ValueType(BT_Float, ST_64, true, 0);
+template <> inline ValueType ValueType::getValueType<char32_t>() {
+  return ValueType(BT_UTF32Char);
 }
 
-template<>
-inline ValueType ValueType::getValueType<long double>() {
-  return ValueType(BT_Float, ST_128, true, 0);
+template <> inline ValueType ValueType::getValueType<llvm::APInt>() {
+  return ValueType(BT_Int);
 }
 
 template<>
 inline ValueType ValueType::getValueType<StringRef>() {
-  return ValueType(BT_String, getSizeType(sizeof(StringRef)), false, 0);
+  return ValueType(BT_String);
 }
 
 template<>
 inline ValueType ValueType::getValueType<void*>() {
-  return ValueType(BT_Pointer, getSizeType(sizeof(void*)), false, 0);
+  return ValueType(BT_Pointer);
 }
 
 /// Base class for AST nodes in the typed intermediate language.
@@ -532,37 +466,29 @@ template <class T> class LiteralT;
 
 // Base class for literal values.
 class Literal : public SExpr {
-public:
-  Literal(const Expr *C)
-     : SExpr(COP_Literal), ValType(ValueType::getValueType<void>()), Cexpr(C) {}
+protected:
   Literal(ValueType VT) : SExpr(COP_Literal), ValType(VT) {}
-  Literal(const Literal &) = default;
 
+public:
   static bool classof(const SExpr *E) { return E->opcode() == COP_Literal; }
 
-  // The clang expression for this literal.
-  const Expr *clangExpr() const { return Cexpr; }
-
   ValueType valueType() const { return ValType; }
 
   template<class T> const LiteralT<T>& as() const {
+    assert(ValType == ValueType::getValueType<T>());
     return *static_cast<const LiteralT<T>*>(this);
   }
   template<class T> LiteralT<T>& as() {
+    assert(ValType == ValueType::getValueType<T>());
     return *static_cast<LiteralT<T>*>(this);
   }
 
   template <class V> typename V::R_SExpr traverse(V &Vs, typename V::R_Ctx Ctx);
 
-  template <class C>
-  typename C::CType compare(const Literal* E, C& Cmp) const {
-    // TODO: defer actual comparison to LiteralT
-    return Cmp.trueResult();
-  }
+  template <class C> typename C::CType compare(const Literal *E, C &Cmp) const;
 
 private:
   const ValueType ValType;
-  const Expr *Cexpr = nullptr;
 };
 
 // Derived class for literal values, which stores the actual value.
@@ -585,58 +511,55 @@ class LiteralT : public Literal {
 
 template <class V>
 typename V::R_SExpr Literal::traverse(V &Vs, typename V::R_Ctx Ctx) {
-  if (Cexpr)
-    return Vs.reduceLiteral(*this);
-
   switch (ValType.Base) {
-  case ValueType::BT_Void:
-    break;
   case ValueType::BT_Bool:
     return Vs.reduceLiteralT(as<bool>());
-  case ValueType::BT_Int: {
-    switch (ValType.Size) {
-    case ValueType::ST_8:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int8_t>());
-      else
-        return Vs.reduceLiteralT(as<uint8_t>());
-    case ValueType::ST_16:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int16_t>());
-      else
-        return Vs.reduceLiteralT(as<uint16_t>());
-    case ValueType::ST_32:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int32_t>());
-      else
-        return Vs.reduceLiteralT(as<uint32_t>());
-    case ValueType::ST_64:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int64_t>());
-      else
-        return Vs.reduceLiteralT(as<uint64_t>());
-    default:
-      break;
-    }
-  }
-  case ValueType::BT_Float: {
-    switch (ValType.Size) {
-    case ValueType::ST_32:
-      return Vs.reduceLiteralT(as<float>());
-    case ValueType::ST_64:
-      return Vs.reduceLiteralT(as<double>());
-    default:
-      break;
-    }
-  }
+  case ValueType::BT_AsciiChar:
+    return Vs.reduceLiteralT(as<char>());
+  case ValueType::BT_WideChar:
+    return Vs.reduceLiteralT(as<wchar_t>());
+  case ValueType::BT_UTF16Char:
+    return Vs.reduceLiteralT(as<char16_t>());
+  case ValueType::BT_UTF32Char:
+    return Vs.reduceLiteralT(as<char32_t>());
+  case ValueType::BT_Int:
+    return Vs.reduceLiteralT(as<llvm::APInt>());
   case ValueType::BT_String:
     return Vs.reduceLiteralT(as<StringRef>());
   case ValueType::BT_Pointer:
-    return Vs.reduceLiteralT(as<void*>());
-  case ValueType::BT_ValueRef:
-    break;
+    return Vs.reduceLiteralT(as<void *>());
+  }
+  llvm_unreachable("Invalid BaseType");
+}
+
+template <class C>
+typename C::CType Literal::compare(const Literal *E, C &Cmp) const {
+  typename C::CType Ct = Cmp.compareIntegers(ValType.Base, E->ValType.Base);
+  if (Cmp.notTrue(Ct))
+    return Ct;
+  switch (ValType.Base) {
+  case ValueType::BT_Bool:
+    return Cmp.compareIntegers(as<bool>().value(), E->as<bool>().value());
+  case ValueType::BT_AsciiChar:
+    return Cmp.compareIntegers(as<char>().value(), E->as<char>().value());
+  case ValueType::BT_WideChar:
+    return Cmp.compareIntegers(as<wchar_t>().value(), E->as<wchar_t>().value());
+  case ValueType::BT_UTF16Char:
+    return Cmp.compareIntegers(as<char16_t>().value(),
+                               E->as<char16_t>().value());
+  case ValueType::BT_UTF32Char:
+    return Cmp.compareIntegers(as<char32_t>().value(),
+                               E->as<char32_t>().value());
+  case ValueType::BT_Int:
+    return Cmp.compareIntegers(as<llvm::APInt>().value(),
+                               E->as<llvm::APInt>().value());
+  case ValueType::BT_String:
+    return Cmp.compareStrings(as<StringRef>().value(),
+                              E->as<StringRef>().value());
+  case ValueType::BT_Pointer:
+    return Cmp.trueResult();
   }
-  return Vs.reduceLiteral(*this);
+  llvm_unreachable("Invalid BaseType");
 }
 
 /// A Literal pointer to an object allocated in memory.
diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h b/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h
index acab8bcdc1dab..6b0c240bc4a9b 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h
@@ -192,7 +192,6 @@ class VisitReducer : public Traversal<Self, VisitReducerBase>,
   R_SExpr reduceUndefined(Undefined &Orig) { return true; }
   R_SExpr reduceWildcard(Wildcard &Orig) { return true; }
 
-  R_SExpr reduceLiteral(Literal &Orig) { return true; }
   template<class T>
   R_SExpr reduceLiteralT(LiteralT<T> &Orig) { return true; }
   R_SExpr reduceLiteralPtr(Literal &Orig) { return true; }
@@ -337,6 +336,9 @@ class EqualsComparator : public Comparator<EqualsComparator> {
   bool notTrue(CType ct) { return !ct; }
 
   bool compareIntegers(unsigned i, unsigned j) { return i == j; }
+  bool compareIntegers(const llvm::APInt &i, const llvm::APInt &j) {
+    return i == j;
+  }
   bool compareStrings (StringRef s, StringRef r) { return s == r; }
   bool comparePointers(const void* P, const void* Q) { return P == Q; }
 
@@ -365,6 +367,9 @@ class MatchComparator : public Comparator<MatchComparator> {
   bool notTrue(CType ct) { return !ct; }
 
   bool compareIntegers(unsigned i, unsigned j) { return i == j; }
+  bool compareIntegers(const llvm::APInt &i, const llvm::APInt &j) {
+    return i == j;
+  }
   bool compareStrings (StringRef s, StringRef r) { return s == r; }
   bool comparePointers(const void *P, const void *Q) { return P == Q; }
 
@@ -532,88 +537,46 @@ class PrettyPrinter {
     SS << "*";
   }
 
-  template<class T>
-  void printLiteralT(const LiteralT<T> *E, StreamType &SS) {
-    SS << E->value();
-  }
-
-  void printLiteralT(const LiteralT<uint8_t> *E, StreamType &SS) {
-    SS << "'" << E->value() << "'";
-  }
-
   void printLiteral(const Literal *E, StreamType &SS) {
-    if (E->clangExpr()) {
-      SS << getSourceLiteralString(E->clangExpr());
+    ValueType VT = E->valueType();
+    switch (VT.Base) {
+    case ValueType::BT_Bool:
+      if (E->as<bool>().value())
+        SS << "true";
+      else
+        SS << "false";
+      return;
+    case ValueType::BT_AsciiChar:
+      CharacterLiteral::print(E->as<char>().value(),
+                              CharacterLiteralKind::Ascii, SS);
+      return;
+    case ValueType::BT_WideChar:
+      CharacterLiteral::print(E->as<wchar_t>().value(),
+                              CharacterLiteralKind::Wide, SS);
+      return;
+    case ValueType::BT_UTF16Char:
+      CharacterLiteral::print(E->as<char16_t>().value(),
+                              CharacterLiteralKind::UTF16, SS);
+      return;
+    case ValueType::BT_UTF32Char:
+      CharacterLiteral::print(E->as<char32_t>().value(),
+                              CharacterLiteralKind::UTF32, SS);
+      return;
+    case ValueType::BT_Int: {
+      SmallVector<char, 32> Str;
+      E->as<llvm::APInt>().value().toStringSigned(Str);
+      Str.push_back('\0');
+      SS << Str.data();
       return;
     }
-    else {
-      ValueType VT = E->valueType();
-      switch (VT.Base) {
-      case ValueType::BT_Void:
-        SS << "void";
-        return;
-      case ValueType::BT_Bool:
-        if (E->as<bool>().value())
-          SS << "true";
-        else
-          SS << "false";
-        return;
-      case ValueType::BT_Int:
-        switch (VT.Size) {
-        case ValueType::ST_8:
-          if (VT.Signed)
-            printLiteralT(&E->as<int8_t>(), SS);
-          else
-            printLiteralT(&E->as<uint8_t>(), SS);
-          return;
-        case ValueType::ST_16:
-          if (VT.Signed)
-            printLiteralT(&E->as<int16_t>(), SS);
-          else
-            printLiteralT(&E->as<uint16_t>(), SS);
-          return;
-        case ValueType::ST_32:
-          if (VT.Signed)
-            printLiteralT(&E->as<int32_t>(), SS);
-          else
-            printLiteralT(&E->as<uint32_t>(), SS);
-          return;
-        case ValueType::ST_64:
-          if (VT.Signed)
-            printLiteralT(&E->as<int64_t>(), SS);
-          else
-            printLiteralT(&E->as<uint64_t>(), SS);
-          return;
-        default:
-          break;
-        }
-        break;
-      case ValueType::BT_Float:
-        switch (VT.Size) {
-        case ValueType::ST_32:
-          printLiteralT(&E->as<float>(), SS);
-          return;
-        case ValueType::ST_64:
-          printLiteralT(&E->as<double>(), SS);
-          return;
-        default:
-          break;
-        }
-        break;
-      case ValueType::BT_String:
-        SS << "\"";
-        printLiteralT(&E->as<StringRef>(), SS);
-        SS << "\"";
-        return;
-      case ValueType::BT_Pointer:
-        SS << "#ptr";
-        return;
-      case ValueType::BT_ValueRef:
-        SS << "#vref";
-        return;
-      }
+    case ValueType::BT_String:
+      SS << '\"' << E->as<StringRef>().value() << '\"';
+      return;
+    case ValueType::BT_Pointer:
+      SS << "nullptr"; // currently the only supported pointer literal.
+      return;
     }
-    SS << "#lit";
+    llvm_unreachable("Invalid BaseType");
   }
 
   void printLiteralPtr(const LiteralPtr *E, StreamType &SS) {
@@ -919,7 +882,7 @@ class PrettyPrinter {
   }
 };
 
-class StdPrinter : public PrettyPrinter<StdPrinter, std::ostream> {};
+class StdPrinter : public PrettyPrinter<StdPrinter, llvm::raw_ostream> {};
 
 } // namespace til
 } // namespace threadSafety
diff --git a/clang/lib/Analysis/ThreadSafetyCommon.cpp b/clang/lib/Analysis/ThreadSafetyCommon.cpp
index ddbd0a9ca904b..0797593f30377 100644
--- a/clang/lib/Analysis/ThreadSafetyCommon.cpp
+++ b/clang/lib/Analysis/ThreadSafetyCommon.cpp
@@ -300,16 +300,37 @@ til::SExpr *SExprBuilder::translate(const Stmt *S, CallingContext *Ctx) {
     return translate(cast<MaterializeTemporaryExpr>(S)->getSubExpr(), Ctx);
 
   // Collect all literals
-  case Stmt::CharacterLiteralClass:
+  case Stmt::CharacterLiteralClass: {
+    const auto *CL = cast<CharacterLiteral>(S);
+    unsigned Value = CL->getValue();
+    switch (CL->getKind()) {
+    case CharacterLiteralKind::Ascii:
+    case CharacterLiteralKind::UTF8:
+      return new (Arena) til::LiteralT<char>(Value);
+    case CharacterLiteralKind::Wide:
+      return new (Arena) til::LiteralT<wchar_t>(Value);
+    case CharacterLiteralKind::UTF16:
+      return new (Arena) til::LiteralT<char16_t>(Value);
+    case CharacterLiteralKind::UTF32:
+      return new (Arena) til::LiteralT<char32_t>(Value);
+    }
+    llvm_unreachable("Invalid CharacterLiteralKind");
+  }
   case Stmt::CXXNullPtrLiteralExprClass:
   case Stmt::GNUNullExprClass:
+    return new (Arena) til::LiteralT<void *>(nullptr);
   case Stmt::CXXBoolLiteralExprClass:
-  case Stmt::FloatingLiteralClass:
-  case Stmt::ImaginaryLiteralClass:
+    return new (Arena)
+        til::LiteralT<bool>(cast<CXXBoolLiteralExpr>(S)->getValue());
   case Stmt::IntegerLiteralClass:
+    return new (Arena)
+        til::LiteralT<llvm::APInt>(cast<IntegerLiteral>(S)->getValue());
   case Stmt::StringLiteralClass:
+    return new (Arena)
+        til::LiteralT<StringRef>(cast<StringLiteral>(S)->getString());
   case Stmt::ObjCStringLiteralClass:
-    return new (Arena) til::Literal(cast<Expr>(S));
+    return new (Arena) til::LiteralT<StringRef>(
+        cast<ObjCStringLiteral>(S)->getString()->getString());
 
   case Stmt::DeclStmtClass:
     return translateDeclStmt(cast<DeclStmt>(S), Ctx);
diff --git a/clang/test/SemaCXX/warn-thread-safety-analysis.cpp b/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
index d64ed1e5f260a..f416c62aaf71a 100644
--- a/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
+++ b/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
@@ -2487,6 +2487,10 @@ class Bar {
   Foo& getFoo()              { return *f; }
   Foo& getFoo2(int c)        { return *f; }
   Foo& getFoo3(int c, int d) { return *f; }
+  Foo& getFoo4(bool)         { return *f; }
+  Foo& getFoo5(char)         { return *f; }
+  Foo& getFoo6(char16_t)     { return *f; }
+  Foo& getFoo7(const char*)  { return *f; }
 
   Foo& getFooey() { return *f; }
 };
@@ -2518,6 +2522,22 @@ void test() {
   bar.getFoo3(a, b).a = 0;
   bar.getFoo3(a, b).mu_.Unlock();
 
+  bar.getFoo4(true).mu_.Lock();
+  bar.getFoo4(true).a = 0;
+  bar.getFoo4(true).mu_.Unlock();
+
+  bar.getFoo5('a').mu_.Lock();
+  bar.getFoo5('a').a = 0;
+  bar.getFoo5('a').mu_.Unlock();
+
+  bar.getFoo6(u'\u1234').mu_.Lock();
+  bar.getFoo6(u'\u1234').a = 0;
+  bar.getFoo6(u'\u1234').mu_.Unlock();
+
+  bar.getFoo7("foo").mu_.Lock();
+  bar.getFoo7("foo").a = 0;
+  bar.getFoo7("foo").mu_.Unlock();
+
   getBarFoo(bar, a).mu_.Lock();
   getBarFoo(bar, a).a = 0;
   getBarFoo(bar, a).mu_.Unlock();
@@ -2559,12 +2579,42 @@ void test2() {
     // expected-note {{found near match 'bar.getFoo2(a).mu_'}}
   bar.getFoo2(a).mu_.Unlock();
 
+  bar.getFoo2(0).mu_.Lock();
+  bar.getFoo2(1).a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo2(1).mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo2(0).mu_'}}
+  bar.getFoo2(0).mu_.Unlock();
+
   bar.getFoo3(a, b).mu_.Lock();
   bar.getFoo3(a, c).a = 0;  // \
     // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo3(a, c).mu_' exclusively}} \
     // expected-note {{found near match 'bar.getFoo3(a, b).mu_'}}
   bar.getFoo3(a, b).mu_.Unlock();
 
+  bar.getFoo4(true).mu_.Lock();
+  bar.getFoo4(false).a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo4(false).mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo4(true).mu_'}}
+  bar.getFoo4(true).mu_.Unlock();
+
+  bar.getFoo5('x').mu_.Lock();
+  bar.getFoo5('y').a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo5('y').mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo5('x').mu_'}}
+  bar.getFoo5('x').mu_.Unlock();
+
+  bar.getFoo6(u'\u1234').mu_.Lock();
+  bar.getFoo6(u'\u4321').a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo6(u'\u4321').mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo6(u'\u1234').mu_'}}
+  bar.getFoo6(u'\u1234').mu_.Unlock();
+
+  bar.getFoo7("foo").mu_.Lock();
+  bar.getFoo7("bar").a = 0; // \
+    // e...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 13, 2025

@llvm/pr-subscribers-clang-analysis

Author: Aaron Puchert (aaronpuchert)

Changes

The typical case for literals is an array of mutexes, where we want to distinguish mutex[0] from mutex[1] and so on. Currently they're treated as the same expression, in fact all literals are treated as the same expression.

The infrastructure for literals is already there, although it required some changes, and some simplifications seemed opportune:

  • The ValueType had fields for size and signedness. But Clang doesn't use native types and stores integer and (floating-point) literals as llvm::APInt regardless of size, so we don't need these properties. We could use them for characters, but it seems easier to just create different base types for now.
  • We remove the BT_Void: void literals don't exist in C++.
  • We remove BT_Float and BT_ValueRef: floating-point numbers and complex numbers are probably not used in lock expressions.

We turn Literal into a pure base class, as it seems to have been intended, and only create LiteralT instances of the correct type. Assertions on as ensure we're not mixing up types.

We print to llvm::raw_ostream instead of std::ostream because that's required for CharacterLiteral::print. Perhaps we should implement that ourselves though.

Fixes #58535.


Patch is 20.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148551.diff

5 Files Affected:

  • (modified) clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h (+4-3)
  • (modified) clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h (+68-145)
  • (modified) clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h (+43-80)
  • (modified) clang/lib/Analysis/ThreadSafetyCommon.cpp (+25-4)
  • (modified) clang/test/SemaCXX/warn-thread-safety-analysis.cpp (+50)
diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h b/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h
index 6c97905a2d7f9..e5cd1948c9314 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafetyCommon.h
@@ -35,7 +35,7 @@
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
-#include <sstream>
+#include "llvm/Support/raw_ostream.h"
 #include <string>
 #include <utility>
 #include <vector>
@@ -90,9 +90,10 @@ inline bool partiallyMatches(const til::SExpr *E1, const til::SExpr *E2) {
 }
 
 inline std::string toString(const til::SExpr *E) {
-  std::stringstream ss;
+  std::string s;
+  llvm::raw_string_ostream ss(s);
   til::StdPrinter::print(E, ss);
-  return ss.str();
+  return s;
 }
 
 }  // namespace sx
diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h b/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h
index 14c5b679428a3..890ba19465f7f 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafetyTIL.h
@@ -148,129 +148,63 @@ StringRef getBinaryOpcodeString(TIL_BinaryOpcode Op);
 /// All variables and expressions must have a value type.
 /// Pointer types are further subdivided into the various heap-allocated
 /// types, such as functions, records, etc.
-/// Structured types that are passed by value (e.g. complex numbers)
-/// require special handling; they use BT_ValueRef, and size ST_0.
 struct ValueType {
   enum BaseType : unsigned char {
-    BT_Void = 0,
     BT_Bool,
+    BT_AsciiChar,
+    BT_WideChar,
+    BT_UTF16Char,
+    BT_UTF32Char,
     BT_Int,
-    BT_Float,
-    BT_String,    // String literals
+    BT_String, // String literals
     BT_Pointer,
-    BT_ValueRef
   };
 
-  enum SizeType : unsigned char {
-    ST_0 = 0,
-    ST_1,
-    ST_8,
-    ST_16,
-    ST_32,
-    ST_64,
-    ST_128
-  };
-
-  ValueType(BaseType B, SizeType Sz, bool S, unsigned char VS)
-      : Base(B), Size(Sz), Signed(S), VectSize(VS) {}
-
-  inline static SizeType getSizeType(unsigned nbytes);
+  ValueType(BaseType B) : Base(B) {}
 
   template <class T>
   inline static ValueType getValueType();
 
   BaseType Base;
-  SizeType Size;
-  bool Signed;
-
-  // 0 for scalar, otherwise num elements in vector
-  unsigned char VectSize;
 };
 
-inline ValueType::SizeType ValueType::getSizeType(unsigned nbytes) {
-  switch (nbytes) {
-    case 1: return ST_8;
-    case 2: return ST_16;
-    case 4: return ST_32;
-    case 8: return ST_64;
-    case 16: return ST_128;
-    default: return ST_0;
-  }
-}
-
-template<>
-inline ValueType ValueType::getValueType<void>() {
-  return ValueType(BT_Void, ST_0, false, 0);
+inline bool operator==(const ValueType &a, const ValueType &b) {
+  return a.Base == b.Base;
 }
 
 template<>
 inline ValueType ValueType::getValueType<bool>() {
-  return ValueType(BT_Bool, ST_1, false, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<int8_t>() {
-  return ValueType(BT_Int, ST_8, true, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<uint8_t>() {
-  return ValueType(BT_Int, ST_8, false, 0);
+  return ValueType(BT_Bool);
 }
 
-template<>
-inline ValueType ValueType::getValueType<int16_t>() {
-  return ValueType(BT_Int, ST_16, true, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<uint16_t>() {
-  return ValueType(BT_Int, ST_16, false, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<int32_t>() {
-  return ValueType(BT_Int, ST_32, true, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<uint32_t>() {
-  return ValueType(BT_Int, ST_32, false, 0);
-}
-
-template<>
-inline ValueType ValueType::getValueType<int64_t>() {
-  return ValueType(BT_Int, ST_64, true, 0);
+template <> inline ValueType ValueType::getValueType<char>() {
+  return ValueType(BT_AsciiChar);
 }
 
-template<>
-inline ValueType ValueType::getValueType<uint64_t>() {
-  return ValueType(BT_Int, ST_64, false, 0);
+template <> inline ValueType ValueType::getValueType<wchar_t>() {
+  return ValueType(BT_WideChar);
 }
 
-template<>
-inline ValueType ValueType::getValueType<float>() {
-  return ValueType(BT_Float, ST_32, true, 0);
+template <> inline ValueType ValueType::getValueType<char16_t>() {
+  return ValueType(BT_UTF16Char);
 }
 
-template<>
-inline ValueType ValueType::getValueType<double>() {
-  return ValueType(BT_Float, ST_64, true, 0);
+template <> inline ValueType ValueType::getValueType<char32_t>() {
+  return ValueType(BT_UTF32Char);
 }
 
-template<>
-inline ValueType ValueType::getValueType<long double>() {
-  return ValueType(BT_Float, ST_128, true, 0);
+template <> inline ValueType ValueType::getValueType<llvm::APInt>() {
+  return ValueType(BT_Int);
 }
 
 template<>
 inline ValueType ValueType::getValueType<StringRef>() {
-  return ValueType(BT_String, getSizeType(sizeof(StringRef)), false, 0);
+  return ValueType(BT_String);
 }
 
 template<>
 inline ValueType ValueType::getValueType<void*>() {
-  return ValueType(BT_Pointer, getSizeType(sizeof(void*)), false, 0);
+  return ValueType(BT_Pointer);
 }
 
 /// Base class for AST nodes in the typed intermediate language.
@@ -532,37 +466,29 @@ template <class T> class LiteralT;
 
 // Base class for literal values.
 class Literal : public SExpr {
-public:
-  Literal(const Expr *C)
-     : SExpr(COP_Literal), ValType(ValueType::getValueType<void>()), Cexpr(C) {}
+protected:
   Literal(ValueType VT) : SExpr(COP_Literal), ValType(VT) {}
-  Literal(const Literal &) = default;
 
+public:
   static bool classof(const SExpr *E) { return E->opcode() == COP_Literal; }
 
-  // The clang expression for this literal.
-  const Expr *clangExpr() const { return Cexpr; }
-
   ValueType valueType() const { return ValType; }
 
   template<class T> const LiteralT<T>& as() const {
+    assert(ValType == ValueType::getValueType<T>());
     return *static_cast<const LiteralT<T>*>(this);
   }
   template<class T> LiteralT<T>& as() {
+    assert(ValType == ValueType::getValueType<T>());
     return *static_cast<LiteralT<T>*>(this);
   }
 
   template <class V> typename V::R_SExpr traverse(V &Vs, typename V::R_Ctx Ctx);
 
-  template <class C>
-  typename C::CType compare(const Literal* E, C& Cmp) const {
-    // TODO: defer actual comparison to LiteralT
-    return Cmp.trueResult();
-  }
+  template <class C> typename C::CType compare(const Literal *E, C &Cmp) const;
 
 private:
   const ValueType ValType;
-  const Expr *Cexpr = nullptr;
 };
 
 // Derived class for literal values, which stores the actual value.
@@ -585,58 +511,55 @@ class LiteralT : public Literal {
 
 template <class V>
 typename V::R_SExpr Literal::traverse(V &Vs, typename V::R_Ctx Ctx) {
-  if (Cexpr)
-    return Vs.reduceLiteral(*this);
-
   switch (ValType.Base) {
-  case ValueType::BT_Void:
-    break;
   case ValueType::BT_Bool:
     return Vs.reduceLiteralT(as<bool>());
-  case ValueType::BT_Int: {
-    switch (ValType.Size) {
-    case ValueType::ST_8:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int8_t>());
-      else
-        return Vs.reduceLiteralT(as<uint8_t>());
-    case ValueType::ST_16:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int16_t>());
-      else
-        return Vs.reduceLiteralT(as<uint16_t>());
-    case ValueType::ST_32:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int32_t>());
-      else
-        return Vs.reduceLiteralT(as<uint32_t>());
-    case ValueType::ST_64:
-      if (ValType.Signed)
-        return Vs.reduceLiteralT(as<int64_t>());
-      else
-        return Vs.reduceLiteralT(as<uint64_t>());
-    default:
-      break;
-    }
-  }
-  case ValueType::BT_Float: {
-    switch (ValType.Size) {
-    case ValueType::ST_32:
-      return Vs.reduceLiteralT(as<float>());
-    case ValueType::ST_64:
-      return Vs.reduceLiteralT(as<double>());
-    default:
-      break;
-    }
-  }
+  case ValueType::BT_AsciiChar:
+    return Vs.reduceLiteralT(as<char>());
+  case ValueType::BT_WideChar:
+    return Vs.reduceLiteralT(as<wchar_t>());
+  case ValueType::BT_UTF16Char:
+    return Vs.reduceLiteralT(as<char16_t>());
+  case ValueType::BT_UTF32Char:
+    return Vs.reduceLiteralT(as<char32_t>());
+  case ValueType::BT_Int:
+    return Vs.reduceLiteralT(as<llvm::APInt>());
   case ValueType::BT_String:
     return Vs.reduceLiteralT(as<StringRef>());
   case ValueType::BT_Pointer:
-    return Vs.reduceLiteralT(as<void*>());
-  case ValueType::BT_ValueRef:
-    break;
+    return Vs.reduceLiteralT(as<void *>());
+  }
+  llvm_unreachable("Invalid BaseType");
+}
+
+template <class C>
+typename C::CType Literal::compare(const Literal *E, C &Cmp) const {
+  typename C::CType Ct = Cmp.compareIntegers(ValType.Base, E->ValType.Base);
+  if (Cmp.notTrue(Ct))
+    return Ct;
+  switch (ValType.Base) {
+  case ValueType::BT_Bool:
+    return Cmp.compareIntegers(as<bool>().value(), E->as<bool>().value());
+  case ValueType::BT_AsciiChar:
+    return Cmp.compareIntegers(as<char>().value(), E->as<char>().value());
+  case ValueType::BT_WideChar:
+    return Cmp.compareIntegers(as<wchar_t>().value(), E->as<wchar_t>().value());
+  case ValueType::BT_UTF16Char:
+    return Cmp.compareIntegers(as<char16_t>().value(),
+                               E->as<char16_t>().value());
+  case ValueType::BT_UTF32Char:
+    return Cmp.compareIntegers(as<char32_t>().value(),
+                               E->as<char32_t>().value());
+  case ValueType::BT_Int:
+    return Cmp.compareIntegers(as<llvm::APInt>().value(),
+                               E->as<llvm::APInt>().value());
+  case ValueType::BT_String:
+    return Cmp.compareStrings(as<StringRef>().value(),
+                              E->as<StringRef>().value());
+  case ValueType::BT_Pointer:
+    return Cmp.trueResult();
   }
-  return Vs.reduceLiteral(*this);
+  llvm_unreachable("Invalid BaseType");
 }
 
 /// A Literal pointer to an object allocated in memory.
diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h b/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h
index acab8bcdc1dab..6b0c240bc4a9b 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafetyTraverse.h
@@ -192,7 +192,6 @@ class VisitReducer : public Traversal<Self, VisitReducerBase>,
   R_SExpr reduceUndefined(Undefined &Orig) { return true; }
   R_SExpr reduceWildcard(Wildcard &Orig) { return true; }
 
-  R_SExpr reduceLiteral(Literal &Orig) { return true; }
   template<class T>
   R_SExpr reduceLiteralT(LiteralT<T> &Orig) { return true; }
   R_SExpr reduceLiteralPtr(Literal &Orig) { return true; }
@@ -337,6 +336,9 @@ class EqualsComparator : public Comparator<EqualsComparator> {
   bool notTrue(CType ct) { return !ct; }
 
   bool compareIntegers(unsigned i, unsigned j) { return i == j; }
+  bool compareIntegers(const llvm::APInt &i, const llvm::APInt &j) {
+    return i == j;
+  }
   bool compareStrings (StringRef s, StringRef r) { return s == r; }
   bool comparePointers(const void* P, const void* Q) { return P == Q; }
 
@@ -365,6 +367,9 @@ class MatchComparator : public Comparator<MatchComparator> {
   bool notTrue(CType ct) { return !ct; }
 
   bool compareIntegers(unsigned i, unsigned j) { return i == j; }
+  bool compareIntegers(const llvm::APInt &i, const llvm::APInt &j) {
+    return i == j;
+  }
   bool compareStrings (StringRef s, StringRef r) { return s == r; }
   bool comparePointers(const void *P, const void *Q) { return P == Q; }
 
@@ -532,88 +537,46 @@ class PrettyPrinter {
     SS << "*";
   }
 
-  template<class T>
-  void printLiteralT(const LiteralT<T> *E, StreamType &SS) {
-    SS << E->value();
-  }
-
-  void printLiteralT(const LiteralT<uint8_t> *E, StreamType &SS) {
-    SS << "'" << E->value() << "'";
-  }
-
   void printLiteral(const Literal *E, StreamType &SS) {
-    if (E->clangExpr()) {
-      SS << getSourceLiteralString(E->clangExpr());
+    ValueType VT = E->valueType();
+    switch (VT.Base) {
+    case ValueType::BT_Bool:
+      if (E->as<bool>().value())
+        SS << "true";
+      else
+        SS << "false";
+      return;
+    case ValueType::BT_AsciiChar:
+      CharacterLiteral::print(E->as<char>().value(),
+                              CharacterLiteralKind::Ascii, SS);
+      return;
+    case ValueType::BT_WideChar:
+      CharacterLiteral::print(E->as<wchar_t>().value(),
+                              CharacterLiteralKind::Wide, SS);
+      return;
+    case ValueType::BT_UTF16Char:
+      CharacterLiteral::print(E->as<char16_t>().value(),
+                              CharacterLiteralKind::UTF16, SS);
+      return;
+    case ValueType::BT_UTF32Char:
+      CharacterLiteral::print(E->as<char32_t>().value(),
+                              CharacterLiteralKind::UTF32, SS);
+      return;
+    case ValueType::BT_Int: {
+      SmallVector<char, 32> Str;
+      E->as<llvm::APInt>().value().toStringSigned(Str);
+      Str.push_back('\0');
+      SS << Str.data();
       return;
     }
-    else {
-      ValueType VT = E->valueType();
-      switch (VT.Base) {
-      case ValueType::BT_Void:
-        SS << "void";
-        return;
-      case ValueType::BT_Bool:
-        if (E->as<bool>().value())
-          SS << "true";
-        else
-          SS << "false";
-        return;
-      case ValueType::BT_Int:
-        switch (VT.Size) {
-        case ValueType::ST_8:
-          if (VT.Signed)
-            printLiteralT(&E->as<int8_t>(), SS);
-          else
-            printLiteralT(&E->as<uint8_t>(), SS);
-          return;
-        case ValueType::ST_16:
-          if (VT.Signed)
-            printLiteralT(&E->as<int16_t>(), SS);
-          else
-            printLiteralT(&E->as<uint16_t>(), SS);
-          return;
-        case ValueType::ST_32:
-          if (VT.Signed)
-            printLiteralT(&E->as<int32_t>(), SS);
-          else
-            printLiteralT(&E->as<uint32_t>(), SS);
-          return;
-        case ValueType::ST_64:
-          if (VT.Signed)
-            printLiteralT(&E->as<int64_t>(), SS);
-          else
-            printLiteralT(&E->as<uint64_t>(), SS);
-          return;
-        default:
-          break;
-        }
-        break;
-      case ValueType::BT_Float:
-        switch (VT.Size) {
-        case ValueType::ST_32:
-          printLiteralT(&E->as<float>(), SS);
-          return;
-        case ValueType::ST_64:
-          printLiteralT(&E->as<double>(), SS);
-          return;
-        default:
-          break;
-        }
-        break;
-      case ValueType::BT_String:
-        SS << "\"";
-        printLiteralT(&E->as<StringRef>(), SS);
-        SS << "\"";
-        return;
-      case ValueType::BT_Pointer:
-        SS << "#ptr";
-        return;
-      case ValueType::BT_ValueRef:
-        SS << "#vref";
-        return;
-      }
+    case ValueType::BT_String:
+      SS << '\"' << E->as<StringRef>().value() << '\"';
+      return;
+    case ValueType::BT_Pointer:
+      SS << "nullptr"; // currently the only supported pointer literal.
+      return;
     }
-    SS << "#lit";
+    llvm_unreachable("Invalid BaseType");
   }
 
   void printLiteralPtr(const LiteralPtr *E, StreamType &SS) {
@@ -919,7 +882,7 @@ class PrettyPrinter {
   }
 };
 
-class StdPrinter : public PrettyPrinter<StdPrinter, std::ostream> {};
+class StdPrinter : public PrettyPrinter<StdPrinter, llvm::raw_ostream> {};
 
 } // namespace til
 } // namespace threadSafety
diff --git a/clang/lib/Analysis/ThreadSafetyCommon.cpp b/clang/lib/Analysis/ThreadSafetyCommon.cpp
index ddbd0a9ca904b..0797593f30377 100644
--- a/clang/lib/Analysis/ThreadSafetyCommon.cpp
+++ b/clang/lib/Analysis/ThreadSafetyCommon.cpp
@@ -300,16 +300,37 @@ til::SExpr *SExprBuilder::translate(const Stmt *S, CallingContext *Ctx) {
     return translate(cast<MaterializeTemporaryExpr>(S)->getSubExpr(), Ctx);
 
   // Collect all literals
-  case Stmt::CharacterLiteralClass:
+  case Stmt::CharacterLiteralClass: {
+    const auto *CL = cast<CharacterLiteral>(S);
+    unsigned Value = CL->getValue();
+    switch (CL->getKind()) {
+    case CharacterLiteralKind::Ascii:
+    case CharacterLiteralKind::UTF8:
+      return new (Arena) til::LiteralT<char>(Value);
+    case CharacterLiteralKind::Wide:
+      return new (Arena) til::LiteralT<wchar_t>(Value);
+    case CharacterLiteralKind::UTF16:
+      return new (Arena) til::LiteralT<char16_t>(Value);
+    case CharacterLiteralKind::UTF32:
+      return new (Arena) til::LiteralT<char32_t>(Value);
+    }
+    llvm_unreachable("Invalid CharacterLiteralKind");
+  }
   case Stmt::CXXNullPtrLiteralExprClass:
   case Stmt::GNUNullExprClass:
+    return new (Arena) til::LiteralT<void *>(nullptr);
   case Stmt::CXXBoolLiteralExprClass:
-  case Stmt::FloatingLiteralClass:
-  case Stmt::ImaginaryLiteralClass:
+    return new (Arena)
+        til::LiteralT<bool>(cast<CXXBoolLiteralExpr>(S)->getValue());
   case Stmt::IntegerLiteralClass:
+    return new (Arena)
+        til::LiteralT<llvm::APInt>(cast<IntegerLiteral>(S)->getValue());
   case Stmt::StringLiteralClass:
+    return new (Arena)
+        til::LiteralT<StringRef>(cast<StringLiteral>(S)->getString());
   case Stmt::ObjCStringLiteralClass:
-    return new (Arena) til::Literal(cast<Expr>(S));
+    return new (Arena) til::LiteralT<StringRef>(
+        cast<ObjCStringLiteral>(S)->getString()->getString());
 
   case Stmt::DeclStmtClass:
     return translateDeclStmt(cast<DeclStmt>(S), Ctx);
diff --git a/clang/test/SemaCXX/warn-thread-safety-analysis.cpp b/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
index d64ed1e5f260a..f416c62aaf71a 100644
--- a/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
+++ b/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
@@ -2487,6 +2487,10 @@ class Bar {
   Foo& getFoo()              { return *f; }
   Foo& getFoo2(int c)        { return *f; }
   Foo& getFoo3(int c, int d) { return *f; }
+  Foo& getFoo4(bool)         { return *f; }
+  Foo& getFoo5(char)         { return *f; }
+  Foo& getFoo6(char16_t)     { return *f; }
+  Foo& getFoo7(const char*)  { return *f; }
 
   Foo& getFooey() { return *f; }
 };
@@ -2518,6 +2522,22 @@ void test() {
   bar.getFoo3(a, b).a = 0;
   bar.getFoo3(a, b).mu_.Unlock();
 
+  bar.getFoo4(true).mu_.Lock();
+  bar.getFoo4(true).a = 0;
+  bar.getFoo4(true).mu_.Unlock();
+
+  bar.getFoo5('a').mu_.Lock();
+  bar.getFoo5('a').a = 0;
+  bar.getFoo5('a').mu_.Unlock();
+
+  bar.getFoo6(u'\u1234').mu_.Lock();
+  bar.getFoo6(u'\u1234').a = 0;
+  bar.getFoo6(u'\u1234').mu_.Unlock();
+
+  bar.getFoo7("foo").mu_.Lock();
+  bar.getFoo7("foo").a = 0;
+  bar.getFoo7("foo").mu_.Unlock();
+
   getBarFoo(bar, a).mu_.Lock();
   getBarFoo(bar, a).a = 0;
   getBarFoo(bar, a).mu_.Unlock();
@@ -2559,12 +2579,42 @@ void test2() {
     // expected-note {{found near match 'bar.getFoo2(a).mu_'}}
   bar.getFoo2(a).mu_.Unlock();
 
+  bar.getFoo2(0).mu_.Lock();
+  bar.getFoo2(1).a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo2(1).mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo2(0).mu_'}}
+  bar.getFoo2(0).mu_.Unlock();
+
   bar.getFoo3(a, b).mu_.Lock();
   bar.getFoo3(a, c).a = 0;  // \
     // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo3(a, c).mu_' exclusively}} \
     // expected-note {{found near match 'bar.getFoo3(a, b).mu_'}}
   bar.getFoo3(a, b).mu_.Unlock();
 
+  bar.getFoo4(true).mu_.Lock();
+  bar.getFoo4(false).a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo4(false).mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo4(true).mu_'}}
+  bar.getFoo4(true).mu_.Unlock();
+
+  bar.getFoo5('x').mu_.Lock();
+  bar.getFoo5('y').a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo5('y').mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo5('x').mu_'}}
+  bar.getFoo5('x').mu_.Unlock();
+
+  bar.getFoo6(u'\u1234').mu_.Lock();
+  bar.getFoo6(u'\u4321').a = 0; // \
+    // expected-warning {{writing variable 'a' requires holding mutex 'bar.getFoo6(u'\u4321').mu_' exclusively}} \
+    // expected-note {{found near match 'bar.getFoo6(u'\u1234').mu_'}}
+  bar.getFoo6(u'\u1234').mu_.Unlock();
+
+  bar.getFoo7("foo").mu_.Lock();
+  bar.getFoo7("bar").a = 0; // \
+    // e...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:analysis clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

False positive in -Wthread-safety-analysis when using mutexes in collections
2 participants