Skip to content

Commit

Permalink
Test coverage for data source enum mapping support
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jan 29, 2023
1 parent aa9ee53 commit c78e984
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 56 deletions.
79 changes: 54 additions & 25 deletions test/EFCore.PG.FunctionalTests/Query/EnumQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ public EnumQueryTest(EnumFixture fixture, ITestOutputHelper testOutputHelper)
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

#region Roundtrip

[Fact]
public void Roundtrip()
{
Expand All @@ -22,10 +20,6 @@ public void Roundtrip()
Assert.Equal(MappedEnum.Happy, x.MappedEnum);
}

#endregion

#region Where

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task Where_with_constant(bool async)
Expand All @@ -39,7 +33,7 @@ await AssertQuery(

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."MappedEnum" = 'sad'::test.mapped_enum
""");
Expand All @@ -58,7 +52,7 @@ await AssertQuery(

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."SchemaQualifiedEnum" = 'Happy (PgName)'::test.schema_qualified_enum
""");
Expand All @@ -80,7 +74,7 @@ await AssertQuery(
"""
@__sad_0='Sad' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."MappedEnum" = @__sad_0
""");
Expand All @@ -102,7 +96,7 @@ await AssertQuery(
"""
@__sad_0='1'

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."UnmappedEnum" = @__sad_0
""");
Expand All @@ -124,7 +118,7 @@ await AssertQuery(
"""
@__sad_0='1'

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."UnmappedEnum" = @__sad_0
""");
Expand All @@ -146,7 +140,7 @@ await AssertQuery(
"""
@__sad_0='Sad' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."MappedEnum" = @__sad_0
""");
Expand All @@ -166,7 +160,7 @@ await AssertQuery(

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE strpos(s."MappedEnum"::text, 'sa') > 0
""");
Expand All @@ -189,7 +183,7 @@ await AssertQuery(
"""
@__values_0='0x01' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."ByteEnum" = ANY (@__values_0)
""");
Expand All @@ -211,13 +205,30 @@ await AssertQuery(
"""
@__values_0='0x01' (DbType = Object)

SELECT s."Id", s."ByteEnum", s."EnumValue", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."UnmappedByteEnum" = ANY (@__values_0)
""");
}

#endregion
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task Global_enum_mapping(bool async)
{
using var ctx = CreateContext();

await AssertQuery(
async,
ss => ss.Set<SomeEnumEntity>().Where(e => e.GloballyMappedEnum == GloballyMappedEnum.Sad),
entryCount: 1);

AssertSql(
"""
SELECT s."Id", s."ByteEnum", s."EnumValue", s."GloballyMappedEnum", s."InferredEnum", s."MappedEnum", s."SchemaQualifiedEnum", s."UnmappedByteEnum", s."UnmappedEnum"
FROM test."SomeEntities" AS s
WHERE s."GloballyMappedEnum" = 'sad'::test.globally_mapped_enum
""");
}

#region Support

Expand All @@ -234,21 +245,20 @@ public class EnumContext : PoolableDbContext
static EnumContext()
{
#pragma warning disable CS0618 // NpgsqlConnection.GlobalTypeMapper is obsolete
NpgsqlConnection.GlobalTypeMapper.MapEnum<MappedEnum>("test.mapped_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<InferredEnum>("test.inferred_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<ByteEnum>("test.byte_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<SchemaQualifiedEnum>("test.schema_qualified_enum");
NpgsqlConnection.GlobalTypeMapper.MapEnum<GloballyMappedEnum>("test.globally_mapped_enum");
#pragma warning restore CS0618
}

public EnumContext(DbContextOptions options) : base(options) {}

protected override void OnModelCreating(ModelBuilder builder)
=> builder.HasPostgresEnum("mapped_enum", new[] { "happy", "sad" })
=> builder
.HasPostgresEnum("mapped_enum", new[] { "happy", "sad" })
.HasPostgresEnum<InferredEnum>()
.HasPostgresEnum<ByteEnum>()
.HasDefaultSchema("test")
.HasPostgresEnum<SchemaQualifiedEnum>();
.HasPostgresEnum<SchemaQualifiedEnum>()
.HasPostgresEnum<GloballyMappedEnum>();

public static void Seed(EnumContext context)
{
Expand All @@ -270,6 +280,7 @@ public class SomeEnumEntity
public ByteEnum ByteEnum { get; set; }
public UnmappedByteEnum UnmappedByteEnum { get; set; }
public int EnumValue { get; set; }
public GloballyMappedEnum GloballyMappedEnum { get; set; }
}

public enum MappedEnum
Expand All @@ -290,6 +301,12 @@ public enum InferredEnum
Sad
}

public enum GloballyMappedEnum
{
Happy,
Sad
}

public enum SchemaQualifiedEnum
{
[PgName("Happy (PgName)")]
Expand All @@ -313,7 +330,16 @@ public enum UnmappedByteEnum : byte
public class EnumFixture : SharedStoreFixtureBase<EnumContext>, IQueryFixtureBase
{
protected override string StoreName => "EnumQueryTest";
protected override ITestStoreFactory TestStoreFactory => NpgsqlTestStoreFactory.Instance;

protected override ITestStoreFactory TestStoreFactory
=> NpgsqlTestStoreFactory.WithDataSourceConfiguration(
b =>
b
.MapEnum<MappedEnum>("test.mapped_enum")
.MapEnum<InferredEnum>("test.inferred_enum")
.MapEnum<ByteEnum>("test.byte_enum")
.MapEnum<SchemaQualifiedEnum>("test.schema_qualified_enum"));

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

private EnumData _expectedData;
Expand Down Expand Up @@ -350,6 +376,7 @@ public IReadOnlyDictionary<Type, object> EntityAsserters
Assert.Equal(ee.ByteEnum, aa.ByteEnum);
Assert.Equal(ee.UnmappedByteEnum, aa.UnmappedByteEnum);
Assert.Equal(ee.EnumValue, aa.EnumValue);
Assert.Equal(ee.GloballyMappedEnum, aa.GloballyMappedEnum);
}
}
}
Expand Down Expand Up @@ -386,7 +413,8 @@ public static IReadOnlyList<SomeEnumEntity> CreateSomeEnumEntities()
SchemaQualifiedEnum = SchemaQualifiedEnum.Happy,
ByteEnum = ByteEnum.Happy,
UnmappedByteEnum = UnmappedByteEnum.Happy,
EnumValue = (int)MappedEnum.Happy
EnumValue = (int)MappedEnum.Happy,
GloballyMappedEnum = GloballyMappedEnum.Happy
},
new()
{
Expand All @@ -397,7 +425,8 @@ public static IReadOnlyList<SomeEnumEntity> CreateSomeEnumEntities()
SchemaQualifiedEnum = SchemaQualifiedEnum.Sad,
ByteEnum = ByteEnum.Sad,
UnmappedByteEnum = UnmappedByteEnum.Sad,
EnumValue = (int)MappedEnum.Sad
EnumValue = (int)MappedEnum.Sad,
GloballyMappedEnum = GloballyMappedEnum.Sad
}
};
}
Expand Down
2 changes: 1 addition & 1 deletion test/EFCore.PG.FunctionalTests/Query/TimestampQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ public class TimestampQueryFixture : SharedStoreFixtureBase<TimestampQueryContex
// don't depend on the database's time zone, and also that operations which shouldn't take TimeZone into account indeed
// don't.
protected override ITestStoreFactory TestStoreFactory
=> NpgsqlTestStoreFactory.WithConnectionStringOptions("-c TimeZone=Europe/Berlin");
=> NpgsqlTestStoreFactory.WithDataSourceConfiguration(b => b.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin");

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

Expand Down
48 changes: 27 additions & 21 deletions test/EFCore.PG.FunctionalTests/TestUtilities/NpgsqlTestStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

namespace Npgsql.EntityFrameworkCore.PostgreSQL.TestUtilities;

// ReSharper disable VirtualMemberCallInConstructor

public class NpgsqlTestStore : RelationalTestStore
{
private readonly NpgsqlDataSource _dataSource;

private readonly string _scriptPath;
private readonly string _additionalSql;

Expand All @@ -27,11 +31,14 @@ public static NpgsqlTestStore GetOrCreate(
string name,
string scriptPath = null,
string additionalSql = null,
string connectionStringOptions = null)
=> new(name, scriptPath, additionalSql, connectionStringOptions);
string connectionStringOptions = null,
Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null)
=> new(name, scriptPath, additionalSql, dataSourceBuilderAction);

public static NpgsqlTestStore Create(string name, string connectionStringOptions = null)
=> new(name, connectionStringOptions: connectionStringOptions, shared: false);
public static NpgsqlTestStore Create(
string name,
Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null)
=> new(name, dataSourceBuilderAction: dataSourceBuilderAction, shared: false);

public static NpgsqlTestStore CreateInitialized(string name)
=> new NpgsqlTestStore(name, shared: false)
Expand All @@ -41,7 +48,7 @@ private NpgsqlTestStore(
string name,
string scriptPath = null,
string additionalSql = null,
string connectionStringOptions = null,
Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null,
bool shared = true)
: base(name, shared)
{
Expand All @@ -55,10 +62,11 @@ private NpgsqlTestStore(

_additionalSql = additionalSql;

// ReSharper disable VirtualMemberCallInConstructor
ConnectionString = CreateConnectionString(Name, connectionStringOptions);
Connection = new NpgsqlConnection(ConnectionString);
// ReSharper restore VirtualMemberCallInConstructor
ConnectionString = CreateConnectionString(Name);
var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString);
dataSourceBuilderAction?.Invoke(dataSourceBuilder);
_dataSource = dataSourceBuilder.Build();
Connection = _dataSource.CreateConnection();
}

// ReSharper disable once MemberCanBePrivate.Global
Expand Down Expand Up @@ -100,7 +108,7 @@ protected override void Initialize(Func<DbContext> createContext, Action<DbConte
}

public override DbContextOptionsBuilder AddProviderOptions(DbContextOptionsBuilder builder)
=> builder.UseNpgsql(Connection, b => b.ApplyConfiguration()
=> builder.UseNpgsql(_dataSource, b => b.ApplyConfiguration()
.CommandTimeout(CommandTimeout)
// The tests are written with the assumption that NULLs are sorted first (SQL Server and .NET behavior), but PostgreSQL
// sorts NULLs last by default. This configures the provider to emit NULLS FIRST.
Expand Down Expand Up @@ -415,20 +423,18 @@ private static DbCommand CreateCommand(
return command;
}

public static string CreateConnectionString(string name, string options = null)
{
var builder = new NpgsqlConnectionStringBuilder(TestEnvironment.DefaultConnection) { Database = name };

if (options is not null)
{
builder.Options = options;
}

return builder.ConnectionString;
}
public static string CreateConnectionString(string name)
=> new NpgsqlConnectionStringBuilder(TestEnvironment.DefaultConnection) { Database = name }.ConnectionString;

private static string CreateAdminConnectionString() => CreateConnectionString("postgres");

public override void Clean(DbContext context)
=> context.Database.EnsureClean();

public override void Dispose()
{
base.Dispose();

_dataSource.Dispose();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

public class NpgsqlTestStoreFactory : RelationalTestStoreFactory
{
private string _connectionStringOptions;
private readonly Action<NpgsqlDataSourceBuilder> _dataSourceBuilderAction;

public static NpgsqlTestStoreFactory Instance { get; } = new();

public static NpgsqlTestStoreFactory WithConnectionStringOptions(string connectionStringOptions)
=> new(connectionStringOptions);
public static NpgsqlTestStoreFactory WithDataSourceConfiguration(Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction)
=> new(dataSourceBuilderAction);

protected NpgsqlTestStoreFactory(string connectionStringOptions = null)
=> _connectionStringOptions = connectionStringOptions;
protected NpgsqlTestStoreFactory(Action<NpgsqlDataSourceBuilder> dataSourceBuilderAction = null)
=> _dataSourceBuilderAction = dataSourceBuilderAction;

public override TestStore Create(string storeName)
=> NpgsqlTestStore.Create(storeName, _connectionStringOptions);
=> NpgsqlTestStore.Create(storeName, _dataSourceBuilderAction);

public override TestStore GetOrCreate(string storeName)
=> NpgsqlTestStore.GetOrCreate(storeName, connectionStringOptions: _connectionStringOptions);
=> NpgsqlTestStore.GetOrCreate(storeName, dataSourceBuilderAction: _dataSourceBuilderAction);

public override IServiceCollection AddProviderServices(IServiceCollection serviceCollection)
=> serviceCollection.AddEntityFrameworkNpgsql();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1910,7 +1910,7 @@ public NodaTimeQueryNpgsqlFixture()
// don't depend on the database's time zone, and also that operations which shouldn't take TimeZone into account indeed
// don't.
protected override ITestStoreFactory TestStoreFactory
=> NpgsqlTestStoreFactory.WithConnectionStringOptions("-c TimeZone=Europe/Berlin");
=> NpgsqlTestStoreFactory.WithDataSourceConfiguration(b => b.ConnectionStringBuilder.Options = "-c TimeZone=Europe/Berlin");

public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ListLoggerFactory;

Expand Down

0 comments on commit c78e984

Please sign in to comment.