Skip to content

Commit

Permalink
Add support for built-in regex with Azure SQL DB
Browse files Browse the repository at this point in the history
fixes #156
  • Loading branch information
ErikEJ committed May 27, 2024
1 parent 68436ae commit a50c46a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
31 changes: 27 additions & 4 deletions EFCore.CheckConstraints.Test/ValidationCheckConstraintTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using EFCore.CheckConstraints.Internal;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
Expand Down Expand Up @@ -180,6 +181,16 @@ public virtual void RegularExpression()
Assert.Equal("dbo.RegexMatch('^A', [StartsWithA]) > 0", checkConstraint.Sql);
}

[Fact]
public virtual void RegularExpressionNavtiveMethod()
{
var entityType = BuildEntityType<Blog>(isAzureSql: true);

var checkConstraint = Assert.Single(entityType.GetCheckConstraints(), c => c.Name == "CK_Blog_StartsWithA_RegularExpression");
Assert.NotNull(checkConstraint);
Assert.Equal("REGEXP_LIKE ([StartsWithA], '^A')", checkConstraint.Sql);
}

[Fact]
public virtual void Properties_on_complex_type()
{
Expand Down Expand Up @@ -258,9 +269,20 @@ public class Location
public double Latitude { get; set; }
}

private IModel BuildModel(Action<ModelBuilder> buildAction, bool useRegex)
private IModel BuildModel(Action<ModelBuilder> buildAction, bool useRegex, bool isAzureSql)
{
var serviceProvider = SqlServerTestHelpers.Instance.CreateContextServices();

var dbContextOptions = serviceProvider.GetRequiredService<IDbContextOptions>();

var sqlServerOptionsExtension = dbContextOptions.Extensions
.Where(o => o.GetType().Name == "SqlServerOptionsExtension")
.FirstOrDefault();

sqlServerOptionsExtension!.GetType()
.GetField("_azureSql", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?
.SetValue(sqlServerOptionsExtension, isAzureSql);

var conventionSet = serviceProvider.GetRequiredService<IConventionSetBuilder>().CreateConventionSet();

conventionSet.ModelFinalizingConventions.Add(
Expand All @@ -269,20 +291,21 @@ private IModel BuildModel(Action<ModelBuilder> buildAction, bool useRegex)
serviceProvider.GetRequiredService<IRelationalTypeMappingSource>(),
serviceProvider.GetRequiredService<ISqlGenerationHelper>(),
serviceProvider.GetRequiredService<IRelationalTypeMappingSource>(),
serviceProvider.GetRequiredService<IDatabaseProvider>()));
serviceProvider.GetRequiredService<IDatabaseProvider>(),
dbContextOptions));

var builder = new ModelBuilder(conventionSet);
buildAction(builder);
return builder.FinalizeModel();
}

private IEntityType BuildEntityType<TEntity>(Action<EntityTypeBuilder<TEntity>>? buildAction = null, bool useRegex = true)
private IEntityType BuildEntityType<TEntity>(Action<EntityTypeBuilder<TEntity>>? buildAction = null, bool useRegex = true, bool isAzureSql = false)
where TEntity : class
{
return BuildModel(buildAction is null
? b => b.Entity<TEntity>()
: b => buildAction(b.Entity<TEntity>()),
useRegex).GetEntityTypes().Single();
useRegex, isAzureSql).GetEntityTypes().Single();
}

#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@ public class CheckConstraintsConventionSetPlugin : IConventionSetPlugin
private readonly ISqlGenerationHelper _sqlGenerationHelper;
private readonly IRelationalTypeMappingSource _relationalTypeMappingSource;
private readonly IDatabaseProvider _databaseProvider;
private readonly IDbContextOptions _dbContextOptions;

public CheckConstraintsConventionSetPlugin(
IDbContextOptions options,
IRelationalTypeMappingSource typeMappingSource,
ISqlGenerationHelper sqlGenerationHelper,
IRelationalTypeMappingSource relationalTypeMappingSource,
IDatabaseProvider databaseProvider)
IDatabaseProvider databaseProvider,
IDbContextOptions dbContextOptions)
{
_options = options;
_typeMappingSource = typeMappingSource;
_sqlGenerationHelper = sqlGenerationHelper;
_relationalTypeMappingSource = relationalTypeMappingSource;
_databaseProvider = databaseProvider;
_dbContextOptions = dbContextOptions;
}

public ConventionSet ModifyConventions(ConventionSet conventionSet)
Expand All @@ -50,7 +53,7 @@ public ConventionSet ModifyConventions(ConventionSet conventionSet)
conventionSet.ModelFinalizingConventions.Add(
new ValidationCheckConstraintConvention(
extension.ValidationCheckConstraintOptions!, _typeMappingSource, _sqlGenerationHelper,
_relationalTypeMappingSource, _databaseProvider));
_relationalTypeMappingSource, _databaseProvider, _dbContextOptions));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Text;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
Expand All @@ -30,6 +31,7 @@ public class ValidationCheckConstraintConvention : IModelFinalizingConvention
private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly ISqlGenerationHelper _sqlGenerationHelper;
private readonly IDatabaseProvider _databaseProvider;
private readonly IDbContextOptions _dbContextOptions;
private readonly RelationalTypeMapping? _intTypeMapping;

private readonly bool _useRegex;
Expand All @@ -40,11 +42,13 @@ public ValidationCheckConstraintConvention(
IRelationalTypeMappingSource typeMappingSource,
ISqlGenerationHelper sqlGenerationHelper,
IRelationalTypeMappingSource relationalTypeMappingSource,
IDatabaseProvider databaseProvider)
IDatabaseProvider databaseProvider,
IDbContextOptions dbContextOptions)
{
_typeMappingSource = typeMappingSource;
_sqlGenerationHelper = sqlGenerationHelper;
_databaseProvider = databaseProvider;
_dbContextOptions = dbContextOptions;
_intTypeMapping = relationalTypeMappingSource.FindMapping(typeof(int))!;

_useRegex = options.UseRegex && SupportsRegex;
Expand Down Expand Up @@ -320,7 +324,22 @@ protected virtual void AddListOfValuesConstraint(
}

protected virtual string GenerateRegexSql(string columnName, [RegexPattern] string regex)
=> string.Format(
{
var sqlServerOptionsExtension = _dbContextOptions.Extensions
.Where(o => o.GetType().Name == "SqlServerOptionsExtension")
.FirstOrDefault();

if (sqlServerOptionsExtension is not null
&& sqlServerOptionsExtension.GetType().GetProperty("IsAzureSql")?.GetValue(sqlServerOptionsExtension) is bool isAzureSql
&& isAzureSql)
{
return string.Format(
"REGEXP_LIKE ({0}, '{1}')",
_sqlGenerationHelper.DelimitIdentifier(columnName),
regex);
}

return string.Format(
_databaseProvider.Name switch
{
// For SQL Server, requires setup:
Expand All @@ -331,6 +350,7 @@ protected virtual string GenerateRegexSql(string columnName, [RegexPattern] stri
MySqlDatabaseProviderName => "{0} REGEXP '{1}'",
_ => throw new InvalidOperationException($"Provider {_databaseProvider.Name} doesn't support regular expressions")
}, _sqlGenerationHelper.DelimitIdentifier(columnName), regex);
}

protected virtual bool SupportsRegex
=> _databaseProvider.Name switch
Expand Down

0 comments on commit a50c46a

Please sign in to comment.