Skip to content

Commit

Permalink
change toBool func
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjkkkjjj committed May 5, 2021
1 parent 02501fa commit 823604e
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Sources/Matft/core/function/conversion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ extension Matft{
- max: (optional) Maximum value. If nil is passed, handled as inf
*/
public static func clip<T: MfTypable>(_ mfarray: MfArray, min: T? = nil, max: T? = nil) -> MfArray{
func _clip<T: MfStorable>(_ vDSP_func: vDSP_clip_func<T>) -> MfArray{
func _clip<T: MfStorable>(_ vDSP_func: vDSP_clipcount_func<T>) -> MfArray{
let min = min == nil ? -T.infinity : T.from(min!)
let max = max == nil ? T.infinity : T.from(max!)
return clip_by_vDSP(mfarray, min, max, vDSP_func)
Expand Down
12 changes: 2 additions & 10 deletions Sources/Matft/util/common/type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,11 @@ internal func to_Bool(_ mfarray: MfArray, thresholdF: Float = 1e-5, thresholdD:
// TODO: use vDSP_vthr?
switch ret.storedType {
case .Float:
ret.withDataUnsafeMBPtrT(datatype: Float.self){
[unowned ret] (dataptr) in
var newptr = dataptr.map{ abs($0) <= thresholdF ? Float.zero : Float(1) }
newptr.withUnsafeMutableBufferPointer{
dataptr.baseAddress!.moveAssign(from: $0.baseAddress!, count: ret.storedSize)
}
}
let ret = toBool_by_vDSP(ret)
return ret
case .Double:
fatalError("Bug was occurred. Bool's storedType is not double.")
}

ret.mfdata._mftype = .Bool
return ret
}
/*
internal func to_Bool_mm_op<U: MfStorable>(l_mfarray: MfArray, r_mfarray: MfArray, op: (U, U) -> Bool) -> MfArray{
Expand Down
63 changes: 60 additions & 3 deletions Sources/Matft/util/library/vDSP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ internal func sort_index_by_vDSP<T: MfStorable>(_ mfarray: MfArray, _ axis: Int,

}

internal typealias vDSP_clip_func<T: MfStorable> = (UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, UnsafePointer<T>, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length, UnsafeMutablePointer<vDSP_Length>, UnsafeMutablePointer<vDSP_Length>) -> Void
fileprivate func _run_clip<T: MfStorable>(_ srcptr: UnsafePointer<T>, dstptr: UnsafeMutablePointer<T>, count: Int, _ min: T, _ max: T, _ vDSP_func: vDSP_clip_func<T>){
internal typealias vDSP_clipcount_func<T: MfStorable> = (UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, UnsafePointer<T>, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length, UnsafeMutablePointer<vDSP_Length>, UnsafeMutablePointer<vDSP_Length>) -> Void
fileprivate func _run_clip<T: MfStorable>(_ srcptr: UnsafePointer<T>, dstptr: UnsafeMutablePointer<T>, count: Int, _ min: T, _ max: T, _ vDSP_func: vDSP_clipcount_func<T>){
var min = min
var max = max

Expand All @@ -349,7 +349,7 @@ fileprivate func _run_clip<T: MfStorable>(_ srcptr: UnsafePointer<T>, dstptr: Un

vDSP_func(srcptr, vDSP_Stride(1), &min, &max, dstptr, vDSP_Stride(1), vDSP_Length(count), &mincount, &maxcount)
}
internal func clip_by_vDSP<T: MfStorable>(_ mfarray: MfArray, _ min: T, _ max: T, _ vDSP_func: vDSP_clip_func<T>) -> MfArray{
internal func clip_by_vDSP<T: MfStorable>(_ mfarray: MfArray, _ min: T, _ max: T, _ vDSP_func: vDSP_clipcount_func<T>) -> MfArray{
//return mfarray must be either row or column major
var mfarray = mfarray
//print(mfarray)
Expand All @@ -369,6 +369,63 @@ internal func clip_by_vDSP<T: MfStorable>(_ mfarray: MfArray, _ min: T, _ max: T
let newmfstructure = copy_mfstructure(mfarray.mfstructure)
return MfArray(mfdata: newdata, mfstructure: newmfstructure)
}

/*
internal typealias vDSP_venvlp_func<T: MfStorable> = (UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, vDSP_Stride, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length) -> Void
internal typealias vDSP_clip_func<T: MfStorable> = (UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, UnsafePointer<T>, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length) -> Void
internal func sign_by_vDSP<T: MfStorable>(_ mfarray: MfArray, low: T, high: T){
var low = low
var high = high
let newdata = withDummyDataMRPtr(mfarray.mftype, storedSize: mfarray.storedSize){
dstptr in
let dstptrT = dstptr.bindMemory(to: T.self, capacity: mfarray.storedSize)
mfarray.withDataUnsafeMBPtrT(datatype: T.self){
[unowned mfarray] in
// if low <=
vDSP_venvlp_func(&high, vDSP_Stride(0), &low, vDSP_Stride(0), $0.baseAddress!, vDSP_Stride(1), dstptrT, vDSP_Stride(1), vDSP_Length(mfarray.storedSize))
}
}
let newmfstructure = copy_mfstructure(mfarray.mfstructure)
vDSP_venvlp(<#T##__A: UnsafePointer<Float>##UnsafePointer<Float>#>, <#T##__IA: vDSP_Stride##vDSP_Stride#>, <#T##__B: UnsafePointer<Float>##UnsafePointer<Float>#>, <#T##__IB: vDSP_Stride##vDSP_Stride#>, <#T##__C: UnsafePointer<Float>##UnsafePointer<Float>#>, <#T##__IC: vDSP_Stride##vDSP_Stride#>, <#T##__D: UnsafeMutablePointer<Float>##UnsafeMutablePointer<Float>#>, <#T##__ID: vDSP_Stride##vDSP_Stride#>, <#T##__N: vDSP_Length##vDSP_Length#>)
vDSP_vclip(<#T##__A: UnsafePointer<Float>##UnsafePointer<Float>#>, <#T##__IA: vDSP_Stride##vDSP_Stride#>, <#T##__B: UnsafePointer<Float>##UnsafePointer<Float>#>, <#T##__C: UnsafePointer<Float>##UnsafePointer<Float>#>, <#T##__D: UnsafeMutablePointer<Float>##UnsafeMutablePointer<Float>#>, <#T##__ID: vDSP_Stride##vDSP_Stride#>, <#T##__N: vDSP_Length##vDSP_Length#>)
}*/

internal typealias vDSP_vminmg_func<T: MfStorable> = (UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, vDSP_Stride, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length) -> Void

internal typealias vDSP_vthres_func<T: MfStorable> = (UnsafePointer<T>, vDSP_Stride, UnsafePointer<T>, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length) -> Void

internal func toBool_by_vDSP(_ mfarray: MfArray) -> MfArray{
assert(mfarray.storedType == .Float, "Must be bool")

let size = mfarray.storedSize
let newdata = withDummyDataMRPtr(.Bool, storedSize: size){
dstptr in
let dstptrT = dstptr.bindMemory(to: Float.self, capacity: size)
mfarray.withDataUnsafeMBPtrT(datatype: Float.self){
srcptr in
var zero = Float.zero
var one = Float.from(1)
// if |src| <= 1 => dst = |src|
// |src| > 1 => dst = 1
// Note that the 0<= dst <= 1
vDSP_vminmg(srcptr.baseAddress!, vDSP_Stride(1), &one, vDSP_Stride(0), dstptrT, vDSP_Stride(1), vDSP_Length(size))

one = Float.from(1)
// if src <= 0, 1 <= src => dst = src
// 0 < src <= 1 => dst = 1
vDSP_viclip(dstptrT, vDSP_Stride(1), &zero, &one, dstptrT, vDSP_Stride(1), vDSP_Length(size))
//vDSP_vthres(dstptrT, vDSP_Stride(1), &one, dstptrT, vDSP_Stride(1), vDSP_Length(size))
}
}

let newmfstructure = copy_mfstructure(mfarray.mfstructure)
return MfArray(mfdata: newdata, mfstructure: newmfstructure)
}

// generate(arange)
/*
internal typealias vDSP_arange_func<T> = (UnsafePointer<T>, UnsafePointer<T>, UnsafeMutablePointer<T>, vDSP_Stride, vDSP_Length) -> Void
Expand Down
4 changes: 2 additions & 2 deletions Tests/MatftTests/MathTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ final class MathTests: XCTestCase {
[3, 1, 4, -5]], mftype: .Float)

let aret = MfArray([[1.4142135, 1.0 , -Float.nan, 0.0 ],
[1.732050, 1.0 , 2.0 , -Float.nan]], mftype: .Float)
let aTret = MfArray([[1.4142135, 1.732050],
[1.7320508, 1.0 , 2.0 , -Float.nan]], mftype: .Float)
let aTret = MfArray([[1.4142135, 1.7320508],
[1.0 , 1.0 ],
[ -Float.nan, 2.0 ],
[0.0 , -Float.nan]], mftype: .Float)
Expand Down
4 changes: 2 additions & 2 deletions Tests/PerformanceTests/BoolPefTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ final class BoolPefTests: XCTestCase {
let _ = a > 0
}
/*
average: 0.007, relative standard deviation: 42.315%, values: [0.015937, 0.005105, 0.005490, 0.006968, 0.006461, 0.005805, 0.006472, 0.005748, 0.009101, 0.005661]
7.27ms
average: 0.005, relative standard deviation: 29.541%, values: [0.010153, 0.005457, 0.004667, 0.004294, 0.004840, 0.004900, 0.004404, 0.004938, 0.005424, 0.005643]
5.47ms
*/
}
}
Expand Down

0 comments on commit 823604e

Please sign in to comment.