Skip to content

Commit

Permalink
Update server feature write approval
Browse files Browse the repository at this point in the history
A write is only approved, if all callbacks approve it. If one denies it, then the write callback is denied.
  • Loading branch information
DerAndereAndi committed May 18, 2024
1 parent 0a1f9a4 commit 8a8dfc0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 18 deletions.
12 changes: 7 additions & 5 deletions api/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ type FeatureLocalInterface interface {
AddResponseCallback(msgCounterReference model.MsgCounterType, function func(msg ResponseMessage)) error
// Add a callback function to be invoked when a result message comes in for this feature
AddResultCallback(function func(msg ResponseMessage))

// Add a callback method for a server feature which is invoked to
// check wether an incoming write message shall be permitted or declined
// check wether an incoming write message shall be approved or denied
AddWriteApprovalCallback(function WriteApprovalCallbackFunc) error
// Needs to be invoked within 1 minute of WritePermissionCheckCallbackFunc of
// SetWritePermissionCheckCallback being invoked by the stack returning wether
// the remote requested write command shall be allowed or not
// This function needs to be invoked within (default) 10 seconds after the via
// AddWriteApprovalCallback defined callback is being invoked.
//
// NOTE: To approve a write, ALL callbacks need to approve the write!
//
// ErrorType.ErrorNumber should be 0 if write is allowed
// ErrorType.ErrorNumber should be 0 if write is approved
ApproveOrDenyWrite(msg *Message, err model.ErrorType)
// Overwrite the default 1 minute timeout for write approvals
SetWriteApprovalTimeout(duration time.Duration)
Expand Down
38 changes: 32 additions & 6 deletions spine/feature_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ type FeatureLocal struct {
responseMsgCallback map[model.MsgCounterType]func(result api.ResponseMessage)
responseCallbacks []func(result api.ResponseMessage)

writeTimeout time.Duration
writeApprovalCallback []api.WriteApprovalCallbackFunc
pendingWriteApprovals map[model.MsgCounterType]*time.Timer
writeTimeout time.Duration
writeApprovalCallbacks []api.WriteApprovalCallbackFunc
muxWriteReceived sync.Mutex
writeApprovalReceived map[model.MsgCounterType]int
pendingWriteApprovals map[model.MsgCounterType]*time.Timer

bindings []*model.FeatureAddressType // bindings to remote features
subscriptions []*model.FeatureAddressType // subscriptions to remote features
Expand All @@ -41,6 +43,7 @@ func NewFeatureLocal(id uint, entity api.EntityLocalInterface, ftype model.Featu
entity: entity,
functionDataMap: make(map[model.FunctionType]api.FunctionDataCmdInterface),
responseMsgCallback: make(map[model.MsgCounterType]func(result api.ResponseMessage)),
writeApprovalReceived: make(map[model.MsgCounterType]int),
pendingWriteApprovals: make(map[model.MsgCounterType]*time.Timer),
writeTimeout: defaultMaxResponseDelay,
}
Expand Down Expand Up @@ -149,7 +152,7 @@ func (r *FeatureLocal) AddWriteApprovalCallback(function api.WriteApprovalCallba
r.muxResponseCB.Lock()
defer r.muxResponseCB.Unlock()

r.writeApprovalCallback = append(r.writeApprovalCallback, function)
r.writeApprovalCallbacks = append(r.writeApprovalCallbacks, function)

return nil
}
Expand All @@ -158,7 +161,7 @@ func (r *FeatureLocal) processWriteApprovalCallbacks(msg *api.Message) {
r.muxResponseCB.Lock()
defer r.muxResponseCB.Unlock()

for _, cb := range r.writeApprovalCallback {
for _, cb := range r.writeApprovalCallbacks {
go cb(msg)
}
}
Expand Down Expand Up @@ -189,15 +192,38 @@ func (r *FeatureLocal) ApproveOrDenyWrite(msg *api.Message, err model.ErrorType)

r.muxResponseCB.Lock()
timer, ok := r.pendingWriteApprovals[*msg.RequestHeader.MsgCounter]
count := len(r.writeApprovalCallbacks)
r.muxResponseCB.Unlock()

// if there is no timer running, we are too late and error has already been sent
if !ok || timer == nil {
return
}

// do we have enough approvals?
r.muxWriteReceived.Lock()
defer r.muxWriteReceived.Unlock()
if count > 1 && err.ErrorNumber == 0 {
amount, ok := r.writeApprovalReceived[*msg.RequestHeader.MsgCounter]
if ok {
r.writeApprovalReceived[*msg.RequestHeader.MsgCounter] = amount + 1
} else {
r.writeApprovalReceived[*msg.RequestHeader.MsgCounter] = 1
}
// do we have enough approve messages, if not exit
if r.writeApprovalReceived[*msg.RequestHeader.MsgCounter] < count {
return
}
}

timer.Stop()

delete(r.writeApprovalReceived, *msg.RequestHeader.MsgCounter)

r.muxResponseCB.Lock()
defer r.muxResponseCB.Unlock()
delete(r.pendingWriteApprovals, *msg.RequestHeader.MsgCounter)

if err.ErrorNumber == 0 {
r.processWrite(msg)
return
Expand Down Expand Up @@ -465,7 +491,7 @@ func (r *FeatureLocal) HandleMessage(message *api.Message) *model.ErrorType {
}
case model.CmdClassifierTypeWrite:
// if there is a write permission check callback set, invoke this instead of directly allowing the write
if r.writeApprovalCallback != nil {
if len(r.writeApprovalCallbacks) > 0 {
r.addPendingApproval(message)
r.processWriteApprovalCallbacks(message)
} else {
Expand Down
103 changes: 96 additions & 7 deletions spine/feature_local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,11 @@ func (suite *DeviceClassificationTestSuite) Test_AddPendingApproval_Invalid() {
assert.Nil(suite.T(), err1)
}

func (suite *DeviceClassificationTestSuite) Test_Write_Callback() {
func (suite *DeviceClassificationTestSuite) Test_Write_Callback_One() {
counter := model.MsgCounterType(1)
msg := &api.Message{
RequestHeader: &model.HeaderType{
MsgCounter: util.Ptr(model.MsgCounterType(1)),
MsgCounter: util.Ptr(counter),
},
CmdClassifier: model.CmdClassifierTypeWrite,
FeatureRemote: suite.remoteSubFeature,
Expand All @@ -414,29 +415,117 @@ func (suite *DeviceClassificationTestSuite) Test_Write_Callback() {
},
}

cb := func(msg *api.Message) {
cb1 := func(msg *api.Message) {
result := model.ErrorType{
ErrorNumber: 0,
}
suite.localServerFeatureWrite.ApproveOrDenyWrite(msg, result)
}
suite.localServerFeatureWrite.AddWriteApprovalCallback(cb1)

suite.localServerFeatureWrite.AddWriteApprovalCallback(cb)
err := suite.localServerFeatureWrite.HandleMessage(msg)
assert.Nil(suite.T(), err)

// callback is called asynchronously
time.Sleep(time.Millisecond * 200)
}

func (suite *DeviceClassificationTestSuite) Test_Write_Callback_One_Fail() {
msg := &api.Message{
RequestHeader: &model.HeaderType{
MsgCounter: util.Ptr(model.MsgCounterType(1)),
},
CmdClassifier: model.CmdClassifierTypeWrite,
FeatureRemote: suite.remoteSubFeature,
Cmd: model.CmdType{
LoadControlLimitListData: &model.LoadControlLimitListDataType{},
},
}

cb1 := func(msg *api.Message) {
result := model.ErrorType{
ErrorNumber: 7,
Description: util.Ptr(model.DescriptionType("not allowed by application")),
}
suite.localServerFeatureWrite.ApproveOrDenyWrite(msg, result)
}
suite.localServerFeatureWrite.AddWriteApprovalCallback(cb1)

suite.senderMock.EXPECT().ResultError(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
err := suite.localServerFeatureWrite.HandleMessage(msg)
assert.Nil(suite.T(), err)

// callback is called asynchronously
time.Sleep(time.Millisecond * 200)
}

func (suite *DeviceClassificationTestSuite) Test_Write_Callback_Two() {
msg := &api.Message{
RequestHeader: &model.HeaderType{
MsgCounter: util.Ptr(model.MsgCounterType(1)),
},
CmdClassifier: model.CmdClassifierTypeWrite,
FeatureRemote: suite.remoteSubFeature,
Cmd: model.CmdType{
LoadControlLimitListData: &model.LoadControlLimitListDataType{},
},
}

cb = func(msg *api.Message) {
cb1 := func(msg *api.Message) {
result := model.ErrorType{
ErrorNumber: 0,
}
suite.localServerFeatureWrite.ApproveOrDenyWrite(msg, result)
}
suite.localServerFeatureWrite.AddWriteApprovalCallback(cb1)

cb2 := func(msg *api.Message) {
result := model.ErrorType{
ErrorNumber: 0,
}
suite.localServerFeatureWrite.ApproveOrDenyWrite(msg, result)
}
suite.localServerFeatureWrite.AddWriteApprovalCallback(cb2)

err := suite.localServerFeatureWrite.HandleMessage(msg)
assert.Nil(suite.T(), err)

// callback is called asynchronously
time.Sleep(time.Millisecond * 200)
}

func (suite *DeviceClassificationTestSuite) Test_Write_Callback_Two_Fail() {
msg := &api.Message{
RequestHeader: &model.HeaderType{
MsgCounter: util.Ptr(model.MsgCounterType(1)),
},
CmdClassifier: model.CmdClassifierTypeWrite,
FeatureRemote: suite.remoteSubFeature,
Cmd: model.CmdType{
LoadControlLimitListData: &model.LoadControlLimitListDataType{},
},
}

cb1 := func(msg *api.Message) {
result := model.ErrorType{
ErrorNumber: 0,
}
suite.localServerFeatureWrite.ApproveOrDenyWrite(msg, result)
}
suite.localServerFeatureWrite.AddWriteApprovalCallback(cb1)

cb2 := func(msg *api.Message) {
result := model.ErrorType{
ErrorNumber: 7,
Description: util.Ptr(model.DescriptionType("not allowed by application")),
}
suite.localServerFeatureWrite.ApproveOrDenyWrite(msg, result)
}
suite.localServerFeatureWrite.AddWriteApprovalCallback(cb2)

suite.localServerFeatureWrite.AddWriteApprovalCallback(cb)
err = suite.localServerFeatureWrite.HandleMessage(msg)
suite.senderMock.EXPECT().ResultError(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()

err := suite.localServerFeatureWrite.HandleMessage(msg)
assert.Nil(suite.T(), err)

// callback is called asynchronously
Expand Down

0 comments on commit 8a8dfc0

Please sign in to comment.