From 823604eb94228f68623103a06ea6cd11c6da972d Mon Sep 17 00:00:00 2001 From: jjjkkkjjj Date: Wed, 5 May 2021 19:51:16 +0900 Subject: [PATCH] change toBool func --- Sources/Matft/core/function/conversion.swift | 2 +- Sources/Matft/util/common/type.swift | 12 +--- Sources/Matft/util/library/vDSP.swift | 63 +++++++++++++++++++- Tests/MatftTests/MathTest.swift | 4 +- Tests/PerformanceTests/BoolPefTests.swift | 4 +- 5 files changed, 67 insertions(+), 18 deletions(-) diff --git a/Sources/Matft/core/function/conversion.swift b/Sources/Matft/core/function/conversion.swift index f43ea6f..645e8db 100644 --- a/Sources/Matft/core/function/conversion.swift +++ b/Sources/Matft/core/function/conversion.swift @@ -373,7 +373,7 @@ extension Matft{ - max: (optional) Maximum value. If nil is passed, handled as inf */ public static func clip(_ mfarray: MfArray, min: T? = nil, max: T? = nil) -> MfArray{ - func _clip(_ vDSP_func: vDSP_clip_func) -> MfArray{ + func _clip(_ vDSP_func: vDSP_clipcount_func) -> MfArray{ let min = min == nil ? -T.infinity : T.from(min!) let max = max == nil ? T.infinity : T.from(max!) return clip_by_vDSP(mfarray, min, max, vDSP_func) diff --git a/Sources/Matft/util/common/type.swift b/Sources/Matft/util/common/type.swift index 3f17fd1..b0c9bee 100644 --- a/Sources/Matft/util/common/type.swift +++ b/Sources/Matft/util/common/type.swift @@ -14,19 +14,11 @@ internal func to_Bool(_ mfarray: MfArray, thresholdF: Float = 1e-5, thresholdD: // TODO: use vDSP_vthr? switch ret.storedType { case .Float: - ret.withDataUnsafeMBPtrT(datatype: Float.self){ - [unowned ret] (dataptr) in - var newptr = dataptr.map{ abs($0) <= thresholdF ? Float.zero : Float(1) } - newptr.withUnsafeMutableBufferPointer{ - dataptr.baseAddress!.moveAssign(from: $0.baseAddress!, count: ret.storedSize) - } - } + let ret = toBool_by_vDSP(ret) + return ret case .Double: fatalError("Bug was occurred. Bool's storedType is not double.") } - - ret.mfdata._mftype = .Bool - return ret } /* internal func to_Bool_mm_op(l_mfarray: MfArray, r_mfarray: MfArray, op: (U, U) -> Bool) -> MfArray{ diff --git a/Sources/Matft/util/library/vDSP.swift b/Sources/Matft/util/library/vDSP.swift index cdf2829..3da7559 100644 --- a/Sources/Matft/util/library/vDSP.swift +++ b/Sources/Matft/util/library/vDSP.swift @@ -339,8 +339,8 @@ internal func sort_index_by_vDSP(_ mfarray: MfArray, _ axis: Int, } -internal typealias vDSP_clip_func = (UnsafePointer, vDSP_Stride, UnsafePointer, UnsafePointer, UnsafeMutablePointer, vDSP_Stride, vDSP_Length, UnsafeMutablePointer, UnsafeMutablePointer) -> Void -fileprivate func _run_clip(_ srcptr: UnsafePointer, dstptr: UnsafeMutablePointer, count: Int, _ min: T, _ max: T, _ vDSP_func: vDSP_clip_func){ +internal typealias vDSP_clipcount_func = (UnsafePointer, vDSP_Stride, UnsafePointer, UnsafePointer, UnsafeMutablePointer, vDSP_Stride, vDSP_Length, UnsafeMutablePointer, UnsafeMutablePointer) -> Void +fileprivate func _run_clip(_ srcptr: UnsafePointer, dstptr: UnsafeMutablePointer, count: Int, _ min: T, _ max: T, _ vDSP_func: vDSP_clipcount_func){ var min = min var max = max @@ -349,7 +349,7 @@ fileprivate func _run_clip(_ srcptr: UnsafePointer, dstptr: Un vDSP_func(srcptr, vDSP_Stride(1), &min, &max, dstptr, vDSP_Stride(1), vDSP_Length(count), &mincount, &maxcount) } -internal func clip_by_vDSP(_ mfarray: MfArray, _ min: T, _ max: T, _ vDSP_func: vDSP_clip_func) -> MfArray{ +internal func clip_by_vDSP(_ mfarray: MfArray, _ min: T, _ max: T, _ vDSP_func: vDSP_clipcount_func) -> MfArray{ //return mfarray must be either row or column major var mfarray = mfarray //print(mfarray) @@ -369,6 +369,63 @@ internal func clip_by_vDSP(_ mfarray: MfArray, _ min: T, _ max: T let newmfstructure = copy_mfstructure(mfarray.mfstructure) return MfArray(mfdata: newdata, mfstructure: newmfstructure) } + +/* +internal typealias vDSP_venvlp_func = (UnsafePointer, vDSP_Stride, UnsafePointer, vDSP_Stride, UnsafePointer, vDSP_Stride, UnsafeMutablePointer, vDSP_Stride, vDSP_Length) -> Void +internal typealias vDSP_clip_func = (UnsafePointer, vDSP_Stride, UnsafePointer, UnsafePointer, UnsafeMutablePointer, vDSP_Stride, vDSP_Length) -> Void + +internal func sign_by_vDSP(_ mfarray: MfArray, low: T, high: T){ + var low = low + var high = high + + let newdata = withDummyDataMRPtr(mfarray.mftype, storedSize: mfarray.storedSize){ + dstptr in + let dstptrT = dstptr.bindMemory(to: T.self, capacity: mfarray.storedSize) + mfarray.withDataUnsafeMBPtrT(datatype: T.self){ + [unowned mfarray] in + // if low <= + vDSP_venvlp_func(&high, vDSP_Stride(0), &low, vDSP_Stride(0), $0.baseAddress!, vDSP_Stride(1), dstptrT, vDSP_Stride(1), vDSP_Length(mfarray.storedSize)) + } + } + + let newmfstructure = copy_mfstructure(mfarray.mfstructure) + + vDSP_venvlp(<#T##__A: UnsafePointer##UnsafePointer#>, <#T##__IA: vDSP_Stride##vDSP_Stride#>, <#T##__B: UnsafePointer##UnsafePointer#>, <#T##__IB: vDSP_Stride##vDSP_Stride#>, <#T##__C: UnsafePointer##UnsafePointer#>, <#T##__IC: vDSP_Stride##vDSP_Stride#>, <#T##__D: UnsafeMutablePointer##UnsafeMutablePointer#>, <#T##__ID: vDSP_Stride##vDSP_Stride#>, <#T##__N: vDSP_Length##vDSP_Length#>) + vDSP_vclip(<#T##__A: UnsafePointer##UnsafePointer#>, <#T##__IA: vDSP_Stride##vDSP_Stride#>, <#T##__B: UnsafePointer##UnsafePointer#>, <#T##__C: UnsafePointer##UnsafePointer#>, <#T##__D: UnsafeMutablePointer##UnsafeMutablePointer#>, <#T##__ID: vDSP_Stride##vDSP_Stride#>, <#T##__N: vDSP_Length##vDSP_Length#>) +}*/ + +internal typealias vDSP_vminmg_func = (UnsafePointer, vDSP_Stride, UnsafePointer, vDSP_Stride, UnsafeMutablePointer, vDSP_Stride, vDSP_Length) -> Void + +internal typealias vDSP_vthres_func = (UnsafePointer, vDSP_Stride, UnsafePointer, UnsafeMutablePointer, vDSP_Stride, vDSP_Length) -> Void + +internal func toBool_by_vDSP(_ mfarray: MfArray) -> MfArray{ + assert(mfarray.storedType == .Float, "Must be bool") + + let size = mfarray.storedSize + let newdata = withDummyDataMRPtr(.Bool, storedSize: size){ + dstptr in + let dstptrT = dstptr.bindMemory(to: Float.self, capacity: size) + mfarray.withDataUnsafeMBPtrT(datatype: Float.self){ + srcptr in + var zero = Float.zero + var one = Float.from(1) + // if |src| <= 1 => dst = |src| + // |src| > 1 => dst = 1 + // Note that the 0<= dst <= 1 + vDSP_vminmg(srcptr.baseAddress!, vDSP_Stride(1), &one, vDSP_Stride(0), dstptrT, vDSP_Stride(1), vDSP_Length(size)) + + one = Float.from(1) + // if src <= 0, 1 <= src => dst = src + // 0 < src <= 1 => dst = 1 + vDSP_viclip(dstptrT, vDSP_Stride(1), &zero, &one, dstptrT, vDSP_Stride(1), vDSP_Length(size)) + //vDSP_vthres(dstptrT, vDSP_Stride(1), &one, dstptrT, vDSP_Stride(1), vDSP_Length(size)) + } + } + + let newmfstructure = copy_mfstructure(mfarray.mfstructure) + return MfArray(mfdata: newdata, mfstructure: newmfstructure) +} + // generate(arange) /* internal typealias vDSP_arange_func = (UnsafePointer, UnsafePointer, UnsafeMutablePointer, vDSP_Stride, vDSP_Length) -> Void diff --git a/Tests/MatftTests/MathTest.swift b/Tests/MatftTests/MathTest.swift index f0f7f5b..b2c5fca 100644 --- a/Tests/MatftTests/MathTest.swift +++ b/Tests/MatftTests/MathTest.swift @@ -26,8 +26,8 @@ final class MathTests: XCTestCase { [3, 1, 4, -5]], mftype: .Float) let aret = MfArray([[1.4142135, 1.0 , -Float.nan, 0.0 ], - [1.732050, 1.0 , 2.0 , -Float.nan]], mftype: .Float) - let aTret = MfArray([[1.4142135, 1.732050], + [1.7320508, 1.0 , 2.0 , -Float.nan]], mftype: .Float) + let aTret = MfArray([[1.4142135, 1.7320508], [1.0 , 1.0 ], [ -Float.nan, 2.0 ], [0.0 , -Float.nan]], mftype: .Float) diff --git a/Tests/PerformanceTests/BoolPefTests.swift b/Tests/PerformanceTests/BoolPefTests.swift index 44c2bd5..6c457f5 100644 --- a/Tests/PerformanceTests/BoolPefTests.swift +++ b/Tests/PerformanceTests/BoolPefTests.swift @@ -19,8 +19,8 @@ final class BoolPefTests: XCTestCase { let _ = a > 0 } /* - average: 0.007, relative standard deviation: 42.315%, values: [0.015937, 0.005105, 0.005490, 0.006968, 0.006461, 0.005805, 0.006472, 0.005748, 0.009101, 0.005661] - 7.27ms + average: 0.005, relative standard deviation: 29.541%, values: [0.010153, 0.005457, 0.004667, 0.004294, 0.004840, 0.004900, 0.004404, 0.004938, 0.005424, 0.005643] + 5.47ms */ } }