Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Fix nullability check for nested record types. fixes domaindrivendev/…
Browse files Browse the repository at this point in the history
  • Loading branch information
Havunen committed Feb 17, 2024
1 parent f3fb6ba commit ab7e8e6
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Mvc;

namespace DotSwashbuckle.AspNetCore.SwaggerGen
{
public static class MemberInfoExtensions
{
private const string NullableAttributeFullTypeName = "System.Runtime.CompilerServices.NullableAttribute";
private const string NullableFlagsFieldName = "NullableFlags";
private const string NullableContextAttributeFullTypeName = "System.Runtime.CompilerServices.NullableContextAttribute";
private const string FlagFieldName = "Flag";

public static ICollection<object> GetInlineAndMetadataAttributes(this MemberInfo memberInfo)
{
var attributes = memberInfo.GetCustomAttributes(true)
Expand Down Expand Up @@ -48,9 +44,7 @@ public static bool IsNonNullableReferenceType(this MemberInfo memberInfo)
return memberInfo.GetNullableFallbackValue();
}

if (nullableAttribute.GetType().GetField(NullableFlagsFieldName) is FieldInfo field &&
field.GetValue(nullableAttribute) is byte[] flags &&
flags.Length >= 1 && flags[0] == 1)
if (nullableAttribute.NullableFlags.Length >= 1 && nullableAttribute.NullableFlags[0] == 1)
{
return true;
}
Expand All @@ -73,23 +67,17 @@ public static bool IsDictionaryValueNonNullable(this MemberInfo memberInfo)
return memberInfo.GetNullableFallbackValue();
}

if (nullableAttribute.GetType().GetField(NullableFlagsFieldName) is FieldInfo field &&
field.GetValue(nullableAttribute) is byte[] flags &&
flags.Length == 3 && flags[2] == 1)
if (nullableAttribute.NullableFlags.Length == 3 && nullableAttribute.NullableFlags[2] == 1)
{
return true;
}

return false;
}

private static object GetNullableAttribute(this MemberInfo memberInfo)
private static NullableAttribute GetNullableAttribute(this MemberInfo memberInfo)
{
var nullableAttribute = memberInfo.GetCustomAttributes()
.Where(attr => string.Equals(attr.GetType().FullName, NullableAttributeFullTypeName, StringComparison.Ordinal))
.FirstOrDefault();

return nullableAttribute;
return memberInfo.GetCustomAttribute<NullableAttribute>();
}

private static bool GetNullableFallbackValue(this MemberInfo memberInfo)
Expand All @@ -98,25 +86,30 @@ private static bool GetNullableFallbackValue(this MemberInfo memberInfo)
? new Type[] { memberInfo.DeclaringType, memberInfo.DeclaringType.DeclaringType }
: new Type[] { memberInfo.DeclaringType };

// https://github.com/dotnet/roslyn/blob/main/docs/features/nullable-metadata.md
// Check NullableContextAttribute first
foreach (var declaringType in declaringTypes)
{
var attributes = (IEnumerable<object>)declaringType.GetCustomAttributes(false);
var attributes = declaringType.GetCustomAttributes(true);

var nullableContext = attributes
.Where(attr => string.Equals(attr.GetType().FullName, NullableContextAttributeFullTypeName, StringComparison.Ordinal))
.FirstOrDefault();
// NullableContextAttribute is optional
var nullableContextAttribute = (NullableContextAttribute)attributes.FirstOrDefault(a => a is NullableContextAttribute);

if (nullableContextAttribute != null)
{
return nullableContextAttribute.Flag == 1;
}
}

