Skip to content

Commit

Permalink
[release/8.0-rc2] Address feedback and fix some TensorPrimitives issu…
Browse files Browse the repository at this point in the history
…es (#92437)

* Address feedback and fix some TensorPrimitives issues

- Added a few APIs based on initial feedback: Abs (vectorized), Log2, and element-wise Max/Min{Magnitude}
- Renamed L2Normalize to Norm
- Fixed semantics of Min/MaxMagnitude to return original value rather than the absolute value
- Renamed a few helper types for consistency
- Added tests

* Add a few more uses of Tolerance

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
github-actions[bot] and stephentoub committed Sep 22, 2023
1 parent b1dbbb9 commit 63d0c64
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace System.Numerics.Tensors
{
public static partial class TensorPrimitives
{
public static void Abs(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static void Add(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { }
public static void Add(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { }
public static void AddMultiply(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.ReadOnlySpan<float> multiplier, System.Span<float> destination) { }
Expand All @@ -24,12 +25,17 @@ public static void Exp(System.ReadOnlySpan<float> x, System.Span<float> destinat
public static int IndexOfMaxMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static int IndexOfMin(System.ReadOnlySpan<float> x) { throw null; }
public static int IndexOfMinMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static float L2Normalize(System.ReadOnlySpan<float> x) { throw null; }
public static float Norm(System.ReadOnlySpan<float> x) { throw null; }
public static void Log(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static void Log2(System.ReadOnlySpan<float> x, System.Span<float> destination) { }
public static float Max(System.ReadOnlySpan<float> x) { throw null; }
public static void Max(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
public static float MaxMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static void MaxMagnitude(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
public static float Min(System.ReadOnlySpan<float> x) { throw null; }
public static void Min(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
public static float MinMagnitude(System.ReadOnlySpan<float> x) { throw null; }
public static void MinMagnitude(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
public static void Multiply(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { }
public static void Multiply(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { }
public static void MultiplyAdd(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.ReadOnlySpan<float> addend, System.Span<float> destination) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ public static void Divide(ReadOnlySpan<float> x, float y, Span<float> destinatio
public static void Negate(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<NegateOperator>(x, destination);

/// <summary>Computes the element-wise result of: <c>MathF.Abs(<paramref name="x" />)</c>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>This method effectively does <c><paramref name="destination" />[i] = MathF.Abs(<paramref name="x" />[i])</c>.</remarks>
public static void Abs(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<AbsoluteOperator>(x, destination);

/// <summary>Computes the element-wise result of: <c>(<paramref name="x" /> + <paramref name="y" />) * <paramref name="multiplier" /></c>.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <param name="y">The second tensor, represented as a span.</param>
Expand Down Expand Up @@ -200,6 +208,24 @@ public static void Log(ReadOnlySpan<float> x, Span<float> destination)
}
}

/// <summary>Computes the element-wise result of: <c>log2(<paramref name="x" />)</c>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>This method effectively does <c><paramref name="destination" />[i] = <see cref="MathF" />.Log2(<paramref name="x" />[i])</c>.</remarks>
public static void Log2(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < x.Length; i++)
{
destination[i] = Log2(x[i]);
}
}

/// <summary>Computes the element-wise result of: <c>cosh(<paramref name="x" />)</c>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
Expand Down Expand Up @@ -318,9 +344,9 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y) // BLAS1:
/// </summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <returns>The L2 norm.</returns>
public static float L2Normalize(ReadOnlySpan<float> x) // BLAS1: nrm2
public static float Norm(ReadOnlySpan<float> x) // BLAS1: nrm2
{
return MathF.Sqrt(Aggregate<LoadSquared, AddOperator>(0f, x));
return MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(0f, x));
}

/// <summary>
Expand All @@ -345,7 +371,7 @@ public static void SoftMax(ReadOnlySpan<float> x, Span<float> destination)

for (int i = 0; i < x.Length; i++)
{
expSum += MathF.Pow((float)Math.E, x[i]);
expSum += MathF.Exp(x[i]);
}

for (int i = 0; i < x.Length; i++)
Expand Down Expand Up @@ -421,6 +447,31 @@ public static float Max(ReadOnlySpan<float> x)
return result;
}

/// <summary>Computes the element-wise result of: <c>MathF.Max(<paramref name="x" />, <paramref name="y" />)</c>.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <param name="y">The second tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Length of '<paramref name="x" />' must be same as length of '<paramref name="y" />'.</exception>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>This method effectively does <c><paramref name="destination" />[i] = MathF.Max(<paramref name="x" />[i], <paramref name="y" />[i])</c>.</remarks>
public static void Max(ReadOnlySpan<float> x, ReadOnlySpan<float> y, Span<float> destination)
{
if (x.Length != y.Length)
{
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Max(x[i], y[i]);
}
}

/// <summary>Computes the minimum element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The minimum element in <paramref name="x"/>.</returns>
Expand Down Expand Up @@ -464,6 +515,31 @@ public static float Min(ReadOnlySpan<float> x)
return result;
}

/// <summary>Computes the element-wise result of: <c>MathF.Min(<paramref name="x" />, <paramref name="y" />)</c>.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <param name="y">The second tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Length of '<paramref name="x" />' must be same as length of '<paramref name="y" />'.</exception>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>This method effectively does <c><paramref name="destination" />[i] = MathF.Min(<paramref name="x" />[i], <paramref name="y" />[i])</c>.</remarks>
public static void Min(ReadOnlySpan<float> x, ReadOnlySpan<float> y, Span<float> destination)
{
if (x.Length != y.Length)
{
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Min(x[i], y[i]);
}
}

/// <summary>Computes the maximum magnitude of any element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The maximum magnitude of any element in <paramref name="x"/>.</returns>
Expand Down Expand Up @@ -508,7 +584,32 @@ public static float MaxMagnitude(ReadOnlySpan<float> x)
}
}

return resultMag;
return result;
}

/// <summary>Computes the element-wise result of: <c>MathF.MaxMagnitude(<paramref name="x" />, <paramref name="y" />)</c>.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <param name="y">The second tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Length of '<paramref name="x" />' must be same as length of '<paramref name="y" />'.</exception>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>This method effectively does <c><paramref name="destination" />[i] = MathF.MaxMagnitude(<paramref name="x" />[i], <paramref name="y" />[i])</c>.</remarks>
public static void MaxMagnitude(ReadOnlySpan<float> x, ReadOnlySpan<float> y, Span<float> destination)
{
if (x.Length != y.Length)
{
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < x.Length; i++)
{
destination[i] = MaxMagnitude(x[i], y[i]);
}
}

/// <summary>Computes the minimum magnitude of any element in <paramref name="x"/>.</summary>
Expand All @@ -522,6 +623,7 @@ public static float MinMagnitude(ReadOnlySpan<float> x)
ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
}

float result = float.PositiveInfinity;
float resultMag = float.PositiveInfinity;

for (int i = 0; i < x.Length; i++)
Expand All @@ -543,16 +645,43 @@ public static float MinMagnitude(ReadOnlySpan<float> x)

if (currentMag < resultMag)
{
result = current;
resultMag = currentMag;
}
}
else if (IsNegative(current))
{
result = current;
resultMag = currentMag;
}
}

return resultMag;
return result;
}

/// <summary>Computes the element-wise result of: <c>MathF.MinMagnitude(<paramref name="x" />, <paramref name="y" />)</c>.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <param name="y">The second tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Length of '<paramref name="x" />' must be same as length of '<paramref name="y" />'.</exception>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>This method effectively does <c><paramref name="destination" />[i] = MathF.MinMagnitude(<paramref name="x" />[i], <paramref name="y" />[i])</c>.</remarks>
public static void MinMagnitude(ReadOnlySpan<float> x, ReadOnlySpan<float> y, Span<float> destination)
{
if (x.Length != y.Length)
{
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < x.Length; i++)
{
destination[i] = MinMagnitude(x[i], y[i]);
}
}

/// <summary>Computes the index of the maximum element in <paramref name="x"/>.</summary>
Expand Down Expand Up @@ -744,14 +873,14 @@ public static unsafe int IndexOfMinMagnitude(ReadOnlySpan<float> x)
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The result of adding all elements in <paramref name="x"/>, or zero if <paramref name="x"/> is empty.</returns>
public static float Sum(ReadOnlySpan<float> x) =>
Aggregate<LoadIdentity, AddOperator>(0f, x);
Aggregate<IdentityOperator, AddOperator>(0f, x);

/// <summary>Computes the sum of the squares of every element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The result of adding every element in <paramref name="x"/> multiplied by itself, or zero if <paramref name="x"/> is empty.</returns>
/// <remarks>This method effectively does <c><see cref="TensorPrimitives" />.Sum(<see cref="TensorPrimitives" />.Multiply(<paramref name="x" />, <paramref name="x" />))</c>.</remarks>
public static float SumOfSquares(ReadOnlySpan<float> x) =>
Aggregate<LoadSquared, AddOperator>(0f, x);
Aggregate<SquaredOperator, AddOperator>(0f, x);

/// <summary>Computes the sum of the absolute values of every element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -761,7 +890,7 @@ public static float SumOfSquares(ReadOnlySpan<float> x) =>
/// <para>This method corresponds to the <c>asum</c> method defined by <c>BLAS1</c>.</para>
/// </remarks>
public static float SumOfMagnitudes(ReadOnlySpan<float> x) =>
Aggregate<LoadAbsolute, AddOperator>(0f, x);
Aggregate<AbsoluteOperator, AddOperator>(0f, x);

/// <summary>Computes the product of all elements in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -774,7 +903,7 @@ public static float Product(ReadOnlySpan<float> x)
ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
}

return Aggregate<LoadIdentity, MultiplyOperator>(1.0f, x);
return Aggregate<IdentityOperator, MultiplyOperator>(1.0f, x);
}

/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> + <paramref name="y" /></c>.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin

private static bool IsNegative(float f) => float.IsNegative(f);

private static float MaxMagnitude(float x, float y) => MathF.MaxMagnitude(x, y);

private static float MinMagnitude(float x, float y) => MathF.MinMagnitude(x, y);

private static float Log2(float x) => MathF.Log2(x);

private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
{
// Compute the same as:
Expand Down Expand Up @@ -1184,7 +1190,7 @@ public static float Invoke(Vector512<float> x)
#endif
}

private readonly struct LoadIdentity : IUnaryOperator
private readonly struct IdentityOperator : IUnaryOperator
{
public static float Invoke(float x) => x;
public static Vector128<float> Invoke(Vector128<float> x) => x;
Expand All @@ -1194,7 +1200,7 @@ public static float Invoke(Vector512<float> x)
#endif
}

private readonly struct LoadSquared : IUnaryOperator
private readonly struct SquaredOperator : IUnaryOperator
{
public static float Invoke(float x) => x * x;
public static Vector128<float> Invoke(Vector128<float> x) => x * x;
Expand All @@ -1204,7 +1210,7 @@ public static float Invoke(Vector512<float> x)
#endif
}

private readonly struct LoadAbsolute : IUnaryOperator
private readonly struct AbsoluteOperator : IUnaryOperator
{
public static float Invoke(float x) => MathF.Abs(x);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ public static partial class TensorPrimitives
{
private static unsafe bool IsNegative(float f) => *(int*)&f < 0;

private static float MaxMagnitude(float x, float y) => MathF.Abs(x) >= MathF.Abs(y) ? x : y;

private static float MinMagnitude(float x, float y) => MathF.Abs(x) < MathF.Abs(y) ? x : y;

private static float Log2(float x) => MathF.Log(x, 2);

private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
{
// Compute the same as:
Expand Down Expand Up @@ -551,19 +557,19 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x, Vector<float> y, Vector<float> z) => (x * y) + z;
}

private readonly struct LoadIdentity : IUnaryOperator
private readonly struct IdentityOperator : IUnaryOperator
{
public float Invoke(float x) => x;
public Vector<float> Invoke(Vector<float> x) => x;
}

private readonly struct LoadSquared : IUnaryOperator
private readonly struct SquaredOperator : IUnaryOperator
{
public float Invoke(float x) => x * x;
public Vector<float> Invoke(Vector<float> x) => x * x;
}

private readonly struct LoadAbsolute : IUnaryOperator
private readonly struct AbsoluteOperator : IUnaryOperator
{
public float Invoke(float x) => MathF.Abs(x);

Expand Down
Loading

0 comments on commit 63d0c64

Please sign in to comment.