diff --git a/src/DotSwashbuckle.AspNetCore.SwaggerGen/SchemaGenerator/MemberInfoExtensions.cs b/src/DotSwashbuckle.AspNetCore.SwaggerGen/SchemaGenerator/MemberInfoExtensions.cs index f19f21f..d209216 100644 --- a/src/DotSwashbuckle.AspNetCore.SwaggerGen/SchemaGenerator/MemberInfoExtensions.cs +++ b/src/DotSwashbuckle.AspNetCore.SwaggerGen/SchemaGenerator/MemberInfoExtensions.cs @@ -2,17 +2,13 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using Microsoft.AspNetCore.Mvc; namespace DotSwashbuckle.AspNetCore.SwaggerGen { public static class MemberInfoExtensions { - private const string NullableAttributeFullTypeName = "System.Runtime.CompilerServices.NullableAttribute"; - private const string NullableFlagsFieldName = "NullableFlags"; - private const string NullableContextAttributeFullTypeName = "System.Runtime.CompilerServices.NullableContextAttribute"; - private const string FlagFieldName = "Flag"; - public static ICollection GetInlineAndMetadataAttributes(this MemberInfo memberInfo) { var attributes = memberInfo.GetCustomAttributes(true) @@ -48,9 +44,7 @@ public static bool IsNonNullableReferenceType(this MemberInfo memberInfo) return memberInfo.GetNullableFallbackValue(); } - if (nullableAttribute.GetType().GetField(NullableFlagsFieldName) is FieldInfo field && - field.GetValue(nullableAttribute) is byte[] flags && - flags.Length >= 1 && flags[0] == 1) + if (nullableAttribute.NullableFlags.Length >= 1 && nullableAttribute.NullableFlags[0] == 1) { return true; } @@ -73,9 +67,7 @@ public static bool IsDictionaryValueNonNullable(this MemberInfo memberInfo) return memberInfo.GetNullableFallbackValue(); } - if (nullableAttribute.GetType().GetField(NullableFlagsFieldName) is FieldInfo field && - field.GetValue(nullableAttribute) is byte[] flags && - flags.Length == 3 && flags[2] == 1) + if (nullableAttribute.NullableFlags.Length == 3 && nullableAttribute.NullableFlags[2] == 1) { return true; } @@ -83,13 +75,9 @@ public static bool IsDictionaryValueNonNullable(this MemberInfo memberInfo) return false; } - private static object GetNullableAttribute(this MemberInfo memberInfo) + private static NullableAttribute GetNullableAttribute(this MemberInfo memberInfo) { - var nullableAttribute = memberInfo.GetCustomAttributes() - .Where(attr => string.Equals(attr.GetType().FullName, NullableAttributeFullTypeName, StringComparison.Ordinal)) - .FirstOrDefault(); - - return nullableAttribute; + return memberInfo.GetCustomAttribute(); } private static bool GetNullableFallbackValue(this MemberInfo memberInfo) @@ -98,25 +86,30 @@ private static bool GetNullableFallbackValue(this MemberInfo memberInfo) ? new Type[] { memberInfo.DeclaringType, memberInfo.DeclaringType.DeclaringType } : new Type[] { memberInfo.DeclaringType }; + // https://github.com/dotnet/roslyn/blob/main/docs/features/nullable-metadata.md + // Check NullableContextAttribute first foreach (var declaringType in declaringTypes) { - var attributes = (IEnumerable)declaringType.GetCustomAttributes(false); + var attributes = declaringType.GetCustomAttributes(true); - var nullableContext = attributes - .Where(attr => string.Equals(attr.GetType().FullName, NullableContextAttributeFullTypeName, StringComparison.Ordinal)) - .FirstOrDefault(); + // NullableContextAttribute is optional + var nullableContextAttribute = (NullableContextAttribute)attributes.FirstOrDefault(a => a is NullableContextAttribute); + + if (nullableContextAttribute != null) + { + return nullableContextAttribute.Flag == 1; + } + } + + // Next check NullableAttribute + foreach (var declaringType in declaringTypes) + { + var attributes = declaringType.GetCustomAttributes(true); + var nullableAttribute = (NullableAttribute)attributes.FirstOrDefault(a => a is NullableAttribute); - if (nullableContext != null) + if (nullableAttribute != null) { - if (nullableContext.GetType().GetField(FlagFieldName) is FieldInfo field && - field.GetValue(nullableContext) is byte flag && flag == 1) - { - return true; - } - else - { - return false; - } + return nullableAttribute.NullableFlags.Length >= 1 && nullableAttribute.NullableFlags[0] == 1; } } diff --git a/test/DotSwashbuckle.AspNetCore.Newtonsoft.Test/SchemaGenerator/NewtonsoftSchemaGeneratorTests.cs b/test/DotSwashbuckle.AspNetCore.Newtonsoft.Test/SchemaGenerator/NewtonsoftSchemaGeneratorTests.cs index 2c7be2f..6ca77b5 100644 --- a/test/DotSwashbuckle.AspNetCore.Newtonsoft.Test/SchemaGenerator/NewtonsoftSchemaGeneratorTests.cs +++ b/test/DotSwashbuckle.AspNetCore.Newtonsoft.Test/SchemaGenerator/NewtonsoftSchemaGeneratorTests.cs @@ -353,6 +353,36 @@ public void GenerateSchema_SetsValidationProperties_IfComplexTypeHasValidationAt Assert.Equal(new[] { "StringWithRequired", "StringWithRequiredAllowEmptyTrue", "StringWithRequiredModifier" }, schema.Required.ToArray()); } + [Fact] + public void GenerateSchema_NestedRecords_RecordTypeWithNonNestedChild_NullableCheck() + { + var schemaRepository = new SchemaRepository(); + + Subject( + (options) => options.SupportNonNullableReferenceTypes = true + ).GenerateSchema(typeof(RecordTypeWithNonNestedChild), schemaRepository); + + var schema = schemaRepository.Schemas[nameof(RecordChild)]; + + Assert.False(schema.Properties[nameof(RecordChild.NonNullable)].Nullable); + Assert.True(schema.Properties[nameof(RecordChild.Nullable)].Nullable); + } + + [Fact] + public void GenerateSchema_NestedRecords_RecordTypeWithNestedChild_NullableCheck() + { + var schemaRepository = new SchemaRepository(); + + Subject( + (options) => options.SupportNonNullableReferenceTypes = true + ).GenerateSchema(typeof(RecordTypeWithNestedChild), schemaRepository); + + var schema = schemaRepository.Schemas[nameof(RecordTypeWithNestedChild.NestedChild)]; + + Assert.False(schema.Properties[nameof(RecordTypeWithNestedChild.NestedChild.NonNullable)].Nullable); + Assert.True(schema.Properties[nameof(RecordTypeWithNestedChild.NestedChild.Nullable)].Nullable); + } + [Fact] public void GenerateSchema_SetsReadOnlyAndWriteOnlyFlags_IfPropertyIsRestricted() { diff --git a/test/DotSwashbuckle.AspNetCore.TestSupport/DotSwashbuckle.AspNetCore.TestSupport.csproj b/test/DotSwashbuckle.AspNetCore.TestSupport/DotSwashbuckle.AspNetCore.TestSupport.csproj index 4e87434..edafee3 100644 --- a/test/DotSwashbuckle.AspNetCore.TestSupport/DotSwashbuckle.AspNetCore.TestSupport.csproj +++ b/test/DotSwashbuckle.AspNetCore.TestSupport/DotSwashbuckle.AspNetCore.TestSupport.csproj @@ -8,6 +8,7 @@ false false true + enable diff --git a/test/DotSwashbuckle.AspNetCore.TestSupport/Fixtures/NestedRecords.cs b/test/DotSwashbuckle.AspNetCore.TestSupport/Fixtures/NestedRecords.cs new file mode 100644 index 0000000..b3f7f65 --- /dev/null +++ b/test/DotSwashbuckle.AspNetCore.TestSupport/Fixtures/NestedRecords.cs @@ -0,0 +1,10 @@ +namespace DotSwashbuckle.AspNetCore.TestSupport.Fixtures +{ + public record RecordTypeWithNestedChild(RecordTypeWithNestedChild.NestedChild Child) + { + public record NestedChild(string NonNullable, string? Nullable); + } + + public record RecordTypeWithNonNestedChild(RecordChild RecordChild); + public record RecordChild(string NonNullable, string? Nullable); +} diff --git a/test/WebSites/Basic/Basic.csproj b/test/WebSites/Basic/Basic.csproj index 738d013..aa653d2 100644 --- a/test/WebSites/Basic/Basic.csproj +++ b/test/WebSites/Basic/Basic.csproj @@ -4,6 +4,7 @@ true $(NoWarn);1591 net8.0 + enable diff --git a/test/WebSites/Basic/Startup.cs b/test/WebSites/Basic/Startup.cs index 1ddaba6..dfadcfe 100644 --- a/test/WebSites/Basic/Startup.cs +++ b/test/WebSites/Basic/Startup.cs @@ -28,6 +28,7 @@ public void ConfigureServices(IServiceCollection services) services.AddSwaggerGen(c => { + c.SupportNonNullableReferenceTypes(); c.SwaggerDoc("v1", new OpenApiInfo { @@ -78,8 +79,11 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) { c.SerializeAsV2 = true; }); + endpoints.MapPost("/requestWithNestedChild", (Requests.RequestWithNestedChild request) => "ok"); + endpoints.MapPost("/requestWithNonNestedChild", (Requests.RequestWithNonNestedChild request) => "ok"); }); + var supportedCultures = new[] { new CultureInfo("en-US"), @@ -102,5 +106,16 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) c.SwaggerEndpoint("/swagger/v1/swagger.json", "V1 Docs"); }); } + + public class Requests + { + public record RequestWithNestedChild(RequestWithNestedChild.NestedChild Child) + { + public record NestedChild(string NonNullable, string? Nullable); + } + + public record RequestWithNonNestedChild(Child Child); + public record Child(string NonNullable, string? Nullable); + } } }