// Next check NullableAttribute
foreach (var declaringType in declaringTypes)
{
var attributes = declaringType.GetCustomAttributes(true);
var nullableAttribute = (NullableAttribute)attributes.FirstOrDefault(a => a is NullableAttribute);

if (nullableContext != null)
if (nullableAttribute != null)
{
if (nullableContext.GetType().GetField(FlagFieldName) is FieldInfo field &&
field.GetValue(nullableContext) is byte flag && flag == 1)
{
return true;
}
else
{
return false;
}
return nullableAttribute.NullableFlags.Length >= 1 && nullableAttribute.NullableFlags[0] == 1;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,36 @@ public void GenerateSchema_SetsValidationProperties_IfComplexTypeHasValidationAt
Assert.Equal(new[] { "StringWithRequired", "StringWithRequiredAllowEmptyTrue", "StringWithRequiredModifier" }, schema.Required.ToArray());
}

[Fact]
public void GenerateSchema_NestedRecords_RecordTypeWithNonNestedChild_NullableCheck()
{
var schemaRepository = new SchemaRepository();

Subject(
(options) => options.SupportNonNullableReferenceTypes = true
).GenerateSchema(typeof(RecordTypeWithNonNestedChild), schemaRepository);

var schema = schemaRepository.Schemas[nameof(RecordChild)];

Assert.False(schema.Properties[nameof(RecordChild.NonNullable)].Nullable);
Assert.True(schema.Properties[nameof(RecordChild.Nullable)].Nullable);
}

[Fact]
public void GenerateSchema_NestedRecords_RecordTypeWithNestedChild_NullableCheck()
{
var schemaRepository = new SchemaRepository();

Subject(
(options) => options.SupportNonNullableReferenceTypes = true
).GenerateSchema(typeof(RecordTypeWithNestedChild), schemaRepository);

var schema = schemaRepository.Schemas[nameof(RecordTypeWithNestedChild.NestedChild)];

Assert.False(schema.Properties[nameof(RecordTypeWithNestedChild.NestedChild.NonNullable)].Nullable);
Assert.True(schema.Properties[nameof(RecordTypeWithNestedChild.NestedChild.Nullable)].Nullable);
}

[Fact]
public void GenerateSchema_SetsReadOnlyAndWriteOnlyFlags_IfPropertyIsRestricted()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<IsUnitTestProject>false</IsUnitTestProject>
<IsTestProject>false</IsTestProject>
<DefineTrace>true</DefineTrace>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace DotSwashbuckle.AspNetCore.TestSupport.Fixtures
{
public record RecordTypeWithNestedChild(RecordTypeWithNestedChild.NestedChild Child)
{
public record NestedChild(string NonNullable, string? Nullable);
}

public record RecordTypeWithNonNestedChild(RecordChild RecordChild);
public record RecordChild(string NonNullable, string? Nullable);
}
1 change: 1 addition & 0 deletions test/WebSites/Basic/Basic.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<NoWarn>$(NoWarn);1591</NoWarn>
<TargetFrameworks>net8.0</TargetFrameworks>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
Expand Down
15 changes: 15 additions & 0 deletions test/WebSites/Basic/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public void ConfigureServices(IServiceCollection services)

services.AddSwaggerGen(c =>
{
c.SupportNonNullableReferenceTypes();
c.SwaggerDoc("v1",
new OpenApiInfo
{
Expand Down Expand Up @@ -78,8 +79,11 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
{
c.SerializeAsV2 = true;
});
endpoints.MapPost("/requestWithNestedChild", (Requests.RequestWithNestedChild request) => "ok");
endpoints.MapPost("/requestWithNonNestedChild", (Requests.RequestWithNonNestedChild request) => "ok");
});


var supportedCultures = new[]
{
new CultureInfo("en-US"),
Expand All @@ -102,5 +106,16 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
c.SwaggerEndpoint("/swagger/v1/swagger.json", "V1 Docs");
});
}

public class Requests
{
public record RequestWithNestedChild(RequestWithNestedChild.NestedChild Child)
{
public record NestedChild(string NonNullable, string? Nullable);
}

public record RequestWithNonNestedChild(Child Child);
public record Child(string NonNullable, string? Nullable);
}
}
}

0 comments on commit ab7e8e6

Please sign in to comment.