From 1811672eeb16fe199014d7f0a88a0b0c8749c692 Mon Sep 17 00:00:00 2001 From: Erik Ejlskov Jensen Date: Tue, 9 Apr 2024 21:27:44 +0200 Subject: [PATCH] Fix | ArgumentNullException on SqlDataRecord.GetValue when using Udt data type (#2448) --- .../Data/SqlClient/Server/MetadataUtilsSmi.cs | 4 --- .../Microsoft.Data.SqlClient.Tests.csproj | 2 ++ .../FunctionalTests/SqlDataRecordTest.cs | 35 ++++++++++++++++++- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/MetadataUtilsSmi.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/MetadataUtilsSmi.cs index 590429222a..2dc635320b 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/MetadataUtilsSmi.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/MetadataUtilsSmi.cs @@ -500,11 +500,7 @@ internal static SmiExtendedMetaData SqlMetaDataToSmiExtendedMetaData(SqlMetaData source.Scale, source.LocaleId, source.CompareOptions, -#if NETFRAMEWORK source.Type, -#else - null, -#endif source.Name, typeSpecificNamePart1, typeSpecificNamePart2, diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index 7724f4c58e..506960d3b6 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -94,6 +94,8 @@ + + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs index 87f6c167ec..0e748c5d74 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlDataRecordTest.cs @@ -8,6 +8,7 @@ using System.Data; using System.Data.SqlTypes; using Microsoft.Data.SqlClient.Server; +using Microsoft.SqlServer.Types; using Xunit; namespace Microsoft.Data.SqlClient.Tests @@ -318,6 +319,19 @@ public void GetChar_ThrowsNotSupported() Assert.Throws(() => record.GetChar(0)); } + [Theory] + [ClassData(typeof(GetUdtTypeTestData))] + public void GetUdt_ReturnsValue(Type udtType, object value, string serverTypeName) + { + SqlMetaData[] metadata = new SqlMetaData[] { new SqlMetaData(nameof(udtType.Name), SqlDbType.Udt, udtType, serverTypeName) }; + + SqlDataRecord record = new SqlDataRecord(metadata); + + record.SetValue(0, value); + + Assert.Equal(value.ToString(), record.GetValue(0).ToString()); + } + [Theory] [ClassData(typeof(GetXXXBadTypeTestData))] public void GetXXX_ThrowsIfBadType(Func getXXX) @@ -342,8 +356,8 @@ public void GetXXX_ReturnValue(SqlDbType dbType, object value, Func + { + public IEnumerator GetEnumerator() + { + yield return new object[] { typeof(SqlGeography), SqlGeography.Point(43, -81, 4326), "Geography" }; + yield return new object[] { typeof(SqlGeometry), SqlGeometry.Point(43, -81, 4326), "Geometry" }; + yield return new object[] { typeof(SqlHierarchyId), SqlHierarchyId.Parse("/"), "HierarchyId" }; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + public class GetXXXCheckValueTestData : IEnumerable { public IEnumerator GetEnumerator() @@ -383,6 +412,10 @@ public IEnumerator GetEnumerator() yield return new object[] { SqlDbType.DateTime, DateTime.Now, new Func(r => r.GetDateTime(0)) }; yield return new object[] { SqlDbType.DateTimeOffset, new DateTimeOffset(DateTime.Now), new Func(r => r.GetDateTimeOffset(0)) }; yield return new object[] { SqlDbType.Time, TimeSpan.FromHours(1), new Func(r => r.GetTimeSpan(0)) }; + yield return new object[] { SqlDbType.Date, DateTime.Now.Date, new Func(r => r.GetDateTime(0)) }; + yield return new object[] { SqlDbType.Bit, bool.Parse(bool.TrueString), new Func(r => r.GetBoolean(0)) }; + yield return new object[] { SqlDbType.SmallDateTime, DateTime.Now, new Func(r => r.GetDateTime(0)) }; + yield return new object[] { SqlDbType.TinyInt, (byte)1, new Func(r => r.GetByte(0)) }; } IEnumerator IEnumerable.GetEnumerator()