From d223b6fb7c0146602a777d8d99c57cf5aa522f16 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Fri, 16 Feb 2024 17:19:59 -0800 Subject: [PATCH 01/13] Generic type support Instance and static fields on generic types Re-enable and add to field tests --- src/coreclr/vm/callconvbuilder.cpp | 11 +-- src/coreclr/vm/dynamicmethod.cpp | 31 ++++--- src/coreclr/vm/dynamicmethod.h | 12 ++- src/coreclr/vm/ilstubresolver.cpp | 40 ++++++--- src/coreclr/vm/ilstubresolver.h | 2 +- src/coreclr/vm/jitinterface.cpp | 71 ++++++++-------- src/coreclr/vm/prestub.cpp | 85 ++++++++++++++++--- src/coreclr/vm/stubgen.cpp | 11 +++ src/coreclr/vm/stubgen.h | 77 ++++++++++++++++- src/coreclr/vm/typehandle.h | 12 ++- .../UnsafeAccessors/UnsafeAccessorsTests.cs | 37 ++++++-- 11 files changed, 298 insertions(+), 91 deletions(-) diff --git a/src/coreclr/vm/callconvbuilder.cpp b/src/coreclr/vm/callconvbuilder.cpp index 20f95f1222410..3075087ee7b83 100644 --- a/src/coreclr/vm/callconvbuilder.cpp +++ b/src/coreclr/vm/callconvbuilder.cpp @@ -298,15 +298,12 @@ namespace { STANDARD_VM_CONTRACT; - TypeHandle type; - MethodDesc* pMD; - FieldDesc* pFD; + ResolvedToken resolved{}; + pResolver->ResolveToken(token, &resolved); - pResolver->ResolveToken(token, &type, &pMD, &pFD); + _ASSERTE(!resolved.TypeHandle.IsNull()); - _ASSERTE(!type.IsNull()); - - *nameOut = type.GetMethodTable()->GetFullyQualifiedNameInfo(namespaceOut); + *nameOut = resolved.TypeHandle.GetMethodTable()->GetFullyQualifiedNameInfo(namespaceOut); return S_OK; } diff --git a/src/coreclr/vm/dynamicmethod.cpp b/src/coreclr/vm/dynamicmethod.cpp index bd5bebcce50f2..065d80d57fcc1 100644 --- a/src/coreclr/vm/dynamicmethod.cpp +++ b/src/coreclr/vm/dynamicmethod.cpp @@ -1325,7 +1325,7 @@ void LCGMethodResolver::AddToUsedIndCellList(BYTE * indcell) } -void LCGMethodResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD) +void LCGMethodResolver::ResolveToken(mdToken token, ResolvedToken* resolvedToken) { STANDARD_VM_CONTRACT; @@ -1335,24 +1335,35 @@ void LCGMethodResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc DECLARE_ARGHOLDER_ARRAY(args, 5); + TypeHandle handle; + MethodDesc* pMD = NULL; + FieldDesc* pFD = NULL; args[ARGNUM_0] = OBJECTREF_TO_ARGHOLDER(ObjectFromHandle(m_managedResolver)); args[ARGNUM_1] = DWORD_TO_ARGHOLDER(token); - args[ARGNUM_2] = pTH; - args[ARGNUM_3] = ppMD; - args[ARGNUM_4] = ppFD; + args[ARGNUM_2] = &handle; + args[ARGNUM_3] = &pMD; + args[ARGNUM_4] = &pFD; CALL_MANAGED_METHOD_NORET(args); - _ASSERTE(*ppMD == NULL || *ppFD == NULL); + _ASSERTE(pMD == NULL || pFD == NULL); - if (pTH->IsNull()) + if (handle.IsNull()) { - if (*ppMD != NULL) *pTH = (*ppMD)->GetMethodTable(); - else - if (*ppFD != NULL) *pTH = (*ppFD)->GetEnclosingMethodTable(); + if (pMD != NULL) + { + handle = pMD->GetMethodTable(); + } + else if (pFD != NULL) + { + handle = pFD->GetEnclosingMethodTable(); + } } - _ASSERTE(!pTH->IsNull()); + _ASSERTE(!handle.IsNull()); + resolvedToken->TypeHandle = handle; + resolvedToken->Method = pMD; + resolvedToken->Field = pFD; } //--------------------------------------------------------------------------------------- diff --git a/src/coreclr/vm/dynamicmethod.h b/src/coreclr/vm/dynamicmethod.h index ddbe3c795cfe3..91cd0ac1d03e9 100644 --- a/src/coreclr/vm/dynamicmethod.h +++ b/src/coreclr/vm/dynamicmethod.h @@ -37,6 +37,14 @@ class ChunkAllocator void Delete(); }; +struct ResolvedToken final +{ + TypeHandle TypeHandle; + SigPointer Signature; // Needed for generic look-up (for example, static fields on generic types) + MethodDesc* Method; + FieldDesc* Field; +}; + //--------------------------------------------------------------------------------------- // class DynamicResolver @@ -90,7 +98,7 @@ class DynamicResolver virtual OBJECTHANDLE ConstructStringLiteral(mdToken metaTok) = 0; virtual BOOL IsValidStringRef(mdToken metaTok) = 0; virtual STRINGREF GetStringLiteral(mdToken metaTok) = 0; - virtual void ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD) = 0; + virtual void ResolveToken(mdToken token, ResolvedToken* resolvedToken) = 0; virtual SigPointer ResolveSignature(mdToken token) = 0; virtual SigPointer ResolveSignatureForVarArg(mdToken token) = 0; virtual void GetEHInfo(unsigned EHnumber, CORINFO_EH_CLAUSE* clause) = 0; @@ -141,7 +149,7 @@ class LCGMethodResolver : public DynamicResolver OBJECTHANDLE ConstructStringLiteral(mdToken metaTok); BOOL IsValidStringRef(mdToken metaTok); - void ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD); + void ResolveToken(mdToken token, ResolvedToken* resolvedToken); SigPointer ResolveSignature(mdToken token); SigPointer ResolveSignatureForVarArg(mdToken token); void GetEHInfo(unsigned EHnumber, CORINFO_EH_CLAUSE* clause); diff --git a/src/coreclr/vm/ilstubresolver.cpp b/src/coreclr/vm/ilstubresolver.cpp index c24be260c692e..c4fe7dae69af8 100644 --- a/src/coreclr/vm/ilstubresolver.cpp +++ b/src/coreclr/vm/ilstubresolver.cpp @@ -133,13 +133,10 @@ STRINGREF ILStubResolver::GetStringLiteral(mdToken metaTok) return NULL; } -void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD) +void ILStubResolver::ResolveToken(mdToken token, ResolvedToken* resolvedToken) { STANDARD_VM_CONTRACT; - - *pTH = NULL; - *ppMD = NULL; - *ppFD = NULL; + _ASSERTE(resolvedToken != NULL); switch (TypeFromToken(token)) { @@ -147,8 +144,8 @@ void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** { MethodDesc* pMD = m_pCompileTimeState->m_tokenLookupMap.LookupMethodDef(token); _ASSERTE(pMD); - *ppMD = pMD; - *pTH = TypeHandle(pMD->GetMethodTable()); + resolvedToken->Method = pMD; + resolvedToken->TypeHandle = TypeHandle(pMD->GetMethodTable()); } break; @@ -156,7 +153,7 @@ void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** { TypeHandle typeHnd = m_pCompileTimeState->m_tokenLookupMap.LookupTypeDef(token); _ASSERTE(!typeHnd.IsNull()); - *pTH = typeHnd; + resolvedToken->TypeHandle = typeHnd; } break; @@ -164,10 +161,33 @@ void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** { FieldDesc* pFD = m_pCompileTimeState->m_tokenLookupMap.LookupFieldDef(token); _ASSERTE(pFD); - *ppFD = pFD; - *pTH = TypeHandle(pFD->GetEnclosingMethodTable()); + resolvedToken->Field = pFD; + resolvedToken->TypeHandle = TypeHandle(pFD->GetEnclosingMethodTable()); + } + break; + +#if !defined(DACCESS_COMPILE) + case mdtMemberRef: + { + TokenLookupMap::MemberRefEntry entry = m_pCompileTimeState->m_tokenLookupMap.LookupMemberRef(token); + if (entry.Type == mdtFieldDef) + { + _ASSERTE(entry.Entry.Field != NULL); + + if (entry.ClassSignatureToken != mdTokenNil) + resolvedToken->Signature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + resolvedToken->Field = entry.Entry.Field; + resolvedToken->TypeHandle = TypeHandle(entry.Entry.Field->GetApproxEnclosingMethodTable()).Instantiate(entry.ClassInstantiation); + } + else + { + _ASSERTE(entry.Type == mdtMethodDef); + _ASSERTE(entry.Entry.Method != NULL); + } } break; +#endif // !defined(DACCESS_COMPILE) default: UNREACHABLE_MSG("unexpected metadata token type"); diff --git a/src/coreclr/vm/ilstubresolver.h b/src/coreclr/vm/ilstubresolver.h index 82a1217d79c7e..ea823e7f77380 100644 --- a/src/coreclr/vm/ilstubresolver.h +++ b/src/coreclr/vm/ilstubresolver.h @@ -35,7 +35,7 @@ class ILStubResolver : DynamicResolver OBJECTHANDLE ConstructStringLiteral(mdToken metaTok); BOOL IsValidStringRef(mdToken metaTok); STRINGREF GetStringLiteral(mdToken metaTok); - void ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD); + void ResolveToken(mdToken token, ResolvedToken* resolvedToken); SigPointer ResolveSignature(mdToken token); SigPointer ResolveSignatureForVarArg(mdToken token); void GetEHInfo(unsigned EHnumber, CORINFO_EH_CLAUSE* clause); diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index f2f7d229d546f..c7061060b9581 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -156,15 +156,13 @@ inline CORINFO_MODULE_HANDLE GetScopeHandle(MethodDesc* method) //This is common refactored code from within several of the access check functions. static BOOL ModifyCheckForDynamicMethod(DynamicResolver *pResolver, TypeHandle *pOwnerTypeForSecurity, - AccessCheckOptions::AccessCheckType *pAccessCheckType, - DynamicResolver** ppAccessContext) + AccessCheckOptions::AccessCheckType *pAccessCheckType) { CONTRACTL { STANDARD_VM_CHECK; PRECONDITION(CheckPointer(pResolver)); PRECONDITION(CheckPointer(pOwnerTypeForSecurity)); PRECONDITION(CheckPointer(pAccessCheckType)); - PRECONDITION(CheckPointer(ppAccessContext)); PRECONDITION(*pAccessCheckType == AccessCheckOptions::kNormalAccessibilityChecks); } CONTRACTL_END; @@ -883,7 +881,12 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken if (IsDynamicScope(pResolvedToken->tokenScope)) { - GetDynamicResolver(pResolvedToken->tokenScope)->ResolveToken(pResolvedToken->token, &th, &pMD, &pFD); + ResolvedToken resolved{}; + GetDynamicResolver(pResolvedToken->tokenScope)->ResolveToken(pResolvedToken->token, &resolved); + + th = resolved.TypeHandle; + pMD = resolved.Method; + pFD = resolved.Field; // // Check that we got the expected handles and fill in missing data if necessary @@ -897,8 +900,6 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken ThrowBadTokenException(pResolvedToken); if ((tokenType & CORINFO_TOKENKIND_Method) == 0) ThrowBadTokenException(pResolvedToken); - if (th.IsNull()) - th = pMD->GetMethodTable(); // "PermitUninstDefOrRef" check if ((tokenType != CORINFO_TOKENKIND_Ldtoken) && pMD->ContainsGenericVariables()) @@ -924,8 +925,12 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken ThrowBadTokenException(pResolvedToken); if ((tokenType & CORINFO_TOKENKIND_Field) == 0) ThrowBadTokenException(pResolvedToken); - if (th.IsNull()) - th = pFD->GetApproxEnclosingMethodTable(); + + // If a signature was supplied, set the TypeSpec for resolution. + if (!resolved.Signature.IsNull()) + { + resolved.Signature.GetSignature(&pResolvedToken->pTypeSpec, &pResolvedToken->cbTypeSpec); + } if (pFD->IsStatic() && (tokenType != CORINFO_TOKENKIND_Ldtoken)) { @@ -1703,7 +1708,9 @@ void CEEInfo::getFieldInfo (CORINFO_RESOLVED_TOKEN * pResolvedToken, SigTypeContext::InitTypeContext(pCallerForSecurity, &typeContext); SigPointer sigptr(pResolvedToken->pTypeSpec, pResolvedToken->cbTypeSpec); - fieldTypeForSecurity = sigptr.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + + Module* targetModule = GetModule(pResolvedToken->tokenScope); + fieldTypeForSecurity = sigptr.GetTypeHandleThrowing(targetModule, &typeContext); // typeHnd can be a variable type if (fieldTypeForSecurity.GetMethodTable() == NULL) @@ -1715,15 +1722,13 @@ void CEEInfo::getFieldInfo (CORINFO_RESOLVED_TOKEN * pResolvedToken, BOOL doAccessCheck = TRUE; AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; - DynamicResolver * pAccessContext = NULL; - //More in code:CEEInfo::getCallInfo, but the short version is that the caller and callee Descs do //not completely describe the type. TypeHandle callerTypeForSecurity = TypeHandle(pCallerForSecurity->GetMethodTable()); if (IsDynamicScope(pResolvedToken->tokenScope)) { doAccessCheck = ModifyCheckForDynamicMethod(GetDynamicResolver(pResolvedToken->tokenScope), &callerTypeForSecurity, - &accessCheckType, &pAccessContext); + &accessCheckType); } //Now for some link time checks. @@ -1735,7 +1740,7 @@ void CEEInfo::getFieldInfo (CORINFO_RESOLVED_TOKEN * pResolvedToken, { //Well, let's check some visibility at least. AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, FALSE, pField); @@ -1849,22 +1854,19 @@ CEEInfo::findCallSiteSig( { _ASSERTE(TypeFromToken(sigMethTok) == mdtMethodDef); - TypeHandle classHandle; - MethodDesc * pMD = NULL; - FieldDesc * pFD = NULL; - // in this case a method is asked for its sig. Resolve the method token and get the sig - pResolver->ResolveToken(sigMethTok, &classHandle, &pMD, &pFD); - if (pMD == NULL) + ResolvedToken resolved{}; + pResolver->ResolveToken(sigMethTok, &resolved); + if (resolved.Method == NULL) COMPlusThrow(kInvalidProgramException); PCCOR_SIGNATURE pSig = NULL; DWORD cbSig; - pMD->GetSig(&pSig, &cbSig); + resolved.Method->GetSig(&pSig, &cbSig); sig = SigPointer(pSig, cbSig); - context = MAKE_METHODCONTEXT(pMD); - scopeHnd = GetScopeHandle(pMD->GetModule()); + context = MAKE_METHODCONTEXT(resolved.Method); + scopeHnd = GetScopeHandle(resolved.Method->GetModule()); } sig.GetSignature(&pSig, &cbSig); @@ -3248,7 +3250,7 @@ void CEEInfo::ComputeRuntimeLookupForSharedGenericToken(DictionaryEntryKind entr sigBuilder.AppendData(pContextMT->GetNumDicts() - 1); } - Module * pModule = (Module *)pResolvedToken->tokenScope; + Module * pModule = GetModule(pResolvedToken->tokenScope); switch (entryKind) { @@ -4929,7 +4931,6 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( BOOL doAccessCheck = TRUE; AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; - DynamicResolver * pAccessContext = NULL; //All access checks must be done on the open instantiation. MethodDesc * pCallerForSecurity = GetMethodForSecurity(callerHandle); @@ -4953,8 +4954,7 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( if (IsDynamicScope(pResolvedToken->tokenScope)) { doAccessCheck = ModifyCheckForDynamicMethod(GetDynamicResolver(pResolvedToken->tokenScope), - &callerTypeForSecurity, &accessCheckType, - &pAccessContext); + &callerTypeForSecurity, &accessCheckType); } //Since this is a check against a TypeHandle, there are some things we can stick in a TypeHandle that @@ -4969,7 +4969,7 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( if (doAccessCheck) { AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, FALSE /*throw on error*/, pCalleeForSecurity.GetMethodTable()); @@ -5602,14 +5602,13 @@ void CEEInfo::getCallInfo( BOOL doAccessCheck = TRUE; BOOL canAccessMethod = TRUE; AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; - DynamicResolver * pAccessContext = NULL; callerTypeForSecurity = TypeHandle(pCallerForSecurity->GetMethodTable()); if (pCallerForSecurity->IsDynamicMethod()) { doAccessCheck = ModifyCheckForDynamicMethod(pCallerForSecurity->AsDynamicMethodDesc()->GetResolver(), &callerTypeForSecurity, - &accessCheckType, &pAccessContext); + &accessCheckType); } pResult->accessAllowed = CORINFO_ACCESS_ALLOWED; @@ -5617,7 +5616,7 @@ void CEEInfo::getCallInfo( if (doAccessCheck) { AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, FALSE, pCalleeForSecurity); @@ -12332,10 +12331,11 @@ void CEEJitInfo::setEHinfo ( ((pEHClause->Flags & COR_ILEXCEPTION_CLAUSE_FILTER) == 0) && (clause->ClassToken != NULL)) { - MethodDesc * pMD; FieldDesc * pFD; - m_pMethodBeingCompiled->AsDynamicMethodDesc()->GetResolver()->ResolveToken(clause->ClassToken, (TypeHandle *)&pEHClause->TypeHandle, &pMD, &pFD); + ResolvedToken resolved{}; + m_pMethodBeingCompiled->AsDynamicMethodDesc()->GetResolver()->ResolveToken(clause->ClassToken, &resolved); + pEHClause->TypeHandle = (void*)resolved.TypeHandle.AsPtr(); SetHasCachedTypeHandle(pEHClause); - LOG((LF_EH, LL_INFO1000000, " CachedTypeHandle: 0x%08lx -> 0x%08lx\n", clause->ClassToken, pEHClause->TypeHandle)); + LOG((LF_EH, LL_INFO1000000, " CachedTypeHandle: 0x%08x -> %p\n", clause->ClassToken, pEHClause->TypeHandle)); } EE_TO_JIT_TRANSITION(); @@ -12939,18 +12939,17 @@ PCODE UnsafeJitFunction(PrepareCodeConfig* config, //and its return type. AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; TypeHandle ownerTypeForSecurity = TypeHandle(pMethodForSecurity->GetMethodTable()); - DynamicResolver *pAccessContext = NULL; BOOL doAccessCheck = TRUE; if (pMethodForSecurity->IsDynamicMethod()) { doAccessCheck = ModifyCheckForDynamicMethod(pMethodForSecurity->AsDynamicMethodDesc()->GetResolver(), &ownerTypeForSecurity, - &accessCheckType, &pAccessContext); + &accessCheckType); } if (doAccessCheck) { AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, TRUE /*Throw on error*/, pMethodForSecurity); diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 7df5865af7920..423f15a935936 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1116,6 +1116,7 @@ namespace : Kind{ kind } , Declaration{ pMD } , DeclarationSig{ pMD } + , TargetTypeSig{} , TargetType{} , IsTargetStatic{ false } , TargetMethod{} @@ -1125,6 +1126,7 @@ namespace UnsafeAccessorKind Kind; MethodDesc* Declaration; MetaSig DeclarationSig; + SigPointer TargetTypeSig; TypeHandle TargetType; bool IsTargetStatic; MethodDesc* TargetMethod; @@ -1321,6 +1323,7 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + CorElementType elemType = fieldType.GetSignatureCorElementType(); ApproxFieldDescIterator fdIterator( targetType.AsMethodTable(), (cxt.IsTargetStatic ? ApproxFieldDescIterator::STATIC_FIELDS : ApproxFieldDescIterator::INSTANCE_FIELDS)); @@ -1328,12 +1331,26 @@ namespace while ((pField = fdIterator.Next()) != NULL) { // Validate the name and target type match. - if (strcmp(fieldName, pField->GetName()) == 0 - && fieldType == pField->LookupFieldTypeHandle()) + if (strcmp(fieldName, pField->GetName()) != 0) + continue; + + // We check if the possible field is class or valuetype + // since generic fields need resolution. + CorElementType fieldTypeMaybe = pField->GetFieldType(); + if (fieldTypeMaybe == ELEMENT_TYPE_CLASS + || fieldTypeMaybe == ELEMENT_TYPE_VALUETYPE) { - cxt.TargetField = pField; - return true; + if (fieldType != pField->LookupFieldTypeHandle()) + continue; } + else + { + if (elemType != fieldTypeMaybe) + continue; + } + + cxt.TargetField = pField; + return true; } return false; } @@ -1351,12 +1368,14 @@ namespace ilResolver->SetStubMethodDesc(cxt.Declaration); ilResolver->SetStubTargetMethodDesc(cxt.TargetMethod); - // [TODO] Handle generics - SigTypeContext emptyContext; + SigTypeContext genericContext; + if (cxt.Declaration->GetClassification() == mcInstantiated) + SigTypeContext::InitTypeContext(cxt.Declaration, &genericContext); + ILStubLinker sl( cxt.Declaration->GetModule(), cxt.Declaration->GetSignature(), - &emptyContext, + &genericContext, cxt.TargetMethod, (ILStubLinkerFlags)ILSTUB_LINKER_FLAG_NONE); @@ -1389,12 +1408,44 @@ namespace pCode->EmitCALL(pCode->GetToken(cxt.TargetMethod), targetArgCount, targetRetCount); break; case UnsafeAccessorKind::Field: + { _ASSERTE(cxt.TargetField != NULL); - pCode->EmitLDFLDA(pCode->GetToken(cxt.TargetField)); + mdToken target; + if (!cxt.TargetType.HasInstantiation()) + { + target = pCode->GetToken(cxt.TargetField); + } + else + { + // See the static field case for why this can be mdTokenNil. + mdToken targetTypeSigToken = mdTokenNil; + target = pCode->GetToken(cxt.TargetField, targetTypeSigToken, cxt.TargetType.GetInstantiation()); + } + pCode->EmitLDFLDA(target); break; + } case UnsafeAccessorKind::StaticField: _ASSERTE(cxt.TargetField != NULL); - pCode->EmitLDSFLDA(pCode->GetToken(cxt.TargetField)); + mdToken target; + if (!cxt.TargetType.HasInstantiation()) + { + target = pCode->GetToken(cxt.TargetField); + } + else + { + // For accessing a generic instance field, every instantiation will + // be at the same offset, and be the same size, with the same GC layout, + // as long as the generic is canonically equivalent. However, for static fields, + // while the offset, size and GC layout remain the same, the address of the + // field is different, and needs to be found by a lookup of some form. The + // current form of lookup means the exact type isn't with a type signature. + PCCOR_SIGNATURE sig; + uint32_t sigLen; + cxt.TargetTypeSig.GetSignature(&sig, &sigLen); + mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + target = pCode->GetToken(cxt.TargetField, targetTypeSigToken, cxt.TargetType.GetInstantiation()); + } + pCode->EmitLDSFLDA(target); break; default: _ASSERTE(!"Unknown UnsafeAccessorKind"); @@ -1449,10 +1500,6 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET if (!IsStatic()) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - // Block generic support early - if (HasClassOrMethodInstantiation()) - ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - UnsafeAccessorKind kind; SString name; @@ -1460,6 +1507,15 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET if (!TryParseUnsafeAccessorAttribute(this, ca, kind, name)) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + // Block generic support on methods + if (HasClassOrMethodInstantiation() + && (kind == UnsafeAccessorKind::Constructor + || kind == UnsafeAccessorKind::Method + || kind == UnsafeAccessorKind::StaticMethod)) + { + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + } + GenerationContext context{ kind, this }; // Parse the signature to determine the type to use: @@ -1473,6 +1529,9 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET if (argCount > 0) { context.DeclarationSig.NextArg(); + + // Get the target type signature and resolve to a type handle. + context.TargetTypeSig = context.DeclarationSig.GetArgProps(); firstArgType = context.DeclarationSig.GetLastTypeHandleThrowing(); } diff --git a/src/coreclr/vm/stubgen.cpp b/src/coreclr/vm/stubgen.cpp index 5f18e8f5d9123..43c1d1bfca01c 100644 --- a/src/coreclr/vm/stubgen.cpp +++ b/src/coreclr/vm/stubgen.cpp @@ -3151,6 +3151,12 @@ int ILStubLinker::GetToken(FieldDesc* pFD) return m_tokenMap.GetToken(pFD); } +int ILStubLinker::GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst) +{ + STANDARD_VM_CONTRACT; + return m_tokenMap.GetToken(pFD, typeSignature, inst); +} + int ILStubLinker::GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { STANDARD_VM_CONTRACT; @@ -3242,6 +3248,11 @@ int ILCodeStream::GetToken(FieldDesc* pFD) STANDARD_VM_CONTRACT; return m_pOwner->GetToken(pFD); } +int ILCodeStream::GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(pFD, typeSignature, inst); +} int ILCodeStream::GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { STANDARD_VM_CONTRACT; diff --git a/src/coreclr/vm/stubgen.h b/src/coreclr/vm/stubgen.h index 595de649220cc..eb4c86f4bc400 100644 --- a/src/coreclr/vm/stubgen.h +++ b/src/coreclr/vm/stubgen.h @@ -307,10 +307,12 @@ class TokenLookupMap for (COUNT_T i = 0; i < pSrc->m_signatures.GetCount(); i++) { const CQuickBytesSpecifySize<16>& src = pSrc->m_signatures[i]; - CQuickBytesSpecifySize<16>& dst = *m_signatures.Append(); - dst.AllocThrows(src.Size()); - memcpy(dst.Ptr(), src.Ptr(), src.Size()); + auto dst = m_signatures.Append(); + dst->AllocThrows(src.Size()); + memcpy(dst->Ptr(), src.Ptr(), src.Size()); } + + m_memberRefs.Set(pSrc->m_memberRefs); } TypeHandle LookupTypeDef(mdToken token) @@ -328,6 +330,34 @@ class TokenLookupMap WRAPPER_NO_CONTRACT; return LookupTokenWorker(token); } + + struct MemberRefEntry final + { + CorTokenType Type; + mdToken ClassSignatureToken; + Instantiation ClassInstantiation; + union + { + FieldDesc* Field; + MethodDesc* Method; + } Entry; + }; + MemberRefEntry LookupMemberRef(mdToken token) + { + CONTRACTL + { + NOTHROW; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(RidFromToken(token) - 1 < m_memberRefs.GetCount()); + PRECONDITION(RidFromToken(token) != 0); + PRECONDITION(TypeFromToken(token) == mdtMemberRef); + } + CONTRACTL_END; + + return m_memberRefs[static_cast(RidFromToken(token) - 1)]; + } + SigPointer LookupSig(mdToken token) { CONTRACTL @@ -362,6 +392,26 @@ class TokenLookupMap WRAPPER_NO_CONTRACT; return GetTokenWorker(pFieldDesc); } + mdToken GetToken(FieldDesc* pFieldDesc, mdToken typeSignature, Instantiation inst) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(pFieldDesc != NULL); + PRECONDITION(!inst.IsEmpty()); + } + CONTRACTL_END; + + MemberRefEntry* entry; + mdToken token = GetMemberRefWorker(&entry); + entry->Type = mdtFieldDef; + entry->ClassSignatureToken = typeSignature; + entry->ClassInstantiation = inst; + entry->Entry.Field = pFieldDesc; + return token; + } mdToken GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { @@ -382,6 +432,22 @@ class TokenLookupMap } protected: + mdToken GetMemberRefWorker(MemberRefEntry** entry) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(entry != NULL); + } + CONTRACTL_END; + + mdToken token = TokenFromRid(m_memberRefs.GetCount(), mdtMemberRef) + 1; + *entry = &*m_memberRefs.Append(); // Dereference the iterator and then take the address + return token; + } + template HandleType LookupTokenWorker(mdToken token) { @@ -423,9 +489,10 @@ class TokenLookupMap return token; } - unsigned int m_nextAvailableRid; + uint32_t m_nextAvailableRid; CQuickBytesSpecifySize m_qbEntries; SArray, FALSE> m_signatures; + SArray m_memberRefs; }; class ILCodeLabel; @@ -595,6 +662,7 @@ class ILStubLinker int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); int GetToken(FieldDesc* pFD); + int GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); DWORD NewLocal(CorElementType typ = ELEMENT_TYPE_I); DWORD NewLocal(LocalDesc loc); @@ -824,6 +892,7 @@ class ILCodeStream int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); int GetToken(FieldDesc* pFD); + int GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); DWORD NewLocal(CorElementType typ = ELEMENT_TYPE_I); diff --git a/src/coreclr/vm/typehandle.h b/src/coreclr/vm/typehandle.h index 8483a935af613..f0f5a4604ab22 100644 --- a/src/coreclr/vm/typehandle.h +++ b/src/coreclr/vm/typehandle.h @@ -647,9 +647,7 @@ inline CHECK CheckPointer(TypeHandle th, IsNullOK ok = NULL_NOT_OK) /*************************************************************************/ // Instantiation is representation of generic instantiation. -// It is simple read-only array of TypeHandles. In NGen, the type handles -// may be encoded using indirections. That's one reason why it is convenient -// to have wrapper class that performs the decoding. +// It is simple read-only array of TypeHandles. class Instantiation { public: @@ -695,6 +693,14 @@ class Instantiation } #endif + Instantiation& operator=(const Instantiation& inst) + { + _ASSERTE(this != &inst); + m_pArgs = inst.m_pArgs; + m_nArgs = inst.m_nArgs; + return *this; + } + // Return i-th instantiation argument TypeHandle operator[](DWORD iArg) const { diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs index 6e0a562f32a9b..8d059042ea739 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs @@ -100,10 +100,12 @@ class UserDataGenericClass private static T _GF; private T _gf; + public static void SetStaticGenericField(T val) => _GF = val; + private static string _F = PrivateStatic; private string _f; - public UserDataGenericClass() { _f = Private; } + public UserDataGenericClass(T t) { _f = Private; _gf = t; } private static string _GM(T s, ref T sr, in T si) => typeof(T).ToString(); private string _gm(T s, ref T sr, in T si) => typeof(T).ToString(); @@ -216,7 +218,6 @@ public static void Verify_AccessFieldClass() } [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/92633")] public static void Verify_AccessStaticFieldGenericClass() { Console.WriteLine($"Running {nameof(Verify_AccessStaticFieldGenericClass)}"); @@ -225,11 +226,25 @@ public static void Verify_AccessStaticFieldGenericClass() Assert.Equal(PrivateStatic, GetPrivateStaticFieldString((UserDataGenericClass)null)); + { + int expected = 10; + UserDataGenericClass.SetStaticGenericField(expected); + Assert.Equal(expected, GetPrivateStaticField((UserDataGenericClass)null)); + } + { + string expected = "abc"; + UserDataGenericClass.SetStaticGenericField(expected); + Assert.Equal(expected, GetPrivateStaticField((UserDataGenericClass)null)); + } + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticFieldName)] extern static ref string GetPrivateStaticFieldInt(UserDataGenericClass d); [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticFieldName)] extern static ref string GetPrivateStaticFieldString(UserDataGenericClass d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticGenericFieldName)] + extern static ref T GetPrivateStaticField(UserDataGenericClass d); } [Fact] @@ -260,20 +275,32 @@ public static void Verify_AccessFieldValue() } [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/92633")] public static void Verify_AccessFieldGenericClass() { Console.WriteLine($"Running {nameof(Verify_AccessFieldGenericClass)}"); - Assert.Equal(Private, GetPrivateFieldInt(new UserDataGenericClass())); + Assert.Equal(Private, GetPrivateFieldInt(new UserDataGenericClass(0))); - Assert.Equal(Private, GetPrivateFieldString(new UserDataGenericClass())); + Assert.Equal(Private, GetPrivateFieldString(new UserDataGenericClass(string.Empty))); + + { + int expected = 10; + Assert.Equal(expected, GetPrivateField(new UserDataGenericClass(expected))); + } + + { + string expected = "abc"; + Assert.Equal(expected, GetPrivateField(new UserDataGenericClass(expected))); + } [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.FieldName)] extern static ref string GetPrivateFieldInt(UserDataGenericClass d); [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.FieldName)] extern static ref string GetPrivateFieldString(UserDataGenericClass d); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.GenericFieldName)] + extern static ref T GetPrivateField(UserDataGenericClass d); } [Fact] From 86bc15f6a95d74962f189bcf871301e4af073bce Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Thu, 29 Feb 2024 16:31:46 -0800 Subject: [PATCH 02/13] Generic methods Generic types with non-generic methods. --- src/coreclr/vm/dynamicmethod.h | 3 +- src/coreclr/vm/ilstubresolver.cpp | 30 ++++++- src/coreclr/vm/jitinterface.cpp | 30 +++---- src/coreclr/vm/prestub.cpp | 126 +++++++++++++++++++++++++++--- src/coreclr/vm/siginfo.hpp | 2 +- src/coreclr/vm/stubgen.cpp | 30 ++++++- src/coreclr/vm/stubgen.h | 91 +++++++++++++++++++-- 7 files changed, 270 insertions(+), 42 deletions(-) diff --git a/src/coreclr/vm/dynamicmethod.h b/src/coreclr/vm/dynamicmethod.h index 91cd0ac1d03e9..a26a241006113 100644 --- a/src/coreclr/vm/dynamicmethod.h +++ b/src/coreclr/vm/dynamicmethod.h @@ -40,7 +40,8 @@ class ChunkAllocator struct ResolvedToken final { TypeHandle TypeHandle; - SigPointer Signature; // Needed for generic look-up (for example, static fields on generic types) + SigPointer TypeSignature; + SigPointer MethodSignature; MethodDesc* Method; FieldDesc* Field; }; diff --git a/src/coreclr/vm/ilstubresolver.cpp b/src/coreclr/vm/ilstubresolver.cpp index c4fe7dae69af8..1efb9c2975e16 100644 --- a/src/coreclr/vm/ilstubresolver.cpp +++ b/src/coreclr/vm/ilstubresolver.cpp @@ -175,18 +175,44 @@ void ILStubResolver::ResolveToken(mdToken token, ResolvedToken* resolvedToken) _ASSERTE(entry.Entry.Field != NULL); if (entry.ClassSignatureToken != mdTokenNil) - resolvedToken->Signature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); resolvedToken->Field = entry.Entry.Field; - resolvedToken->TypeHandle = TypeHandle(entry.Entry.Field->GetApproxEnclosingMethodTable()).Instantiate(entry.ClassInstantiation); + resolvedToken->TypeHandle = TypeHandle(entry.Entry.Field->GetApproxEnclosingMethodTable()); } else { _ASSERTE(entry.Type == mdtMethodDef); _ASSERTE(entry.Entry.Method != NULL); + + if (entry.ClassSignatureToken != mdTokenNil) + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + resolvedToken->Method = entry.Entry.Method; + MethodTable* pMT = entry.Entry.Method->GetMethodTable(); + _ASSERTE(!pMT->ContainsGenericVariables()); + resolvedToken->TypeHandle = TypeHandle(pMT); } } break; + + case mdtMethodSpec: + { + TokenLookupMap::MethodSpecEntry entry = m_pCompileTimeState->m_tokenLookupMap.LookupMethodSpec(token); + _ASSERTE(entry.Method != NULL); + + if (entry.ClassSignatureToken != mdTokenNil) + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + if (entry.MethodSignatureToken != mdTokenNil) + resolvedToken->MethodSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.MethodSignatureToken); + + resolvedToken->Method = entry.Method; + MethodTable* pMT = entry.Method->GetMethodTable(); + _ASSERTE(!pMT->ContainsGenericVariables()); + resolvedToken->TypeHandle = TypeHandle(pMT); + } + break; #endif // !defined(DACCESS_COMPILE) default: diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index c7061060b9581..c49e3fb2f8e4c 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -888,6 +888,12 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken pMD = resolved.Method; pFD = resolved.Field; + // Record supplied signatures. + if (!resolved.TypeSignature.IsNull()) + resolved.TypeSignature.GetSignature(&pResolvedToken->pTypeSpec, &pResolvedToken->cbTypeSpec); + if (!resolved.MethodSignature.IsNull()) + resolved.MethodSignature.GetSignature(&pResolvedToken->pMethodSpec, &pResolvedToken->cbMethodSpec); + // // Check that we got the expected handles and fill in missing data if necessary // @@ -896,17 +902,11 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken if (pMD != NULL) { - if ((tkType != mdtMethodDef) && (tkType != mdtMemberRef)) + if ((tkType != mdtMethodDef) && (tkType != mdtMemberRef) && (tkType != mdtMethodSpec)) ThrowBadTokenException(pResolvedToken); if ((tokenType & CORINFO_TOKENKIND_Method) == 0) ThrowBadTokenException(pResolvedToken); - // "PermitUninstDefOrRef" check - if ((tokenType != CORINFO_TOKENKIND_Ldtoken) && pMD->ContainsGenericVariables()) - { - COMPlusThrow(kInvalidProgramException); - } - // if this is a BoxedEntryPointStub get the UnboxedEntryPoint one if (pMD->IsUnboxingStub()) { @@ -926,12 +926,6 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken if ((tokenType & CORINFO_TOKENKIND_Field) == 0) ThrowBadTokenException(pResolvedToken); - // If a signature was supplied, set the TypeSpec for resolution. - if (!resolved.Signature.IsNull()) - { - resolved.Signature.GetSignature(&pResolvedToken->pTypeSpec, &pResolvedToken->cbTypeSpec); - } - if (pFD->IsStatic() && (tokenType != CORINFO_TOKENKIND_Ldtoken)) { EnsureActive(th); @@ -964,7 +958,7 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken else { mdToken metaTOK = pResolvedToken->token; - Module * pModule = (Module *)pResolvedToken->tokenScope; + Module * pModule = GetModule(pResolvedToken->tokenScope); switch (TypeFromToken(metaTOK)) { @@ -4943,7 +4937,7 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( SigTypeContext::InitTypeContext(pCallerForSecurity, &typeContext); SigPointer sigptr(pResolvedToken->pTypeSpec, pResolvedToken->cbTypeSpec); - pCalleeForSecurity = sigptr.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + pCalleeForSecurity = sigptr.GetTypeHandleThrowing(GetModule(pResolvedToken->tokenScope), &typeContext); } while (pCalleeForSecurity.HasTypeParam()) @@ -5541,7 +5535,7 @@ void CEEInfo::getCallInfo( if (pResolvedToken->pTypeSpec != NULL) { SigPointer sigptr(pResolvedToken->pTypeSpec, pResolvedToken->cbTypeSpec); - calleeTypeForSecurity = sigptr.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + calleeTypeForSecurity = sigptr.GetTypeHandleThrowing(GetModule(pResolvedToken->tokenScope), &typeContext); // typeHnd can be a variable type if (calleeTypeForSecurity.GetMethodTable() == NULL) @@ -5568,7 +5562,7 @@ void CEEInfo::getCallInfo( IfFailThrow(sp.GetByte(&etype)); // Load the generic method instantiation - THROW_BAD_FORMAT_MAYBE(etype == (BYTE)IMAGE_CEE_CS_CALLCONV_GENERICINST, 0, (Module *)pResolvedToken->tokenScope); + THROW_BAD_FORMAT_MAYBE(etype == (BYTE)IMAGE_CEE_CS_CALLCONV_GENERICINST, 0, GetModule(pResolvedToken->tokenScope)); IfFailThrow(sp.GetData(&nGenericMethodArgs)); @@ -5582,7 +5576,7 @@ void CEEInfo::getCallInfo( for (uint32_t i = 0; i < nGenericMethodArgs; i++) { - genericMethodArgs[i] = sp.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + genericMethodArgs[i] = sp.GetTypeHandleThrowing(GetModule(pResolvedToken->tokenScope), &typeContext); _ASSERTE (!genericMethodArgs[i].IsNull()); IfFailThrow(sp.SkipExactlyOne()); } diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 423f15a935936..332c4a92ccc58 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1150,6 +1150,7 @@ namespace bool DoesMethodMatchUnsafeAccessorDeclaration( GenerationContext& cxt, MethodDesc* method, + const Substitution* pMethodSubst, MetaSig::CompareState& state) { STANDARD_VM_CONTRACT; @@ -1167,7 +1168,11 @@ namespace method->GetSig(&pSig2, &cSig2); PCCOR_SIGNATURE pEndSig2 = pSig2 + cSig2; ModuleBase* pModule2 = method->GetModule(); - const Substitution* pSubst2 = NULL; + const Substitution* pSubst2 = pMethodSubst; + + // + // Parsing the signature follows details defined in ECMA-335 - II.23.2.1 + // // Validate calling convention if ((*pSig1 & IMAGE_CEE_CS_CALLCONV_MASK) != (*pSig2 & IMAGE_CEE_CS_CALLCONV_MASK)) @@ -1175,10 +1180,19 @@ namespace return false; } - BYTE callConv = *pSig1; + BYTE callConvDecl = *pSig1; + BYTE callConvMethod = *pSig2; pSig1++; pSig2++; + // Handle generic param count + DWORD declGenericCount = 0; + DWORD methodGenericCount = 0; + if (callConvDecl & IMAGE_CEE_CS_CALLCONV_GENERIC) + IfFailThrow(CorSigUncompressData_EndPtr(pSig1, pEndSig1, &declGenericCount)); + if (callConvMethod & IMAGE_CEE_CS_CALLCONV_GENERIC) + IfFailThrow(CorSigUncompressData_EndPtr(pSig2, pEndSig2, &methodGenericCount)); + DWORD declArgCount; DWORD methodArgCount; IfFailThrow(CorSigUncompressData_EndPtr(pSig1, pEndSig1, &declArgCount)); @@ -1266,7 +1280,33 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + // We do not support the Canon type as a valid target. + if (targetType.AsMethodTable() == g_pCanonMethodTableClass) + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + MethodDesc* targetMaybe = NULL; + Substitution* pLookupSubst = NULL; + + // Build up a Substitution to use when looking up methods involving generics. + Substitution substitution; + SigBuilder sigBuilder; + DWORD targetGenericParamCount = targetType.AsMethodTable()->GetNumGenericArgs(); + if (targetGenericParamCount > 0) + { + // Create a temporary signature that translate VARs to MVARs. + for (DWORD i = 0; i < targetGenericParamCount; ++i) + { + sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); + sigBuilder.AppendData(i); // Represents the generic parameter index - II.23.2.12 + } + + DWORD tmpSigLen; + PVOID tmpSigRaw = sigBuilder.GetSignature(&tmpSigLen); + + SigPointer tmpSig{ (PCCOR_SIGNATURE)tmpSigRaw, tmpSigLen }; + substitution = Substitution{ cxt.Declaration->GetModule(), tmpSig, NULL }; + pLookupSubst = &substitution; + } // Following a similar iteration pattern found in MemberLoader::FindMethod(). // However, we are only operating on the current type not walking the type hierarchy. @@ -1287,7 +1327,7 @@ namespace TokenPairList list { nullptr }; MetaSig::CompareState state{ &list }; state.IgnoreCustomModifiers = ignoreCustomModifiers; - if (!DoesMethodMatchUnsafeAccessorDeclaration(cxt, curr, state)) + if (!DoesMethodMatchUnsafeAccessorDeclaration(cxt, curr, pLookupSubst, state)) continue; // Check if there is some ambiguity. @@ -1323,6 +1363,10 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + // We do not support the Canon type as a valid target. + if (targetType.AsMethodTable() == g_pCanonMethodTableClass) + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + CorElementType elemType = fieldType.GetSignatureCorElementType(); ApproxFieldDescIterator fdIterator( targetType.AsMethodTable(), @@ -1396,13 +1440,75 @@ namespace switch (cxt.Kind) { case UnsafeAccessorKind::Constructor: + { _ASSERTE(cxt.TargetMethod != NULL); - pCode->EmitNEWOBJ(pCode->GetToken(cxt.TargetMethod), targetArgCount); + mdToken target; + if (!cxt.TargetType.HasInstantiation()) + { + target = pCode->GetToken(cxt.TargetMethod); + } + else + { + PCCOR_SIGNATURE sig; + uint32_t sigLen; + cxt.TargetTypeSig.GetSignature(&sig, &sigLen); + mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + target = pCode->GetToken(cxt.TargetMethod, targetTypeSigToken); + } + pCode->EmitNEWOBJ(target, targetArgCount); break; + } case UnsafeAccessorKind::Method: + { _ASSERTE(cxt.TargetMethod != NULL); - pCode->EmitCALLVIRT(pCode->GetToken(cxt.TargetMethod), targetArgCount, targetRetCount); + mdToken target; + if (!cxt.TargetMethod->HasMethodInstantiation()) + { + target = pCode->GetToken(cxt.TargetMethod); + } + else + { + // Create signature for the MethodSpec. See ECMA-335 - II.23.2.15 + DWORD targetGenericCount = cxt.TargetMethod->GetNumGenericMethodArgs(); + _ASSERTE(targetGenericCount != 0); + + SigBuilder sigBuilder; + sigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); + sigBuilder.AppendData(targetGenericCount); + for (DWORD i = 0; i < targetGenericCount; ++i) + { + sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); + sigBuilder.AppendData(i); + } + uint32_t sigLen; + PCCOR_SIGNATURE sig = (PCCOR_SIGNATURE)sigBuilder.GetSignature((DWORD*)&sigLen); + mdToken methodSpecSigToken = pCode->GetSigToken(sig, sigLen); + + // Create a MethodSpec + cxt.TargetTypeSig.GetSignature(&sig, &sigLen); + mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + + // Convert the declaration Instantiation to one that can be used + // to find the instantiated MethodDesc target. + Instantiation methodInst = cxt.Declaration->GetMethodInstantiation(); + DWORD declGenericCount = cxt.Declaration->GetNumGenericMethodArgs(); + + // Generic parameters that are used for the type, must come at the end + // of the parameter list. For example, + // + // [UnsafeAccessor(UnsafeAccessorKind.Method)] + // extern static void M(C c, T element); + // + if (declGenericCount > targetGenericCount) + methodInst = Instantiation{ methodInst.GetRawArgs(), targetGenericCount }; + + // Look up the instantiated MethodDesc target. + MethodDesc* instantiatedTarget = MethodDesc::FindOrCreateAssociatedMethodDesc(cxt.TargetMethod, cxt.TargetType.GetMethodTable(), FALSE, methodInst, TRUE); + target = pCode->GetToken(instantiatedTarget, targetTypeSigToken, methodSpecSigToken); + } + pCode->EmitCALLVIRT(target, targetArgCount, targetRetCount); break; + } case UnsafeAccessorKind::StaticMethod: _ASSERTE(cxt.TargetMethod != NULL); pCode->EmitCALL(pCode->GetToken(cxt.TargetMethod), targetArgCount, targetRetCount); @@ -1419,7 +1525,7 @@ namespace { // See the static field case for why this can be mdTokenNil. mdToken targetTypeSigToken = mdTokenNil; - target = pCode->GetToken(cxt.TargetField, targetTypeSigToken, cxt.TargetType.GetInstantiation()); + target = pCode->GetToken(cxt.TargetField, targetTypeSigToken); } pCode->EmitLDFLDA(target); break; @@ -1443,7 +1549,7 @@ namespace uint32_t sigLen; cxt.TargetTypeSig.GetSignature(&sig, &sigLen); mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); - target = pCode->GetToken(cxt.TargetField, targetTypeSigToken, cxt.TargetType.GetInstantiation()); + target = pCode->GetToken(cxt.TargetField, targetTypeSigToken); } pCode->EmitLDSFLDA(target); break; @@ -1509,9 +1615,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET // Block generic support on methods if (HasClassOrMethodInstantiation() - && (kind == UnsafeAccessorKind::Constructor - || kind == UnsafeAccessorKind::Method - || kind == UnsafeAccessorKind::StaticMethod)) + && kind == UnsafeAccessorKind::StaticMethod) { ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } @@ -1550,6 +1654,8 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } + // Get the target type signature from the return type. + context.TargetTypeSig = context.DeclarationSig.GetReturnProps(); context.TargetType = ValidateTargetType(retType); if (!TrySetTargetMethod(context, ".ctor")) MemberLoader::ThrowMissingMethodException(context.TargetType.AsMethodTable(), ".ctor"); diff --git a/src/coreclr/vm/siginfo.hpp b/src/coreclr/vm/siginfo.hpp index a0ec6b3d4a26c..fab9a79260d2d 100644 --- a/src/coreclr/vm/siginfo.hpp +++ b/src/coreclr/vm/siginfo.hpp @@ -394,7 +394,7 @@ class Substitution Substitution( ModuleBase * pModuleArg, - const SigPointer & sigInst, + SigPointer sigInst, const Substitution * pNextSubstitution) { LIMITED_METHOD_CONTRACT; diff --git a/src/coreclr/vm/stubgen.cpp b/src/coreclr/vm/stubgen.cpp index 43c1d1bfca01c..88cc6fde7679d 100644 --- a/src/coreclr/vm/stubgen.cpp +++ b/src/coreclr/vm/stubgen.cpp @@ -3133,6 +3133,18 @@ int ILStubLinker::GetToken(MethodDesc* pMD) return m_tokenMap.GetToken(pMD); } +int ILStubLinker::GetToken(MethodDesc* pMD, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_tokenMap.GetToken(pMD, typeSignature); +} + +int ILStubLinker::GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature) +{ + STANDARD_VM_CONTRACT; + return m_tokenMap.GetToken(pMD, typeSignature, methodSignature); +} + int ILStubLinker::GetToken(MethodTable* pMT) { STANDARD_VM_CONTRACT; @@ -3151,10 +3163,10 @@ int ILStubLinker::GetToken(FieldDesc* pFD) return m_tokenMap.GetToken(pFD); } -int ILStubLinker::GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst) +int ILStubLinker::GetToken(FieldDesc* pFD, mdToken typeSignature) { STANDARD_VM_CONTRACT; - return m_tokenMap.GetToken(pFD, typeSignature, inst); + return m_tokenMap.GetToken(pFD, typeSignature); } int ILStubLinker::GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) @@ -3233,6 +3245,16 @@ int ILCodeStream::GetToken(MethodDesc* pMD) STANDARD_VM_CONTRACT; return m_pOwner->GetToken(pMD); } +int ILCodeStream::GetToken(MethodDesc* pMD, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(pMD, typeSignature); +} +int ILCodeStream::GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(pMD, typeSignature, methodSignature); +} int ILCodeStream::GetToken(MethodTable* pMT) { STANDARD_VM_CONTRACT; @@ -3248,10 +3270,10 @@ int ILCodeStream::GetToken(FieldDesc* pFD) STANDARD_VM_CONTRACT; return m_pOwner->GetToken(pFD); } -int ILCodeStream::GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst) +int ILCodeStream::GetToken(FieldDesc* pFD, mdToken typeSignature) { STANDARD_VM_CONTRACT; - return m_pOwner->GetToken(pFD, typeSignature, inst); + return m_pOwner->GetToken(pFD, typeSignature); } int ILCodeStream::GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { diff --git a/src/coreclr/vm/stubgen.h b/src/coreclr/vm/stubgen.h index eb4c86f4bc400..4c30867761240 100644 --- a/src/coreclr/vm/stubgen.h +++ b/src/coreclr/vm/stubgen.h @@ -313,6 +313,7 @@ class TokenLookupMap } m_memberRefs.Set(pSrc->m_memberRefs); + m_methodSpecs.Set(pSrc->m_methodSpecs); } TypeHandle LookupTypeDef(mdToken token) @@ -335,7 +336,6 @@ class TokenLookupMap { CorTokenType Type; mdToken ClassSignatureToken; - Instantiation ClassInstantiation; union { FieldDesc* Field; @@ -358,6 +358,28 @@ class TokenLookupMap return m_memberRefs[static_cast(RidFromToken(token) - 1)]; } + struct MethodSpecEntry final + { + mdToken ClassSignatureToken; + mdToken MethodSignatureToken; + MethodDesc* Method; + }; + MethodSpecEntry LookupMethodSpec(mdToken token) + { + CONTRACTL + { + NOTHROW; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(RidFromToken(token) - 1 < m_methodSpecs.GetCount()); + PRECONDITION(RidFromToken(token) != 0); + PRECONDITION(TypeFromToken(token) == mdtMethodSpec); + } + CONTRACTL_END; + + return m_methodSpecs[static_cast(RidFromToken(token) - 1)]; + } + SigPointer LookupSig(mdToken token) { CONTRACTL @@ -387,12 +409,50 @@ class TokenLookupMap WRAPPER_NO_CONTRACT; return GetTokenWorker(pMD); } + mdToken GetToken(MethodDesc* pMD, mdToken typeSignature) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(pMD != NULL); + } + CONTRACTL_END; + + MemberRefEntry* entry; + mdToken token = GetMemberRefWorker(&entry); + entry->Type = mdtMethodDef; + entry->ClassSignatureToken = typeSignature; + entry->Entry.Method = pMD; + return token; + } + mdToken GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(pMD != NULL); + PRECONDITION(typeSignature != mdTokenNil); + PRECONDITION(methodSignature != mdTokenNil); + } + CONTRACTL_END; + + MethodSpecEntry* entry; + mdToken token = GetMethodSpecWorker(&entry); + entry->ClassSignatureToken = typeSignature; + entry->MethodSignatureToken = methodSignature; + entry->Method = pMD; + return token; + } mdToken GetToken(FieldDesc* pFieldDesc) { WRAPPER_NO_CONTRACT; return GetTokenWorker(pFieldDesc); } - mdToken GetToken(FieldDesc* pFieldDesc, mdToken typeSignature, Instantiation inst) + mdToken GetToken(FieldDesc* pFieldDesc, mdToken typeSignature) { CONTRACTL { @@ -400,7 +460,6 @@ class TokenLookupMap MODE_ANY; GC_NOTRIGGER; PRECONDITION(pFieldDesc != NULL); - PRECONDITION(!inst.IsEmpty()); } CONTRACTL_END; @@ -408,7 +467,6 @@ class TokenLookupMap mdToken token = GetMemberRefWorker(&entry); entry->Type = mdtFieldDef; entry->ClassSignatureToken = typeSignature; - entry->ClassInstantiation = inst; entry->Entry.Field = pFieldDesc; return token; } @@ -448,6 +506,22 @@ class TokenLookupMap return token; } + mdToken GetMethodSpecWorker(MethodSpecEntry** entry) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(entry != NULL); + } + CONTRACTL_END; + + mdToken token = TokenFromRid(m_methodSpecs.GetCount(), mdtMethodSpec) + 1; + *entry = &*m_methodSpecs.Append(); // Dereference the iterator and then take the address + return token; + } + template HandleType LookupTokenWorker(mdToken token) { @@ -493,6 +567,7 @@ class TokenLookupMap CQuickBytesSpecifySize m_qbEntries; SArray, FALSE> m_signatures; SArray m_memberRefs; + SArray m_methodSpecs; }; class ILCodeLabel; @@ -659,10 +734,12 @@ class ILStubLinker // ILCodeLabel* NewCodeLabel(); int GetToken(MethodDesc* pMD); + int GetToken(MethodDesc* pMD, mdToken typeSignature); + int GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature); int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); int GetToken(FieldDesc* pFD); - int GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst); + int GetToken(FieldDesc* pFD, mdToken typeSignature); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); DWORD NewLocal(CorElementType typ = ELEMENT_TYPE_I); DWORD NewLocal(LocalDesc loc); @@ -889,10 +966,12 @@ class ILCodeStream // int GetToken(MethodDesc* pMD); + int GetToken(MethodDesc* pMD, mdToken typeSignature); + int GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature); int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); int GetToken(FieldDesc* pFD); - int GetToken(FieldDesc* pFD, mdToken typeSignature, Instantiation inst); + int GetToken(FieldDesc* pFD, mdToken typeSignature); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); DWORD NewLocal(CorElementType typ = ELEMENT_TYPE_I); From 7495477e9135f5c4cbd6d1eac1649ff2967bae45 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Fri, 1 Mar 2024 16:30:38 -0800 Subject: [PATCH 03/13] Generic static methods --- src/coreclr/vm/prestub.cpp | 96 ++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 41 deletions(-) diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 332c4a92ccc58..0d84b6a81efae 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1363,13 +1363,15 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + MethodTable* pMT = targetType.AsMethodTable(); + // We do not support the Canon type as a valid target. - if (targetType.AsMethodTable() == g_pCanonMethodTableClass) + if (pMT == g_pCanonMethodTableClass) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); CorElementType elemType = fieldType.GetSignatureCorElementType(); ApproxFieldDescIterator fdIterator( - targetType.AsMethodTable(), + pMT, (cxt.IsTargetStatic ? ApproxFieldDescIterator::STATIC_FIELDS : ApproxFieldDescIterator::INSTANCE_FIELDS)); PTR_FieldDesc pField; while ((pField = fdIterator.Next()) != NULL) @@ -1393,6 +1395,17 @@ namespace continue; } + if (cxt.Kind == UnsafeAccessorKind::StaticField && pMT->HasGenericsStaticsInfo()) + { + // Statics require the exact typed field as opposed to the canonically + // typed field. In order to do that we lookup the current index of the + // approx field and then use that index to get the precise field from + // the approx field. + MethodTable* pFieldMT = pField->GetApproxEnclosingMethodTable(); + DWORD index = pFieldMT->GetIndexForFieldDesc(pField); + pField = pMT->GetFieldDescByIndex(index); + } + cxt.TargetField = pField; return true; } @@ -1459,60 +1472,68 @@ namespace break; } case UnsafeAccessorKind::Method: + case UnsafeAccessorKind::StaticMethod: { _ASSERTE(cxt.TargetMethod != NULL); mdToken target; - if (!cxt.TargetMethod->HasMethodInstantiation()) + if (!cxt.TargetMethod->HasClassOrMethodInstantiation()) { target = pCode->GetToken(cxt.TargetMethod); } else { - // Create signature for the MethodSpec. See ECMA-335 - II.23.2.15 DWORD targetGenericCount = cxt.TargetMethod->GetNumGenericMethodArgs(); - _ASSERTE(targetGenericCount != 0); + mdToken methodSpecSigToken = mdTokenNil; SigBuilder sigBuilder; - sigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); - sigBuilder.AppendData(targetGenericCount); - for (DWORD i = 0; i < targetGenericCount; ++i) + uint32_t sigLen; + PCCOR_SIGNATURE sig; + if (targetGenericCount != 0) { - sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); - sigBuilder.AppendData(i); + // Create signature for the MethodSpec. See ECMA-335 - II.23.2.15 + sigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); + sigBuilder.AppendData(targetGenericCount); + for (DWORD i = 0; i < targetGenericCount; ++i) + { + sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); + sigBuilder.AppendData(i); + } + sigLen; + sig = (PCCOR_SIGNATURE)sigBuilder.GetSignature((DWORD*)&sigLen); + methodSpecSigToken = pCode->GetSigToken(sig, sigLen); } - uint32_t sigLen; - PCCOR_SIGNATURE sig = (PCCOR_SIGNATURE)sigBuilder.GetSignature((DWORD*)&sigLen); - mdToken methodSpecSigToken = pCode->GetSigToken(sig, sigLen); - // Create a MethodSpec cxt.TargetTypeSig.GetSignature(&sig, &sigLen); mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); - // Convert the declaration Instantiation to one that can be used - // to find the instantiated MethodDesc target. - Instantiation methodInst = cxt.Declaration->GetMethodInstantiation(); - DWORD declGenericCount = cxt.Declaration->GetNumGenericMethodArgs(); + if (methodSpecSigToken == mdTokenNil) + { + // Create a MemberRef + target = pCode->GetToken(cxt.TargetMethod, targetTypeSigToken); + _ASSERTE(TypeFromToken(target) == mdtMemberRef); + } + else + { + // Use the method declaration Instantiation to find the instantiated MethodDesc target. + Instantiation methodInst = cxt.Declaration->GetMethodInstantiation(); + MethodDesc* instantiatedTarget = MethodDesc::FindOrCreateAssociatedMethodDesc(cxt.TargetMethod, cxt.TargetType.GetMethodTable(), FALSE, methodInst, TRUE); - // Generic parameters that are used for the type, must come at the end - // of the parameter list. For example, - // - // [UnsafeAccessor(UnsafeAccessorKind.Method)] - // extern static void M(C c, T element); - // - if (declGenericCount > targetGenericCount) - methodInst = Instantiation{ methodInst.GetRawArgs(), targetGenericCount }; + // Create a MethodSpec + target = pCode->GetToken(instantiatedTarget, targetTypeSigToken, methodSpecSigToken); + _ASSERTE(TypeFromToken(target) == mdtMethodSpec); + } + } - // Look up the instantiated MethodDesc target. - MethodDesc* instantiatedTarget = MethodDesc::FindOrCreateAssociatedMethodDesc(cxt.TargetMethod, cxt.TargetType.GetMethodTable(), FALSE, methodInst, TRUE); - target = pCode->GetToken(instantiatedTarget, targetTypeSigToken, methodSpecSigToken); + if (cxt.Kind == UnsafeAccessorKind::StaticMethod) + { + pCode->EmitCALL(target, targetArgCount, targetRetCount); + } + else + { + pCode->EmitCALLVIRT(target, targetArgCount, targetRetCount); } - pCode->EmitCALLVIRT(target, targetArgCount, targetRetCount); break; } - case UnsafeAccessorKind::StaticMethod: - _ASSERTE(cxt.TargetMethod != NULL); - pCode->EmitCALL(pCode->GetToken(cxt.TargetMethod), targetArgCount, targetRetCount); - break; case UnsafeAccessorKind::Field: { _ASSERTE(cxt.TargetField != NULL); @@ -1613,13 +1634,6 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET if (!TryParseUnsafeAccessorAttribute(this, ca, kind, name)) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - // Block generic support on methods - if (HasClassOrMethodInstantiation() - && kind == UnsafeAccessorKind::StaticMethod) - { - ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - } - GenerationContext context{ kind, this }; // Parse the signature to determine the type to use: From 4c099189c4f68d8638265fa0ec6507166f99634e Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Mon, 4 Mar 2024 11:37:01 -0800 Subject: [PATCH 04/13] Add testing for generic scenarios --- .../UnsafeAccessorsTests.Generics.cs | 370 ++++++++++++++++++ .../UnsafeAccessors/UnsafeAccessorsTests.cs | 99 ----- .../UnsafeAccessorsTests.csproj | 1 + 3 files changed, 371 insertions(+), 99 deletions(-) create mode 100644 src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs new file mode 100644 index 0000000000000..fe23497fd8ec0 --- /dev/null +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -0,0 +1,370 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using Xunit; + +struct Struct { } + +public static unsafe class UnsafeAccessorsTestsGenerics +{ + class MyList + { + public const string StaticGenericFieldName = nameof(_GF); + public const string StaticFieldName = nameof(_F); + public const string GenericFieldName = nameof(_list); + + static MyList() + { + _F = typeof(T).ToString(); + } + + public static void SetStaticGenericField(T val) => _GF = val; + private static T _GF; + private static string _F; + + private List _list; + + public MyList() => _list = new(); + + private MyList(int i) => _list = new(i); + + private MyList(List list) => _list = list; + + private void Clear() => _list.Clear(); + + private void Add(T t) => _list.Add(t); + + private bool CanCastToElementType(U t) => t is T; + + private static bool CanUseElementType(U t) => t is T; + + private static Type ElementType() => typeof(T); + + private void Add(int a) => + Unsafe.As>(_list).Add(a); + + private void Add(string a) => + Unsafe.As>(_list).Add(a); + + private void Add(Struct a) => + Unsafe.As>(_list).Add(a); + + public int Count => _list.Count; + + public int Capacity => _list.Capacity; + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_AccessStaticFieldClass() + { + Console.WriteLine($"Running {nameof(Verify_Generic_AccessStaticFieldClass)}"); + + Assert.Equal(typeof(int).ToString(), GetPrivateStaticFieldInt((MyList)null)); + + Assert.Equal(typeof(string).ToString(), GetPrivateStaticFieldString((MyList)null)); + + Assert.Equal(typeof(Struct).ToString(), GetPrivateStaticFieldStruct((MyList)null)); + + { + int expected = 10; + MyList.SetStaticGenericField(expected); + Assert.Equal(expected, GetPrivateStaticField((MyList)null)); + } + { + string expected = "abc"; + MyList.SetStaticGenericField(expected); + Assert.Equal(expected, GetPrivateStaticField((MyList)null)); + } + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticFieldName)] + extern static ref string GetPrivateStaticFieldInt(MyList d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticFieldName)] + extern static ref string GetPrivateStaticFieldString(MyList d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticFieldName)] + extern static ref string GetPrivateStaticFieldStruct(MyList d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticGenericFieldName)] + extern static ref T GetPrivateStaticField(MyList d); + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_AccessFieldClass() + { + Console.WriteLine($"Running {nameof(Verify_Generic_AccessFieldClass)}"); + { + MyList a = new(); + Assert.NotNull(GetPrivateField(a)); + } + { + MyList a = new(); + Assert.NotNull(GetPrivateField(a)); + } + { + MyList a = new(); + Assert.NotNull(GetPrivateField(a)); + } + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name=MyList.GenericFieldName)] + extern static ref List GetPrivateField(MyList a); + } + + class Base + { + protected virtual string CreateMessage(T t) => $"{nameof(Base)}:{t}"; + } + + class Derived1 : Base + { + protected override string CreateMessage(T t) => $"{nameof(Derived1)}:{t}"; + } + + sealed class Derived2 : Derived1 + { + protected override string CreateMessage(T t) => $"{nameof(Derived2)}:{t}"; + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_InheritanceMethodResolution() + { + string expect = "abc"; + Console.WriteLine($"Running {nameof(Verify_Generic_InheritanceMethodResolution)}"); + { + Base a = new(); + Assert.Equal($"{nameof(Base)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Base)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Base)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + Derived1 a = new(); + Assert.Equal($"{nameof(Derived1)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Derived1)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Derived1)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + Derived2 a = new(); + Assert.Equal($"{nameof(Derived2)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Derived2)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Derived2)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessage")] + extern static string CreateMessage(Base b, U t); + } + + sealed class Accessors + { + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + public extern static MyList Create(int a); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + public extern static MyList CreateWithList(List a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = ".ctor")] + public extern static void CallCtorAsMethod(MyList l, List a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void AddInt(MyList l, int a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void AddString(MyList l, string a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void AddStruct(MyList l, Struct a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Clear")] + public extern static void Clear(MyList l); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void Add(MyList l, U element); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CanCastToElementType")] + public extern static bool CanCastToElementType(MyList l, U element); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "ElementType")] + public extern static Type ElementType(MyList l); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "CanUseElementType")] + public extern static bool CanUseElementType(MyList l, U element); + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_CallCtor() + { + Console.WriteLine($"Running {nameof(Verify_Generic_CallCtor)}"); + + // Call constructor with non-generic parameter + { + MyList a = Accessors.Create(1); + Assert.Equal(1, a.Capacity); + } + { + MyList a = Accessors.Create(2); + Assert.Equal(2, a.Capacity); + } + { + MyList a = Accessors.Create(3); + Assert.Equal(3, a.Capacity); + } + + // Call constructor using generic parameter + { + MyList a = Accessors.CreateWithList([ 1 ]); + Assert.Equal(1, a.Count); + } + { + MyList a = Accessors.CreateWithList([ "1", "2" ]); + Assert.Equal(2, a.Count); + } + { + MyList a = Accessors.CreateWithList([new Struct(), new Struct(), new Struct()]); + Assert.Equal(3, a.Count); + } + + // Call constructors as methods + { + MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); + Accessors.CallCtorAsMethod(a, [1]); + Assert.Equal(1, a.Count); + } + { + MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); + Accessors.CallCtorAsMethod(a, ["1", "2"]); + Assert.Equal(2, a.Count); + } + { + MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); + Accessors.CallCtorAsMethod(a, [new Struct(), new Struct(), new Struct()]); + Assert.Equal(3, a.Count); + } + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeNonGenericInstanceMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeNonGenericInstanceMethod)}"); + { + MyList a = new(); + Accessors.AddInt(a, 1); + Assert.Equal(1, a.Count); + Accessors.Clear(a); + Assert.Equal(0, a.Count); + } + { + MyList a = new(); + Accessors.AddString(a, "1"); + Accessors.AddString(a, "2"); + Assert.Equal(2, a.Count); + Accessors.Clear(a); + Assert.Equal(0, a.Count); + } + { + MyList a = new(); + Accessors.AddStruct(a, new Struct()); + Accessors.AddStruct(a, new Struct()); + Accessors.AddStruct(a, new Struct()); + Assert.Equal(3, a.Count); + Accessors.Clear(a); + Assert.Equal(0, a.Count); + } + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeGenericInstanceMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeGenericInstanceMethod)}"); + { + MyList a = new(); + Assert.True(Accessors.CanCastToElementType(a, 1)); + Assert.False(Accessors.CanCastToElementType(a, string.Empty)); + Assert.False(Accessors.CanCastToElementType(a, new Struct())); + Accessors.Add(a, 1); + Assert.Equal(1, a.Count); + } + { + MyList a = new(); + Assert.False(Accessors.CanCastToElementType(a, 1)); + Assert.True(Accessors.CanCastToElementType(a, string.Empty)); + Assert.False(Accessors.CanCastToElementType(a, new Struct())); + Accessors.Add(a, string.Empty); + Assert.Equal(1, a.Count); + } + { + MyList a = new(); + Assert.False(Accessors.CanCastToElementType(a, 1)); + Assert.False(Accessors.CanCastToElementType(a, string.Empty)); + Assert.True(Accessors.CanCastToElementType(a, new Struct())); + Accessors.Add(a, new Struct()); + Assert.Equal(1, a.Count); + } + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeNonGenericStaticMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeNonGenericStaticMethod)}"); + { + Assert.Equal(typeof(int), Accessors.ElementType(null)); + Assert.Equal(typeof(string), Accessors.ElementType(null)); + Assert.Equal(typeof(Struct), Accessors.ElementType(null)); + } + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeGenericStaticMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeGenericStaticMethod)}"); + { + Assert.True(Accessors.CanUseElementType(null, 1)); + Assert.False(Accessors.CanUseElementType(null, string.Empty)); + Assert.False(Accessors.CanUseElementType(null, new Struct())); + } + { + Assert.False(Accessors.CanUseElementType(null, 1)); + Assert.True(Accessors.CanUseElementType(null, string.Empty)); + Assert.False(Accessors.CanUseElementType(null, new Struct())); + } + { + Assert.False(Accessors.CanUseElementType(null, 1)); + Assert.False(Accessors.CanUseElementType(null, string.Empty)); + Assert.True(Accessors.CanUseElementType(null, new Struct())); + } + } + + class Invalid + { + [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] + public static extern string CallToString(U a); + } + + class Invalid + { + [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] + public static extern string CallToString(T a); + } + + [Fact] + [ActiveIssue("", TestRuntimes.Mono)] + public static void Verify_Generic_InvalidUseUnsafeAccessor() + { + Console.WriteLine($"Running {nameof(Verify_Generic_InvalidUseUnsafeAccessor)}"); + + Assert.Throws(() => Invalid.CallToString(string.Empty)); + Assert.Throws(() => Invalid.CallToString(string.Empty)); + } +} diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs index 8d059042ea739..30f65993da6cc 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs @@ -85,35 +85,6 @@ struct UserDataValue public string GetFieldValue() => _f; } - class UserDataGenericClass - { - public const string StaticGenericFieldName = nameof(_GF); - public const string GenericFieldName = nameof(_gf); - public const string StaticGenericMethodName = nameof(_GM); - public const string GenericMethodName = nameof(_gm); - - public const string StaticFieldName = nameof(_F); - public const string FieldName = nameof(_f); - public const string StaticMethodName = nameof(_M); - public const string MethodName = nameof(_m); - - private static T _GF; - private T _gf; - - public static void SetStaticGenericField(T val) => _GF = val; - - private static string _F = PrivateStatic; - private string _f; - - public UserDataGenericClass(T t) { _f = Private; _gf = t; } - - private static string _GM(T s, ref T sr, in T si) => typeof(T).ToString(); - private string _gm(T s, ref T sr, in T si) => typeof(T).ToString(); - - private static string _M(string s, ref string sr, in string si) => s; - private string _m(string s, ref string sr, in string si) => s; - } - [UnsafeAccessor(UnsafeAccessorKind.Constructor)] extern static UserDataClass CallPrivateConstructorClass(); @@ -217,36 +188,6 @@ public static void Verify_AccessFieldClass() extern static ref string GetPrivateField(UserDataClass d); } - [Fact] - public static void Verify_AccessStaticFieldGenericClass() - { - Console.WriteLine($"Running {nameof(Verify_AccessStaticFieldGenericClass)}"); - - Assert.Equal(PrivateStatic, GetPrivateStaticFieldInt((UserDataGenericClass)null)); - - Assert.Equal(PrivateStatic, GetPrivateStaticFieldString((UserDataGenericClass)null)); - - { - int expected = 10; - UserDataGenericClass.SetStaticGenericField(expected); - Assert.Equal(expected, GetPrivateStaticField((UserDataGenericClass)null)); - } - { - string expected = "abc"; - UserDataGenericClass.SetStaticGenericField(expected); - Assert.Equal(expected, GetPrivateStaticField((UserDataGenericClass)null)); - } - - [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticFieldName)] - extern static ref string GetPrivateStaticFieldInt(UserDataGenericClass d); - - [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticFieldName)] - extern static ref string GetPrivateStaticFieldString(UserDataGenericClass d); - - [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticGenericFieldName)] - extern static ref T GetPrivateStaticField(UserDataGenericClass d); - } - [Fact] public static void Verify_AccessStaticFieldValue() { @@ -274,35 +215,6 @@ public static void Verify_AccessFieldValue() extern static ref string GetPrivateField(ref UserDataValue d); } - [Fact] - public static void Verify_AccessFieldGenericClass() - { - Console.WriteLine($"Running {nameof(Verify_AccessFieldGenericClass)}"); - - Assert.Equal(Private, GetPrivateFieldInt(new UserDataGenericClass(0))); - - Assert.Equal(Private, GetPrivateFieldString(new UserDataGenericClass(string.Empty))); - - { - int expected = 10; - Assert.Equal(expected, GetPrivateField(new UserDataGenericClass(expected))); - } - - { - string expected = "abc"; - Assert.Equal(expected, GetPrivateField(new UserDataGenericClass(expected))); - } - - [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.FieldName)] - extern static ref string GetPrivateFieldInt(UserDataGenericClass d); - - [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.FieldName)] - extern static ref string GetPrivateFieldString(UserDataGenericClass d); - - [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.GenericFieldName)] - extern static ref T GetPrivateField(UserDataGenericClass d); - } - [Fact] public static void Verify_AccessStaticMethodClass() { @@ -614,15 +526,6 @@ class Invalid { [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] public extern string NonStatic(string a); - - [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] - public static extern string CallToString(U a); - } - - class Invalid - { - [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] - public static extern string CallToString(T a); } [Fact] @@ -647,8 +550,6 @@ public static void Verify_InvalidUseUnsafeAccessor() Assert.Throws(() => LookUpFailsOnPointers(null)); Assert.Throws(() => LookUpFailsOnFunctionPointers(null)); Assert.Throws(() => new Invalid().NonStatic(string.Empty)); - Assert.Throws(() => Invalid.CallToString(string.Empty)); - Assert.Throws(() => Invalid.CallToString(string.Empty)); Assert.Throws(() => { string str = string.Empty; diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj index 876d006ea96eb..f551f9b48c249 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj @@ -6,6 +6,7 @@ + From b1b22cf766dde1abddf5ed9b03d6c84a373b9fe5 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Fri, 8 Mar 2024 15:41:27 -0800 Subject: [PATCH 05/13] Remove canonical check to VAR and MVAR --- src/coreclr/vm/prestub.cpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 0d84b6a81efae..83a08bbd4ba33 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1133,7 +1133,7 @@ namespace FieldDesc* TargetField; }; - TypeHandle ValidateTargetType(TypeHandle targetTypeMaybe) + TypeHandle ValidateTargetType(TypeHandle targetTypeMaybe, CorElementType targetFromSig) { TypeHandle targetType = targetTypeMaybe.IsByRef() ? targetTypeMaybe.GetTypeParam() @@ -1144,6 +1144,12 @@ namespace if (targetType.IsTypeDesc()) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + // We do not support generic signature types as valid targets. + if (targetFromSig == ELEMENT_TYPE_VAR || targetFromSig == ELEMENT_TYPE_MVAR) + { + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + } + return targetType; } @@ -1280,10 +1286,6 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); - // We do not support the Canon type as a valid target. - if (targetType.AsMethodTable() == g_pCanonMethodTableClass) - ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - MethodDesc* targetMaybe = NULL; Substitution* pLookupSubst = NULL; @@ -1365,10 +1367,6 @@ namespace MethodTable* pMT = targetType.AsMethodTable(); - // We do not support the Canon type as a valid target. - if (pMT == g_pCanonMethodTableClass) - ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - CorElementType elemType = fieldType.GetSignatureCorElementType(); ApproxFieldDescIterator fdIterator( pMT, @@ -1641,7 +1639,10 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET // * Instance member access - examine type of first parameter // * Static member access - examine type of first parameter TypeHandle retType; + CorElementType retCorType; TypeHandle firstArgType; + CorElementType firstArgCorType = ELEMENT_TYPE_END; + retCorType = context.DeclarationSig.GetReturnType(); retType = context.DeclarationSig.GetRetTypeHandleThrowing(); UINT argCount = context.DeclarationSig.NumFixedArgs(); if (argCount > 0) @@ -1650,6 +1651,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET // Get the target type signature and resolve to a type handle. context.TargetTypeSig = context.DeclarationSig.GetArgProps(); + (void)context.TargetTypeSig.PeekElemType(&firstArgCorType); firstArgType = context.DeclarationSig.GetLastTypeHandleThrowing(); } @@ -1670,7 +1672,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET // Get the target type signature from the return type. context.TargetTypeSig = context.DeclarationSig.GetReturnProps(); - context.TargetType = ValidateTargetType(retType); + context.TargetType = ValidateTargetType(retType, retCorType); if (!TrySetTargetMethod(context, ".ctor")) MemberLoader::ThrowMissingMethodException(context.TargetType.AsMethodTable(), ".ctor"); break; @@ -1690,7 +1692,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } - context.TargetType = ValidateTargetType(firstArgType); + context.TargetType = ValidateTargetType(firstArgType, firstArgCorType); context.IsTargetStatic = kind == UnsafeAccessorKind::StaticMethod; if (!TrySetTargetMethod(context, name.GetUTF8())) MemberLoader::ThrowMissingMethodException(context.TargetType.AsMethodTable(), name.GetUTF8()); @@ -1715,7 +1717,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } - context.TargetType = ValidateTargetType(firstArgType); + context.TargetType = ValidateTargetType(firstArgType, firstArgCorType); context.IsTargetStatic = kind == UnsafeAccessorKind::StaticField; if (!TrySetTargetField(context, name.GetUTF8(), retType.GetTypeParam())) MemberLoader::ThrowMissingFieldException(context.TargetType.AsMethodTable(), name.GetUTF8()); From 6271919e3fbd8cd3461b32c0ca8448bb877ec0a4 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Fri, 8 Mar 2024 15:45:00 -0800 Subject: [PATCH 06/13] Native AOT support Update tests --- .../Common/TypeSystem/IL/Stubs/ILEmitter.cs | 10 ++-- .../Common/TypeSystem/IL/UnsafeAccessors.cs | 59 +++++++++++++++---- .../UnsafeAccessorsTests.Generics.cs | 7 +++ 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs index b6f31e130142e..4a4723d2b7afd 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs @@ -686,27 +686,27 @@ private ILToken NewToken(object value, int tokenType) public ILToken NewToken(TypeDesc value) { - return NewToken(value, 0x01000000); + return NewToken(value, 0x01000000); // mdtTypeRef } public ILToken NewToken(MethodDesc value) { - return NewToken(value, 0x0a000000); + return NewToken(value, 0x0a000000); // mdtMemberRef } public ILToken NewToken(FieldDesc value) { - return NewToken(value, 0x0a000000); + return NewToken(value, 0x0a000000); // mdtMemberRef } public ILToken NewToken(string value) { - return NewToken(value, 0x70000000); + return NewToken(value, 0x70000000); // mdtString } public ILToken NewToken(MethodSignature value) { - return NewToken(value, 0x11000000); + return NewToken(value, 0x11000000); // mdtSignature } public ILLocalVariable NewLocal(TypeDesc localType, bool isPinned = false) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index 6338f725be223..f1e8e903d4deb 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -29,12 +29,6 @@ public static MethodIL TryGetIL(EcmaMethod method) return GenerateAccessorBadImageFailure(method); } - // Block generic support early - if (method.HasInstantiation || method.OwningType.HasInstantiation) - { - return GenerateAccessorBadImageFailure(method); - } - if (!TryParseUnsafeAccessorAttribute(method, decodedAttribute.Value, out UnsafeAccessorKind kind, out string name)) { return GenerateAccessorBadImageFailure(method); @@ -232,15 +226,22 @@ private static bool ValidateTargetType(TypeDesc targetTypeMaybe, out TypeDesc va targetType = null; } + // We do not support signature variables as a target (for example, VAR and MVAR). + if (targetType is SignatureVariable) + { + targetType = null; + } + validated = targetType; return validated != null; } - private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationContext context, MethodDesc method, bool ignoreCustomModifiers) + private static bool DoesMethodMatchUnsafeAccessorDeclaration( + ref GenerationContext context, + MethodSignature declSig, + MethodSignature maybeSig, + bool ignoreCustomModifiers) { - MethodSignature declSig = context.Declaration.Signature; - MethodSignature maybeSig = method.Signature; - // Check if we need to also validate custom modifiers. // If we are, do it first. if (!ignoreCustomModifiers) @@ -366,10 +367,22 @@ private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationConte return true; } - private static bool TrySetTargetMethod(ref GenerationContext context, string name, out bool isAmbiguous, bool ignoreCustomModifiers = true) + private static unsafe bool TrySetTargetMethod(ref GenerationContext context, string name, out bool isAmbiguous, bool ignoreCustomModifiers = true) { TypeDesc targetType = context.TargetType; + // Build up a substitution Instantiation to use when looking up methods involving generics. + Instantiation substitutionForSig = default; + if (targetType.Instantiation.Length > 0) + { + TypeDesc[] types = new TypeDesc[targetType.Instantiation.Length]; + for (int i = 0; i < types.Length; ++i) + { + types[i] = targetType.Context.GetSignatureVariable(i, true); + } + substitutionForSig = new Instantiation(types); + } + MethodDesc targetMaybe = null; foreach (MethodDesc md in targetType.GetMethods()) { @@ -385,8 +398,18 @@ private static bool TrySetTargetMethod(ref GenerationContext context, string nam continue; } + // Create a substitution signature for the current signature if appropriate. + MethodSignature sigToCompare = md.Signature; + if (!substitutionForSig.IsNull) + { + sigToCompare = sigToCompare.ApplySubstitution(substitutionForSig); + } + // Check signature - if (!DoesMethodMatchUnsafeAccessorDeclaration(ref context, md, ignoreCustomModifiers)) + if (!DoesMethodMatchUnsafeAccessorDeclaration(ref context, + context.Declaration.Signature, + sigToCompare, + ignoreCustomModifiers)) { continue; } @@ -411,6 +434,18 @@ private static bool TrySetTargetMethod(ref GenerationContext context, string nam } isAmbiguous = false; + + if (targetMaybe != null && targetMaybe.HasInstantiation) + { + TypeDesc[] methodInstantiation = new TypeDesc[targetMaybe.Instantiation.Length]; + for (int i = 0; i < methodInstantiation.Length; ++i) + { + methodInstantiation[i] = targetMaybe.Context.GetSignatureVariable(i, true); + } + + targetMaybe = targetMaybe.Context.GetInstantiatedMethod(targetMaybe, new Instantiation(methodInstantiation)); + } + context.TargetMethod = targetMaybe; return context.TargetMethod != null; } diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs index fe23497fd8ec0..fa8373119f2a8 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -291,6 +291,7 @@ public static void Verify_Generic_GenericTypeGenericInstanceMethod() Assert.True(Accessors.CanCastToElementType(a, 1)); Assert.False(Accessors.CanCastToElementType(a, string.Empty)); Assert.False(Accessors.CanCastToElementType(a, new Struct())); + Assert.Equal(0, a.Count); Accessors.Add(a, 1); Assert.Equal(1, a.Count); } @@ -299,6 +300,7 @@ public static void Verify_Generic_GenericTypeGenericInstanceMethod() Assert.False(Accessors.CanCastToElementType(a, 1)); Assert.True(Accessors.CanCastToElementType(a, string.Empty)); Assert.False(Accessors.CanCastToElementType(a, new Struct())); + Assert.Equal(0, a.Count); Accessors.Add(a, string.Empty); Assert.Equal(1, a.Count); } @@ -307,6 +309,7 @@ public static void Verify_Generic_GenericTypeGenericInstanceMethod() Assert.False(Accessors.CanCastToElementType(a, 1)); Assert.False(Accessors.CanCastToElementType(a, string.Empty)); Assert.True(Accessors.CanCastToElementType(a, new Struct())); + Assert.Equal(0, a.Count); Accessors.Add(a, new Struct()); Assert.Equal(1, a.Count); } @@ -364,7 +367,11 @@ public static void Verify_Generic_InvalidUseUnsafeAccessor() { Console.WriteLine($"Running {nameof(Verify_Generic_InvalidUseUnsafeAccessor)}"); + Assert.Throws(() => Invalid.CallToString(0)); + Assert.Throws(() => Invalid.CallToString(0)); Assert.Throws(() => Invalid.CallToString(string.Empty)); Assert.Throws(() => Invalid.CallToString(string.Empty)); + Assert.Throws(() => Invalid.CallToString(new Struct())); + Assert.Throws(() => Invalid.CallToString(new Struct())); } } From 1de45f5e0408ea7865c64aad3bb9471f82938c3a Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Tue, 12 Mar 2024 11:00:09 -0700 Subject: [PATCH 07/13] Review feedback. --- .../Common/TypeSystem/IL/UnsafeAccessors.cs | 23 +--- src/coreclr/vm/prestub.cpp | 27 +---- .../UnsafeAccessorsTests.Generics.cs | 110 ++++++++++++------ 3 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index f1e8e903d4deb..bf578c74f17f7 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -367,22 +367,10 @@ private static bool DoesMethodMatchUnsafeAccessorDeclaration( return true; } - private static unsafe bool TrySetTargetMethod(ref GenerationContext context, string name, out bool isAmbiguous, bool ignoreCustomModifiers = true) + private static bool TrySetTargetMethod(ref GenerationContext context, string name, out bool isAmbiguous, bool ignoreCustomModifiers = true) { TypeDesc targetType = context.TargetType; - // Build up a substitution Instantiation to use when looking up methods involving generics. - Instantiation substitutionForSig = default; - if (targetType.Instantiation.Length > 0) - { - TypeDesc[] types = new TypeDesc[targetType.Instantiation.Length]; - for (int i = 0; i < types.Length; ++i) - { - types[i] = targetType.Context.GetSignatureVariable(i, true); - } - substitutionForSig = new Instantiation(types); - } - MethodDesc targetMaybe = null; foreach (MethodDesc md in targetType.GetMethods()) { @@ -398,17 +386,10 @@ private static unsafe bool TrySetTargetMethod(ref GenerationContext context, str continue; } - // Create a substitution signature for the current signature if appropriate. - MethodSignature sigToCompare = md.Signature; - if (!substitutionForSig.IsNull) - { - sigToCompare = sigToCompare.ApplySubstitution(substitutionForSig); - } - // Check signature if (!DoesMethodMatchUnsafeAccessorDeclaration(ref context, context.Declaration.Signature, - sigToCompare, + md.Signature, ignoreCustomModifiers)) { continue; diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 83a08bbd4ba33..cc7592211f8cc 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1156,7 +1156,6 @@ namespace bool DoesMethodMatchUnsafeAccessorDeclaration( GenerationContext& cxt, MethodDesc* method, - const Substitution* pMethodSubst, MetaSig::CompareState& state) { STANDARD_VM_CONTRACT; @@ -1174,7 +1173,7 @@ namespace method->GetSig(&pSig2, &cSig2); PCCOR_SIGNATURE pEndSig2 = pSig2 + cSig2; ModuleBase* pModule2 = method->GetModule(); - const Substitution* pSubst2 = pMethodSubst; + const Substitution* pSubst2 = NULL; // // Parsing the signature follows details defined in ECMA-335 - II.23.2.1 @@ -1287,28 +1286,6 @@ namespace _ASSERTE(!targetType.IsTypeDesc()); MethodDesc* targetMaybe = NULL; - Substitution* pLookupSubst = NULL; - - // Build up a Substitution to use when looking up methods involving generics. - Substitution substitution; - SigBuilder sigBuilder; - DWORD targetGenericParamCount = targetType.AsMethodTable()->GetNumGenericArgs(); - if (targetGenericParamCount > 0) - { - // Create a temporary signature that translate VARs to MVARs. - for (DWORD i = 0; i < targetGenericParamCount; ++i) - { - sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); - sigBuilder.AppendData(i); // Represents the generic parameter index - II.23.2.12 - } - - DWORD tmpSigLen; - PVOID tmpSigRaw = sigBuilder.GetSignature(&tmpSigLen); - - SigPointer tmpSig{ (PCCOR_SIGNATURE)tmpSigRaw, tmpSigLen }; - substitution = Substitution{ cxt.Declaration->GetModule(), tmpSig, NULL }; - pLookupSubst = &substitution; - } // Following a similar iteration pattern found in MemberLoader::FindMethod(). // However, we are only operating on the current type not walking the type hierarchy. @@ -1329,7 +1306,7 @@ namespace TokenPairList list { nullptr }; MetaSig::CompareState state{ &list }; state.IgnoreCustomModifiers = ignoreCustomModifiers; - if (!DoesMethodMatchUnsafeAccessorDeclaration(cxt, curr, pLookupSubst, state)) + if (!DoesMethodMatchUnsafeAccessorDeclaration(cxt, curr, state)) continue; // Check if there is some ambiguity. diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs index fa8373119f2a8..17f1880f1a6ae 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -40,6 +40,8 @@ static MyList() private void Add(T t) => _list.Add(t); + private void AddWithIgnore(T t, U _) => _list.Add(t); + private bool CanCastToElementType(U t) => t is T; private static bool CanUseElementType(U t) => t is T; @@ -61,7 +63,7 @@ private void Add(Struct a) => } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_AccessStaticFieldClass() { Console.WriteLine($"Running {nameof(Verify_Generic_AccessStaticFieldClass)}"); @@ -97,7 +99,7 @@ public static void Verify_Generic_AccessStaticFieldClass() } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_AccessFieldClass() { Console.WriteLine($"Running {nameof(Verify_Generic_AccessFieldClass)}"); @@ -120,21 +122,27 @@ public static void Verify_Generic_AccessFieldClass() class Base { - protected virtual string CreateMessage(T t) => $"{nameof(Base)}:{t}"; + protected virtual string CreateMessageGeneric(T t) => $"{nameof(Base)}:{t}"; } - class Derived1 : Base + class GenericBase : Base { - protected override string CreateMessage(T t) => $"{nameof(Derived1)}:{t}"; + protected virtual string CreateMessage(T u) => $"{nameof(GenericBase)}:{u}"; + protected override string CreateMessageGeneric(U t) => $"{nameof(GenericBase)}:{t}"; } - sealed class Derived2 : Derived1 + sealed class Derived1 : GenericBase + { + protected override string CreateMessage(string u) => $"{nameof(Derived1)}:{u}"; + protected override string CreateMessageGeneric(U t) => $"{nameof(Derived1)}:{t}"; + } + + sealed class Derived2 : GenericBase { - protected override string CreateMessage(T t) => $"{nameof(Derived2)}:{t}"; } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_InheritanceMethodResolution() { string expect = "abc"; @@ -145,6 +153,24 @@ public static void Verify_Generic_InheritanceMethodResolution() Assert.Equal($"{nameof(Base)}:{expect}", CreateMessage(a, expect)); Assert.Equal($"{nameof(Base)}:{nameof(Struct)}", CreateMessage(a, new Struct())); } + { + GenericBase a = new(); + Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + GenericBase a = new(); + Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + GenericBase a = new(); + Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } { Derived1 a = new(); Assert.Equal($"{nameof(Derived1)}:1", CreateMessage(a, 1)); @@ -152,13 +178,14 @@ public static void Verify_Generic_InheritanceMethodResolution() Assert.Equal($"{nameof(Derived1)}:{nameof(Struct)}", CreateMessage(a, new Struct())); } { - Derived2 a = new(); - Assert.Equal($"{nameof(Derived2)}:1", CreateMessage(a, 1)); - Assert.Equal($"{nameof(Derived2)}:{expect}", CreateMessage(a, expect)); - Assert.Equal($"{nameof(Derived2)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + // Verify resolution of generic override logic. + Derived1 a1 = new(); + Derived2 a2 = new(); + Assert.Equal($"{nameof(Derived1)}:{expect}", Accessors.CreateMessage(a1, expect)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", Accessors.CreateMessage(a2, expect)); } - [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessage")] + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessageGeneric")] extern static string CreateMessage(Base b, U t); } @@ -168,10 +195,10 @@ sealed class Accessors public extern static MyList Create(int a); [UnsafeAccessor(UnsafeAccessorKind.Constructor)] - public extern static MyList CreateWithList(List a); + public extern static MyList CreateWithList(List a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = ".ctor")] - public extern static void CallCtorAsMethod(MyList l, List a); + public extern static void CallCtorAsMethod(MyList l, List a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] public extern static void AddInt(MyList l, int a); @@ -186,11 +213,17 @@ sealed class Accessors public extern static void Clear(MyList l); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] - public extern static void Add(MyList l, U element); + public extern static void Add(MyList l, T element); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "AddWithIgnore")] + public extern static void AddWithIgnore(MyList l, T element, U ignore); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CanCastToElementType")] public extern static bool CanCastToElementType(MyList l, U element); + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessage")] + public extern static string CreateMessage(GenericBase b, T t); + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "ElementType")] public extern static Type ElementType(MyList l); @@ -199,7 +232,7 @@ sealed class Accessors } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_CallCtor() { Console.WriteLine($"Running {nameof(Verify_Generic_CallCtor)}"); @@ -220,38 +253,38 @@ public static void Verify_Generic_CallCtor() // Call constructor using generic parameter { - MyList a = Accessors.CreateWithList([ 1 ]); + MyList a = Accessors.CreateWithList([ 1 ]); Assert.Equal(1, a.Count); } { - MyList a = Accessors.CreateWithList([ "1", "2" ]); + MyList a = Accessors.CreateWithList([ "1", "2" ]); Assert.Equal(2, a.Count); } { - MyList a = Accessors.CreateWithList([new Struct(), new Struct(), new Struct()]); + MyList a = Accessors.CreateWithList([new Struct(), new Struct(), new Struct()]); Assert.Equal(3, a.Count); } // Call constructors as methods { MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); - Accessors.CallCtorAsMethod(a, [1]); + Accessors.CallCtorAsMethod(a, [1]); Assert.Equal(1, a.Count); } { MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); - Accessors.CallCtorAsMethod(a, ["1", "2"]); + Accessors.CallCtorAsMethod(a, ["1", "2"]); Assert.Equal(2, a.Count); } { MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); - Accessors.CallCtorAsMethod(a, [new Struct(), new Struct(), new Struct()]); + Accessors.CallCtorAsMethod(a, [new Struct(), new Struct(), new Struct()]); Assert.Equal(3, a.Count); } } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_GenericTypeNonGenericInstanceMethod() { Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeNonGenericInstanceMethod)}"); @@ -282,7 +315,7 @@ public static void Verify_Generic_GenericTypeNonGenericInstanceMethod() } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_GenericTypeGenericInstanceMethod() { Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeGenericInstanceMethod)}"); @@ -292,8 +325,11 @@ public static void Verify_Generic_GenericTypeGenericInstanceMethod() Assert.False(Accessors.CanCastToElementType(a, string.Empty)); Assert.False(Accessors.CanCastToElementType(a, new Struct())); Assert.Equal(0, a.Count); - Accessors.Add(a, 1); - Assert.Equal(1, a.Count); + Accessors.Add(a, 1); + Accessors.AddWithIgnore(a, 1, 1); + Accessors.AddWithIgnore(a, 1, string.Empty); + Accessors.AddWithIgnore(a, 1, new Struct()); + Assert.Equal(4, a.Count); } { MyList a = new(); @@ -301,8 +337,11 @@ public static void Verify_Generic_GenericTypeGenericInstanceMethod() Assert.True(Accessors.CanCastToElementType(a, string.Empty)); Assert.False(Accessors.CanCastToElementType(a, new Struct())); Assert.Equal(0, a.Count); - Accessors.Add(a, string.Empty); - Assert.Equal(1, a.Count); + Accessors.Add(a, string.Empty); + Accessors.AddWithIgnore(a, string.Empty, 1); + Accessors.AddWithIgnore(a, string.Empty, string.Empty); + Accessors.AddWithIgnore(a, string.Empty, new Struct()); + Assert.Equal(4, a.Count); } { MyList a = new(); @@ -310,13 +349,16 @@ public static void Verify_Generic_GenericTypeGenericInstanceMethod() Assert.False(Accessors.CanCastToElementType(a, string.Empty)); Assert.True(Accessors.CanCastToElementType(a, new Struct())); Assert.Equal(0, a.Count); - Accessors.Add(a, new Struct()); - Assert.Equal(1, a.Count); + Accessors.Add(a, new Struct()); + Accessors.AddWithIgnore(a, new Struct(), 1); + Accessors.AddWithIgnore(a, new Struct(), string.Empty); + Accessors.AddWithIgnore(a, new Struct(), new Struct()); + Assert.Equal(4, a.Count); } } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_GenericTypeNonGenericStaticMethod() { Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeNonGenericStaticMethod)}"); @@ -328,7 +370,7 @@ public static void Verify_Generic_GenericTypeNonGenericStaticMethod() } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_GenericTypeGenericStaticMethod() { Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeGenericStaticMethod)}"); @@ -362,7 +404,7 @@ class Invalid } [Fact] - [ActiveIssue("", TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] public static void Verify_Generic_InvalidUseUnsafeAccessor() { Console.WriteLine($"Running {nameof(Verify_Generic_InvalidUseUnsafeAccessor)}"); From 9b26c9d3621fa3c50e5ad1dd2b08f7b97f2eba04 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Wed, 13 Mar 2024 13:40:52 -0700 Subject: [PATCH 08/13] Add design document Add strict check for precisely matched Generic constraints --- docs/design/features/unsafeaccessors.md | 137 ++++++++++++++++++ src/coreclr/vm/methodtable.cpp | 11 ++ src/coreclr/vm/methodtable.h | 2 + src/coreclr/vm/prestub.cpp | 78 +++++++++- .../src/Resources/Strings.resx | 8 +- 5 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 docs/design/features/unsafeaccessors.md diff --git a/docs/design/features/unsafeaccessors.md b/docs/design/features/unsafeaccessors.md new file mode 100644 index 0000000000000..f3463546e171b --- /dev/null +++ b/docs/design/features/unsafeaccessors.md @@ -0,0 +1,137 @@ +# `UnsafeAccessorAttribute` + +## Background and motivation + +Number of existing .NET serializers depend on skipping member visibility checks for data serialization. Examples include System.Text.Json or EF Core. In order to skip the visibility checks, the serializers typically use dynamically emitted code (Reflection.Emit or Linq.Expressions) and classic reflection APIs as slow fallback. Neither of these two options are great for source generated serializers and native AOT compilation. This API proposal introduces a first class zero-overhead mechanism for skipping visibility checks. + +## Semantics + +This attribute will be applied to an `extern static` method. The implementation of the `extern static` method annotated with this attribute will be provided by the runtime based on the information in the attribute and the signature of the method that the attribute is applied to. The runtime will try to find the matching method or field and forward the call to it. If the matching method or field is not found, the body of the `extern static` method will throw `MissingFieldException` or `MissingMethodException`. + +For `Method`, `StaticMethod`, `Field`, and `StaticField`, the type of the first argument of the annotated `extern static` method identifies the owning type. Only the specific type defined will be examined for inaccessible members. The type hierarchy is not walked looking for a match. + +The value of the first argument is treated as `this` pointer for instance fields and methods. + +The first argument must be passed as `ref` for instance fields and methods on structs. + +The value of the first argument is not used by the implementation for static fields and methods. + +The return value for an accessor to a field can be `ref` if setting of the field is desired. + +Constructors can be accessed using Constructor or Method. + +The return type is considered for the signature match. Modreqs and modopts are initially not considered for the signature match. However, if an ambiguity exists ignoring modreqs and modopts, a precise match is attempted. If an ambiguity still exists, `AmbiguousMatchException` is thrown. + +By default, the attributed method's name dictates the name of the method/field. This can cause confusion in some cases since language abstractions, like C# local functions, generate mangled IL names. The solution to this is to use the `nameof` mechanism and define the `Name` property. + +Scenarios involving Generics may require creating new Generic types to contain the `extern static` method definition. The decision was made to require all `ELEMENT_TYPE_VAR` and `ELEMENT_TYPE_MVAR` instances to match identically type and generic parameter index. This means if the target method for access uses an `ELEMENT_TYPE_VAR`, the `extern static` method must also use an `ELEMENT_TYPE_VAR`. For example: + +```csharp +class C +{ + T M(U u) => default; +} + +class Accessor +{ + // Correct - V is an ELEMENT_TYPE_VAR and W is ELEMENT_TYPE_VAR, + // respectively the same as T and U in the definition of C::M(). + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static void CallM(C c, W w); + + // Incorrect - Since Y must be an ELEMENT_TYPE_VAR, but is ELEMENT_TYPE_MVAR below. + // [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + // extern static void CallM(C c, Z z); +} +``` + +Methods with the `UnsafeAccessorAttribute` that access members with Generic parameters are expected to have the same declared constraints with the target member. Failure to do so results in unspecified behavior. For example: + +```csharp +class C +{ + T M(U u) where U: Base => default; +} + +class Accessor +{ + // Correct - Constraints match the target member. + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static void CallM(C c, W w) where W: Base; + + // Incorrect - Constraints do not match target member. + // [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + // extern static void CallM(C c, W w); +} +``` + +## API + +```csharp +namespace System.Runtime.CompilerServices; + +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)] +public class UnsafeAccessorAttribute : Attribute +{ + public UnsafeAccessorAttribute(UnsafeAccessorKind kind); + + public UnsafeAccessorKind Kind { get; } + + // The name defaults to the annotated method name if not specified. + // The name must be null for constructors + public string? Name { get; set; } +} + +public enum UnsafeAccessorKind +{ + Constructor, // call instance constructor (`newobj` in IL) + Method, // call instance method (`callvirt` in IL) + StaticMethod, // call static method (`call` in IL) + Field, // address of instance field (`ldflda` in IL) + StaticField // address of static field (`ldsflda` in IL) +}; +``` + +## API Usage + +```csharp +class UserData +{ + private UserData() { } + public string Name { get; set; } +} + +[UnsafeAccessor(UnsafeAccessorKind.Constructor)] +extern static UserData CallPrivateConstructor(); + +// This API allows accessing backing fields for auto-implemented properties with unspeakable names. +[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")] +extern static ref string GetName(UserData userData); + +UserData ud = CallPrivateConstructor(); +GetName(ud) = "Joe"; +``` + +Using Generics + +```csharp +class UserData +{ + private T _field; + private UserData(T t) { _field = t; } + private U ConvertFieldToT() => (U)_field; +} + +// The Accessors class provides the Generic Type parameter for the method definitions. +class Accessors +{ + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + extern static UserData CallPrivateConstructor(V v); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "ConvertFieldToT")] + extern static U CallConvertFieldToT(UserData userData); +} + +UserData ud = Accessors.CallPrivateConstructor("Joe"); +Accessors.CallPrivateConstructor(ud); +``` \ No newline at end of file diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp index b59a2d23d7d03..08ad1eede6468 100644 --- a/src/coreclr/vm/methodtable.cpp +++ b/src/coreclr/vm/methodtable.cpp @@ -472,6 +472,17 @@ WORD MethodTable::GetNumMethods() return GetClass()->GetNumMethods(); } +PTR_MethodTable MethodTable::GetTypicalMethodTable() +{ + LIMITED_METHOD_DAC_CONTRACT; + if (IsArray()) + return (PTR_MethodTable)this; + + PTR_MethodTable methodTableMaybe = GetModule()->LookupTypeDef(GetCl()).AsMethodTable(); + _ASSERTE(methodTableMaybe->IsTypicalTypeDefinition()); + return methodTableMaybe; +} + //========================================================================================== BOOL MethodTable::HasSameTypeDefAs(MethodTable *pMT) { diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h index 83057c623fba9..6e4f68f29bce7 100644 --- a/src/coreclr/vm/methodtable.h +++ b/src/coreclr/vm/methodtable.h @@ -1183,6 +1183,8 @@ class MethodTable return !HasInstantiation() || IsGenericTypeDefinition(); } + PTR_MethodTable GetTypicalMethodTable(); + BOOL HasSameTypeDefAs(MethodTable *pMT); //------------------------------------------------------------------- diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index cc7592211f8cc..58e6a2fa73071 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1271,6 +1271,64 @@ namespace return true; } + bool AreConstraintsEqual(const Instantiation& left, const Instantiation& right) + { + STANDARD_VM_CONTRACT; + + DWORD argCount = left.GetNumArgs(); + if (argCount != right.GetNumArgs()) + return false; + + for (DWORD i = 0; i < argCount; ++i) + { + TypeHandle tL = left[i]; + TypeHandle tR = right[i]; + + // Check generic variable state are the same. + BOOL isGeneric = tL.IsGenericVariable(); + if (isGeneric != tR.IsGenericVariable()) + return false; + + // Only generic variables have constraints. + if (!isGeneric) + continue; + + TypeVarTypeDesc* tsL = tL.AsGenericVariable(); + TypeVarTypeDesc* tsR = tR.AsGenericVariable(); + + // + // Verify general constraints + // + IMDInternalImport* importL = tsL->GetModule()->GetMDImport(); + IMDInternalImport* importR = tsR->GetModule()->GetMDImport(); + + DWORD flagsL; + DWORD flagsR; + IfFailThrow(importL->GetGenericParamProps(tsL->GetToken(), NULL, &flagsL, NULL, NULL, NULL)); + IfFailThrow(importR->GetGenericParamProps(tsR->GetToken(), NULL, &flagsR, NULL, NULL, NULL)); + if ((flagsL & gpSpecialConstraintMask) != (flagsR & gpSpecialConstraintMask)) + return false; + + // + // Verify type constraints + // + DWORD constraintCountL; + DWORD constraintCountR; + TypeHandle* cL = tsL->GetConstraints(&constraintCountL); + TypeHandle* cR = tsR->GetConstraints(&constraintCountR); + if (constraintCountL != constraintCountR) + return false; + + for (DWORD j = 0; j < constraintCountL; ++j) + { + if (cL[j] != cR[j]) + return false; + } + } + + return true; + } + bool TrySetTargetMethod( GenerationContext& cxt, LPCUTF8 methodName, @@ -1285,11 +1343,13 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + MethodTable* pMT = targetType.AsMethodTable(); + MethodDesc* targetMaybe = NULL; // Following a similar iteration pattern found in MemberLoader::FindMethod(). // However, we are only operating on the current type not walking the type hierarchy. - MethodTable::IntroducedMethodIterator iter(targetType.AsMethodTable()); + MethodTable::IntroducedMethodIterator iter(pMT); for (; iter.IsValid(); iter.Next()) { MethodDesc* curr = iter.GetMethodDesc(); @@ -1325,6 +1385,22 @@ namespace targetMaybe = curr; } + if (pMT->HasInstantiation()) + { + Instantiation decl = cxt.Declaration->GetMethodTable()->GetTypicalMethodTable()->GetInstantiation(); + Instantiation target = pMT->GetTypicalMethodTable()->GetInstantiation(); + if (!AreConstraintsEqual(decl, target)) + COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); + } + + if (targetMaybe->HasMethodInstantiation()) + { + Instantiation decl = cxt.Declaration->LoadTypicalMethodDefinition()->GetMethodInstantiation(); + Instantiation target = targetMaybe->LoadTypicalMethodDefinition()->GetMethodInstantiation(); + if (!AreConstraintsEqual(decl, target)) + COMPlusThrow(kInvalidProgramException, W("Argument_GenMethodConstraintsNotEqual")); + } + cxt.TargetMethod = targetMaybe; return cxt.TargetMethod != NULL; } diff --git a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx index b6d8fd00aa404..f5777fc29413a 100644 --- a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx +++ b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx @@ -1101,6 +1101,12 @@ GenericArguments[{0}], '{1}', on '{2}' violates the constraint of type '{3}'. + + Generic type constraints do not match. + + + Generic method constraints do not match. + The number of generic arguments provided doesn't equal the arity of the generic type definition. @@ -3346,7 +3352,7 @@ Object type {0} does not match target type {1}. - + Non-static field requires a target. From 7946c69663e6a7d39a865619ebf602758509b792 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Wed, 13 Mar 2024 13:58:29 -0700 Subject: [PATCH 09/13] Check for null prior to constraint check. --- src/coreclr/vm/prestub.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 58e6a2fa73071..87897d11874af 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1393,7 +1393,7 @@ namespace COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); } - if (targetMaybe->HasMethodInstantiation()) + if (targetMaybe != NULL && targetMaybe->HasMethodInstantiation()) { Instantiation decl = cxt.Declaration->LoadTypicalMethodDefinition()->GetMethodInstantiation(); Instantiation target = targetMaybe->LoadTypicalMethodDefinition()->GetMethodInstantiation(); From 3546eebafccdbd8cda264fd99453617b24c2a36f Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Thu, 14 Mar 2024 10:17:36 -0700 Subject: [PATCH 10/13] native AOT feedback --- .../tools/Common/TypeSystem/IL/UnsafeAccessors.cs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index bf578c74f17f7..6232487a39693 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -236,12 +236,11 @@ private static bool ValidateTargetType(TypeDesc targetTypeMaybe, out TypeDesc va return validated != null; } - private static bool DoesMethodMatchUnsafeAccessorDeclaration( - ref GenerationContext context, - MethodSignature declSig, - MethodSignature maybeSig, - bool ignoreCustomModifiers) + private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationContext context, MethodDesc method, bool ignoreCustomModifiers) { + MethodSignature declSig = context.Declaration.Signature; + MethodSignature maybeSig = method.Signature; + // Check if we need to also validate custom modifiers. // If we are, do it first. if (!ignoreCustomModifiers) @@ -387,10 +386,7 @@ private static bool TrySetTargetMethod(ref GenerationContext context, string nam } // Check signature - if (!DoesMethodMatchUnsafeAccessorDeclaration(ref context, - context.Declaration.Signature, - md.Signature, - ignoreCustomModifiers)) + if (!DoesMethodMatchUnsafeAccessorDeclaration(ref context, md, ignoreCustomModifiers)) { continue; } From 1c3c87469543a5e2537ce0a8210afad822c81f11 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Thu, 14 Mar 2024 15:18:17 -0700 Subject: [PATCH 11/13] Use TypeVarTypeDesc::SatisfiesConstraints for constraint checking. --- src/coreclr/vm/prestub.cpp | 117 ++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 60 deletions(-) diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 87897d11874af..f1e1e40792aac 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1271,62 +1271,72 @@ namespace return true; } - bool AreConstraintsEqual(const Instantiation& left, const Instantiation& right) + void VerifyDeclarationSatifiesTargetConstraints(MethodDesc* declaration, MethodTable* targetType, MethodDesc* targetMethod) { - STANDARD_VM_CONTRACT; + CONTRACTL + { + STANDARD_VM_CHECK; + PRECONDITION(declaration != NULL); + PRECONDITION(targetType != NULL); + PRECONDITION(targetMethod != NULL); + } + CONTRACTL_END; - DWORD argCount = left.GetNumArgs(); - if (argCount != right.GetNumArgs()) - return false; + // If the target method has no generic parameters there is nothing to verify + if (!targetMethod->HasClassOrMethodInstantiation()) + return; - for (DWORD i = 0; i < argCount; ++i) + // Construct a context for verifying target's constraints are + // satisfied by the declaration. + Instantiation declClassInst; + Instantiation declMethodInst; + Instantiation targetClassInst; + Instantiation targetMethodInst; + if (targetType->HasInstantiation()) { - TypeHandle tL = left[i]; - TypeHandle tR = right[i]; + declClassInst = declaration->GetMethodTable()->GetInstantiation(); + targetClassInst = targetType->GetTypicalMethodTable()->GetInstantiation(); + } + if (targetMethod->HasMethodInstantiation()) + { + declMethodInst = declaration->LoadTypicalMethodDefinition()->GetMethodInstantiation(); + targetMethodInst = targetMethod->LoadTypicalMethodDefinition()->GetMethodInstantiation(); + } - // Check generic variable state are the same. - BOOL isGeneric = tL.IsGenericVariable(); - if (isGeneric != tR.IsGenericVariable()) - return false; + SigTypeContext typeContext; + SigTypeContext::InitTypeContext(declClassInst, declMethodInst, &typeContext); - // Only generic variables have constraints. - if (!isGeneric) - continue; + InstantiationContext instContext{ &typeContext }; - TypeVarTypeDesc* tsL = tL.AsGenericVariable(); - TypeVarTypeDesc* tsR = tR.AsGenericVariable(); + // + // Validate constraints on Type parameters + // + DWORD typeParamCount = targetClassInst.GetNumArgs(); + if (typeParamCount != declClassInst.GetNumArgs()) + COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); - // - // Verify general constraints - // - IMDInternalImport* importL = tsL->GetModule()->GetMDImport(); - IMDInternalImport* importR = tsR->GetModule()->GetMDImport(); - - DWORD flagsL; - DWORD flagsR; - IfFailThrow(importL->GetGenericParamProps(tsL->GetToken(), NULL, &flagsL, NULL, NULL, NULL)); - IfFailThrow(importR->GetGenericParamProps(tsR->GetToken(), NULL, &flagsR, NULL, NULL, NULL)); - if ((flagsL & gpSpecialConstraintMask) != (flagsR & gpSpecialConstraintMask)) - return false; + for (DWORD i = 0; i < typeParamCount; ++i) + { + TypeHandle arg = declClassInst[i]; + TypeVarTypeDesc* param = targetClassInst[i].AsGenericVariable(); + if (!param->SatisfiesConstraints(&typeContext, arg, &instContext)) + COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); + } - // - // Verify type constraints - // - DWORD constraintCountL; - DWORD constraintCountR; - TypeHandle* cL = tsL->GetConstraints(&constraintCountL); - TypeHandle* cR = tsR->GetConstraints(&constraintCountR); - if (constraintCountL != constraintCountR) - return false; + // + // Validate constraints on Method parameters + // + DWORD methodParamCount = targetMethodInst.GetNumArgs(); + if (methodParamCount != declMethodInst.GetNumArgs()) + COMPlusThrow(kInvalidProgramException, W("Argument_GenMethodConstraintsNotEqual")); - for (DWORD j = 0; j < constraintCountL; ++j) - { - if (cL[j] != cR[j]) - return false; - } + for (DWORD i = 0; i < methodParamCount; ++i) + { + TypeHandle arg = declMethodInst[i]; + TypeVarTypeDesc* param = targetMethodInst[i].AsGenericVariable(); + if (!param->SatisfiesConstraints(&typeContext, arg, &instContext)) + COMPlusThrow(kInvalidProgramException, W("Argument_GenMethodConstraintsNotEqual")); } - - return true; } bool TrySetTargetMethod( @@ -1385,21 +1395,8 @@ namespace targetMaybe = curr; } - if (pMT->HasInstantiation()) - { - Instantiation decl = cxt.Declaration->GetMethodTable()->GetTypicalMethodTable()->GetInstantiation(); - Instantiation target = pMT->GetTypicalMethodTable()->GetInstantiation(); - if (!AreConstraintsEqual(decl, target)) - COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); - } - - if (targetMaybe != NULL && targetMaybe->HasMethodInstantiation()) - { - Instantiation decl = cxt.Declaration->LoadTypicalMethodDefinition()->GetMethodInstantiation(); - Instantiation target = targetMaybe->LoadTypicalMethodDefinition()->GetMethodInstantiation(); - if (!AreConstraintsEqual(decl, target)) - COMPlusThrow(kInvalidProgramException, W("Argument_GenMethodConstraintsNotEqual")); - } + if (targetMaybe != NULL) + VerifyDeclarationSatifiesTargetConstraints(cxt.Declaration, pMT, targetMaybe); cxt.TargetMethod = targetMaybe; return cxt.TargetMethod != NULL; From adbfc528beb9b4f092e793a87bdc5993fa2a1517 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Fri, 15 Mar 2024 09:00:04 -0700 Subject: [PATCH 12/13] Update design doc Add constraint test --- docs/design/features/unsafeaccessors.md | 8 ++-- .../UnsafeAccessorsTests.Generics.cs | 47 +++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/docs/design/features/unsafeaccessors.md b/docs/design/features/unsafeaccessors.md index f3463546e171b..8ce4d2a22ed26 100644 --- a/docs/design/features/unsafeaccessors.md +++ b/docs/design/features/unsafeaccessors.md @@ -24,7 +24,7 @@ The return type is considered for the signature match. Modreqs and modopts are i By default, the attributed method's name dictates the name of the method/field. This can cause confusion in some cases since language abstractions, like C# local functions, generate mangled IL names. The solution to this is to use the `nameof` mechanism and define the `Name` property. -Scenarios involving Generics may require creating new Generic types to contain the `extern static` method definition. The decision was made to require all `ELEMENT_TYPE_VAR` and `ELEMENT_TYPE_MVAR` instances to match identically type and generic parameter index. This means if the target method for access uses an `ELEMENT_TYPE_VAR`, the `extern static` method must also use an `ELEMENT_TYPE_VAR`. For example: +Scenarios involving generics may require creating new generic types to contain the `extern static` method definition. The decision was made to require all `ELEMENT_TYPE_VAR` and `ELEMENT_TYPE_MVAR` instances to match identically type and generic parameter index. This means if the target method for access uses an `ELEMENT_TYPE_VAR`, the `extern static` method must also use an `ELEMENT_TYPE_VAR`. For example: ```csharp class C @@ -45,7 +45,7 @@ class Accessor } ``` -Methods with the `UnsafeAccessorAttribute` that access members with Generic parameters are expected to have the same declared constraints with the target member. Failure to do so results in unspecified behavior. For example: +Methods with the `UnsafeAccessorAttribute` that access members with generic parameters are expected to have the same declared constraints with the target member. Failure to do so results in unspecified behavior. For example: ```csharp class C @@ -112,7 +112,7 @@ UserData ud = CallPrivateConstructor(); GetName(ud) = "Joe"; ``` -Using Generics +Using generics ```csharp class UserData @@ -122,7 +122,7 @@ class UserData private U ConvertFieldToT() => (U)_field; } -// The Accessors class provides the Generic Type parameter for the method definitions. +// The Accessors class provides the generic Type parameter for the method definitions. class Accessors { [UnsafeAccessor(UnsafeAccessorKind.Constructor)] diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs index 17f1880f1a6ae..9305a1f4c1476 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using TestLibrary; using Xunit; struct Struct { } @@ -391,6 +392,52 @@ public static void Verify_Generic_GenericTypeGenericStaticMethod() } } + class ClassWithConstraints + { + private string M() where T : U, IEquatable + => $"{typeof(T)}|{typeof(U)}"; + + private static string SM() where T : U, IEquatable + => $"{typeof(T)}|{typeof(U)}"; + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_ConstraintEnforcement() + { + Console.WriteLine($"Running {nameof(Verify_Generic_ConstraintEnforcement)}"); + + Assert.Equal($"{typeof(string)}|{typeof(object)}", CallMethod(new ClassWithConstraints())); + Assert.Equal($"{typeof(string)}|{typeof(object)}", CallStaticMethod(null)); + + // Constraint validation isn't performed in AOT scenarios. + if (Utilities.IsNotNativeAot) + { + Assert.Throws(() => CallMethod_NoConstraints(new ClassWithConstraints())); + Assert.Throws(() => CallMethod_MissingConstraint(new ClassWithConstraints())); + Assert.Throws(() => CallStaticMethod_NoConstraints(null)); + Assert.Throws(() => CallStaticMethod_MissingConstraint(null)); + } + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static string CallMethod(ClassWithConstraints c) where T : U, IEquatable; + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static string CallMethod_NoConstraints(ClassWithConstraints c); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static string CallMethod_MissingConstraint(ClassWithConstraints c) where T : U; + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + extern static string CallStaticMethod(ClassWithConstraints c) where T : U, IEquatable; + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + extern static string CallStaticMethod_NoConstraints(ClassWithConstraints c); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + extern static string CallStaticMethod_MissingConstraint(ClassWithConstraints c) where T : U; + } + class Invalid { [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] From c34b61ea26c9ed3c617f79045a7f1040bca4cfbc Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Tue, 19 Mar 2024 12:58:28 -0700 Subject: [PATCH 13/13] Add constraints check in native AOT Update tests to use unique parameter names --- .../Common/TypeSystem/IL/UnsafeAccessors.cs | 104 +++++++++++++----- src/coreclr/vm/prestub.cpp | 4 +- .../UnsafeAccessorsTests.Generics.cs | 64 +++++------ 3 files changed, 108 insertions(+), 64 deletions(-) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index 6232487a39693..e8e97f2eb3197 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -48,7 +48,7 @@ public static MethodIL TryGetIL(EcmaMethod method) firstArgType = sig[0]; } - bool isAmbiguous = false; + SetTargetResult result; // Using the kind type, perform the following: // 1) Validate the basic type information from the signature. @@ -71,9 +71,10 @@ public static MethodIL TryGetIL(EcmaMethod method) } const string ctorName = ".ctor"; - if (!TrySetTargetMethod(ref context, ctorName, out isAmbiguous)) + result = TrySetTargetMethod(ref context, ctorName); + if (result is not SetTargetResult.Success) { - return GenerateAccessorSpecificFailure(ref context, ctorName, isAmbiguous); + return GenerateAccessorSpecificFailure(ref context, ctorName, result); } break; case UnsafeAccessorKind.Method: @@ -99,9 +100,10 @@ public static MethodIL TryGetIL(EcmaMethod method) } context.IsTargetStatic = kind == UnsafeAccessorKind.StaticMethod; - if (!TrySetTargetMethod(ref context, name, out isAmbiguous)) + result = TrySetTargetMethod(ref context, name); + if (result is not SetTargetResult.Success) { - return GenerateAccessorSpecificFailure(ref context, name, isAmbiguous); + return GenerateAccessorSpecificFailure(ref context, name, result); } break; @@ -130,9 +132,10 @@ public static MethodIL TryGetIL(EcmaMethod method) } context.IsTargetStatic = kind == UnsafeAccessorKind.StaticField; - if (!TrySetTargetField(ref context, name, ((ParameterizedType)retType).GetParameterType())) + result = TrySetTargetField(ref context, name, ((ParameterizedType)retType).GetParameterType()); + if (result is not SetTargetResult.Success) { - return GenerateAccessorSpecificFailure(ref context, name, isAmbiguous); + return GenerateAccessorSpecificFailure(ref context, name, result); } break; @@ -366,7 +369,45 @@ private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationConte return true; } - private static bool TrySetTargetMethod(ref GenerationContext context, string name, out bool isAmbiguous, bool ignoreCustomModifiers = true) + private static bool VerifyDeclarationSatisfiesTargetConstraints(MethodDesc declaration, TypeDesc targetType, MethodDesc targetMethod) + { + Debug.Assert(declaration != null); + Debug.Assert(targetType != null); + Debug.Assert(targetMethod != null); + + if (targetType.HasInstantiation) + { + Instantiation declClassInst = declaration.OwningType.Instantiation; + var instType = targetType.Context.GetInstantiatedType((MetadataType)targetType.GetTypeDefinition(), declClassInst); + if (!instType.CheckConstraints()) + { + return false; + } + + targetMethod = instType.FindMethodOnExactTypeWithMatchingTypicalMethod(targetMethod); + } + + if (targetMethod.HasInstantiation) + { + Instantiation declMethodInst = declaration.Instantiation; + var instMethod = targetType.Context.GetInstantiatedMethod(targetMethod, declMethodInst); + if (!instMethod.CheckConstraints()) + { + return false; + } + } + return true; + } + + private enum SetTargetResult + { + Success, + Missing, + Ambiguous, + Invalid, + } + + private static SetTargetResult TrySetTargetMethod(ref GenerationContext context, string name, bool ignoreCustomModifiers = true) { TypeDesc targetType = context.TargetType; @@ -399,35 +440,39 @@ private static bool TrySetTargetMethod(ref GenerationContext context, string nam // We have detected ambiguity when ignoring custom modifiers. // Start over, but look for a match requiring custom modifiers // to match precisely. - if (TrySetTargetMethod(ref context, name, out isAmbiguous, ignoreCustomModifiers: false)) - return true; + if (SetTargetResult.Success == TrySetTargetMethod(ref context, name, ignoreCustomModifiers: false)) + return SetTargetResult.Success; } - - isAmbiguous = true; - return false; + return SetTargetResult.Ambiguous; } targetMaybe = md; } - isAmbiguous = false; - - if (targetMaybe != null && targetMaybe.HasInstantiation) + if (targetMaybe != null) { - TypeDesc[] methodInstantiation = new TypeDesc[targetMaybe.Instantiation.Length]; - for (int i = 0; i < methodInstantiation.Length; ++i) + if (!VerifyDeclarationSatisfiesTargetConstraints(context.Declaration, targetType, targetMaybe)) { - methodInstantiation[i] = targetMaybe.Context.GetSignatureVariable(i, true); + return SetTargetResult.Invalid; } - targetMaybe = targetMaybe.Context.GetInstantiatedMethod(targetMaybe, new Instantiation(methodInstantiation)); + if (targetMaybe.HasInstantiation) + { + TypeDesc[] methodInstantiation = new TypeDesc[targetMaybe.Instantiation.Length]; + for (int i = 0; i < methodInstantiation.Length; ++i) + { + methodInstantiation[i] = targetMaybe.Context.GetSignatureVariable(i, true); + } + targetMaybe = targetMaybe.Context.GetInstantiatedMethod(targetMaybe, new Instantiation(methodInstantiation)); + } + Debug.Assert(targetMaybe is not null); } context.TargetMethod = targetMaybe; - return context.TargetMethod != null; + return context.TargetMethod != null ? SetTargetResult.Success : SetTargetResult.Missing; } - private static bool TrySetTargetField(ref GenerationContext context, string name, TypeDesc fieldType) + private static SetTargetResult TrySetTargetField(ref GenerationContext context, string name, TypeDesc fieldType) { TypeDesc targetType = context.TargetType; @@ -443,10 +488,10 @@ private static bool TrySetTargetField(ref GenerationContext context, string name && fieldType == fd.FieldType) { context.TargetField = fd; - return true; + return SetTargetResult.Success; } } - return false; + return SetTargetResult.Missing; } private static MethodIL GenerateAccessor(ref GenerationContext context) @@ -498,7 +543,7 @@ private static MethodIL GenerateAccessor(ref GenerationContext context) return emit.Link(context.Declaration); } - private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext context, string name, bool ambiguous) + private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext context, string name, SetTargetResult result) { ILEmitter emit = new ILEmitter(); ILCodeStream codeStream = emit.NewCodeStream(); @@ -508,14 +553,19 @@ private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext co MethodDesc thrower; TypeSystemContext typeSysContext = context.Declaration.Context; - if (ambiguous) + if (result is SetTargetResult.Ambiguous) { codeStream.EmitLdc((int)ExceptionStringID.AmbiguousMatchUnsafeAccessor); thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowAmbiguousMatchException"); } + else if (result is SetTargetResult.Invalid) + { + codeStream.EmitLdc((int)ExceptionStringID.InvalidProgramDefault); + thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowInvalidProgramException"); + } else { - + Debug.Assert(result is SetTargetResult.Missing); ExceptionStringID id; if (context.Kind == UnsafeAccessorKind.Field || context.Kind == UnsafeAccessorKind.StaticField) { diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index f1e1e40792aac..ff447af8ab46e 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1271,7 +1271,7 @@ namespace return true; } - void VerifyDeclarationSatifiesTargetConstraints(MethodDesc* declaration, MethodTable* targetType, MethodDesc* targetMethod) + void VerifyDeclarationSatisfiesTargetConstraints(MethodDesc* declaration, MethodTable* targetType, MethodDesc* targetMethod) { CONTRACTL { @@ -1396,7 +1396,7 @@ namespace } if (targetMaybe != NULL) - VerifyDeclarationSatifiesTargetConstraints(cxt.Declaration, pMT, targetMaybe); + VerifyDeclarationSatisfiesTargetConstraints(cxt.Declaration, pMT, targetMaybe); cxt.TargetMethod = targetMaybe; return cxt.TargetMethod != NULL; diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs index 9305a1f4c1476..e1029797bf12c 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -7,7 +7,6 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using TestLibrary; using Xunit; struct Struct { } @@ -96,7 +95,7 @@ public static void Verify_Generic_AccessStaticFieldClass() extern static ref string GetPrivateStaticFieldStruct(MyList d); [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticGenericFieldName)] - extern static ref T GetPrivateStaticField(MyList d); + extern static ref V GetPrivateStaticField(MyList d); } [Fact] @@ -118,7 +117,7 @@ public static void Verify_Generic_AccessFieldClass() } [UnsafeAccessor(UnsafeAccessorKind.Field, Name=MyList.GenericFieldName)] - extern static ref List GetPrivateField(MyList a); + extern static ref List GetPrivateField(MyList a); } class Base @@ -128,8 +127,8 @@ class Base class GenericBase : Base { - protected virtual string CreateMessage(T u) => $"{nameof(GenericBase)}:{u}"; - protected override string CreateMessageGeneric(U t) => $"{nameof(GenericBase)}:{t}"; + protected virtual string CreateMessage(T t) => $"{nameof(GenericBase)}:{t}"; + protected override string CreateMessageGeneric(U u) => $"{nameof(GenericBase)}:{u}"; } sealed class Derived1 : GenericBase @@ -187,49 +186,49 @@ public static void Verify_Generic_InheritanceMethodResolution() } [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessageGeneric")] - extern static string CreateMessage(Base b, U t); + extern static string CreateMessage(Base b, W w); } - sealed class Accessors + sealed class Accessors { [UnsafeAccessor(UnsafeAccessorKind.Constructor)] - public extern static MyList Create(int a); + public extern static MyList Create(int a); [UnsafeAccessor(UnsafeAccessorKind.Constructor)] - public extern static MyList CreateWithList(List a); + public extern static MyList CreateWithList(List a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = ".ctor")] - public extern static void CallCtorAsMethod(MyList l, List a); + public extern static void CallCtorAsMethod(MyList l, List a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] - public extern static void AddInt(MyList l, int a); + public extern static void AddInt(MyList l, int a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] - public extern static void AddString(MyList l, string a); + public extern static void AddString(MyList l, string a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] - public extern static void AddStruct(MyList l, Struct a); + public extern static void AddStruct(MyList l, Struct a); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Clear")] - public extern static void Clear(MyList l); + public extern static void Clear(MyList l); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] - public extern static void Add(MyList l, T element); + public extern static void Add(MyList l, V element); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "AddWithIgnore")] - public extern static void AddWithIgnore(MyList l, T element, U ignore); + public extern static void AddWithIgnore(MyList l, V element, W ignore); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CanCastToElementType")] - public extern static bool CanCastToElementType(MyList l, U element); + public extern static bool CanCastToElementType(MyList l, W element); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessage")] - public extern static string CreateMessage(GenericBase b, T t); + public extern static string CreateMessage(GenericBase b, V v); [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "ElementType")] - public extern static Type ElementType(MyList l); + public extern static Type ElementType(MyList l); [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "CanUseElementType")] - public extern static bool CanUseElementType(MyList l, U element); + public extern static bool CanUseElementType(MyList l, W element); } [Fact] @@ -409,33 +408,28 @@ public static void Verify_Generic_ConstraintEnforcement() Assert.Equal($"{typeof(string)}|{typeof(object)}", CallMethod(new ClassWithConstraints())); Assert.Equal($"{typeof(string)}|{typeof(object)}", CallStaticMethod(null)); - - // Constraint validation isn't performed in AOT scenarios. - if (Utilities.IsNotNativeAot) - { - Assert.Throws(() => CallMethod_NoConstraints(new ClassWithConstraints())); - Assert.Throws(() => CallMethod_MissingConstraint(new ClassWithConstraints())); - Assert.Throws(() => CallStaticMethod_NoConstraints(null)); - Assert.Throws(() => CallStaticMethod_MissingConstraint(null)); - } + Assert.Throws(() => CallMethod_NoConstraints(new ClassWithConstraints())); + Assert.Throws(() => CallMethod_MissingConstraint(new ClassWithConstraints())); + Assert.Throws(() => CallStaticMethod_NoConstraints(null)); + Assert.Throws(() => CallStaticMethod_MissingConstraint(null)); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] - extern static string CallMethod(ClassWithConstraints c) where T : U, IEquatable; + extern static string CallMethod(ClassWithConstraints c) where V : W, IEquatable; [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] - extern static string CallMethod_NoConstraints(ClassWithConstraints c); + extern static string CallMethod_NoConstraints(ClassWithConstraints c); [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] - extern static string CallMethod_MissingConstraint(ClassWithConstraints c) where T : U; + extern static string CallMethod_MissingConstraint(ClassWithConstraints c) where V : W; [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] - extern static string CallStaticMethod(ClassWithConstraints c) where T : U, IEquatable; + extern static string CallStaticMethod(ClassWithConstraints c) where V : W, IEquatable; [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] - extern static string CallStaticMethod_NoConstraints(ClassWithConstraints c); + extern static string CallStaticMethod_NoConstraints(ClassWithConstraints c); [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] - extern static string CallStaticMethod_MissingConstraint(ClassWithConstraints c) where T : U; + extern static string CallStaticMethod_MissingConstraint(ClassWithConstraints c) where V : W; } class Invalid