diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 2448c483108..ee1e9380c68 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -207,7 +207,7 @@ func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string // WriteUpgradeCancelChannel writes a channel which has canceled the upgrade process.Auxiliary upgrade state is // also deleted. -func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string) { +func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string, newUpgradeSequence uint64) { defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-cancel") upgrade, found := k.GetUpgrade(ctx, portID, channelID) @@ -222,7 +222,7 @@ func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID str previousState := channel.State - k.restoreChannel(ctx, portID, channelID, channel) + k.restoreChannel(ctx, portID, channelID, newUpgradeSequence, channel) k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.OPEN.String()) emitChannelUpgradeCancelEvent(ctx, portID, channelID, channel, upgrade) @@ -354,9 +354,6 @@ func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, err return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than or equal to current sequence (%d)", counterpartySequence, currentSequence) } - channel.UpgradeSequence = errorReceipt.Sequence + 1 - k.SetChannel(ctx, portID, channelID, channel) - return nil } @@ -715,7 +712,9 @@ func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err erro return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - k.restoreChannel(ctx, portID, channelID, channel) + // the channel upgrade sequence has already been updated in ChannelUpgradeTry, so we can pass + // its updated value. + k.restoreChannel(ctx, portID, channelID, channel.UpgradeSequence, channel) // in the case of application callbacks, the error may not be an upgrade error. // in this case we need to construct one in order to write the error receipt. @@ -728,16 +727,14 @@ func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err erro return err } - // TODO: callback execution - // cbs.OnChanUpgradeRestore() - return nil } // restoreChannel will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted. -func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, currentChannel types.Channel) { +func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgradeSequence uint64, currentChannel types.Channel) { currentChannel.State = types.OPEN currentChannel.FlushStatus = types.NOTINFLUSH + currentChannel.UpgradeSequence = upgradeSequence k.SetChannel(ctx, portID, channelID, currentChannel) diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index bb152f17af3..0dc7c0c71c8 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -1183,8 +1183,6 @@ func (suite *KeeperTestSuite) TestChanUpgradeCancel() { expPass := tc.expError == nil if expPass { suite.Require().NoError(err) - channel := path.EndpointA.GetChannel() - suite.Require().Equal(errorReceipt.Sequence+1, channel.UpgradeSequence, "upgrade sequence should be incremented") } else { suite.Require().ErrorIs(err, tc.expError) } diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 33df77fb5b9..b8d78997e0b 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -820,5 +820,30 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M // ChannelUpgradeCancel defines a rpc handler method for MsgChannelUpgradeCancel. func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.MsgChannelUpgradeCancel) (*channeltypes.MsgChannelUpgradeCancelResponse, error) { - return nil, nil + ctx := sdk.UnwrapSDKContext(goCtx) + + module, _, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId) + if err != nil { + ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrap(err, "could not retrieve module from port-id")) + return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") + } + + cbs, ok := k.Router.GetRoute(module) + if !ok { + ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) + } + + if err := k.ChannelKeeper.ChanUpgradeCancel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt, msg.ProofErrorReceipt, msg.ProofHeight); err != nil { + ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", err.Error()) + return nil, errorsmod.Wrap(err, "channel upgrade cancel failed") + } + + cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId) + + k.ChannelKeeper.WriteUpgradeCancelChannel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt.Sequence) + + ctx.Logger().Info("channel upgrade cancel succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId) + + return &channeltypes.MsgChannelUpgradeCancelResponse{}, nil } diff --git a/modules/core/keeper/msg_server_test.go b/modules/core/keeper/msg_server_test.go index d2e23861d7c..7cdc0a36c10 100644 --- a/modules/core/keeper/msg_server_test.go +++ b/modules/core/keeper/msg_server_test.go @@ -924,3 +924,132 @@ func (suite *KeeperTestSuite) TestChannelUpgradeTry() { }) } } + +func (suite *KeeperTestSuite) TestChannelUpgradeCancel() { + var ( + path *ibctesting.Path + msg *channeltypes.MsgChannelUpgradeCancel + ) + + cases := []struct { + name string + malleate func() + expErr error + }{ + { + name: "success", + malleate: func() {}, + expErr: nil, + }, + { + name: "invalid proof", + malleate: func() { + msg.ProofErrorReceipt = []byte("invalid proof") + }, + expErr: commitmenttypes.ErrInvalidProof, + }, + { + name: "invalid error receipt sequence", + malleate: func() { + const invalidSequence = 0 + + errorReceipt, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + errorReceipt.Sequence = invalidSequence + + // overwrite the error receipt with an invalid sequence. + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + + // ensure that the error receipt is committed to state. + suite.coordinator.CommitBlock(suite.chainB) + suite.Require().NoError(path.EndpointA.UpdateClient()) + + // retrieve the error receipt proof and proof height. + errorReceiptProof, proofHeight := path.EndpointB.QueryProof(host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)) + + // provide a valid proof of the error receipt with an invalid sequence. + msg.ErrorReceipt.Sequence = invalidSequence + msg.ProofErrorReceipt = errorReceiptProof + msg.ProofHeight = proofHeight + }, + expErr: channeltypes.ErrInvalidUpgradeSequence, + }, + { + name: "capability not found", + malleate: func() { + msg.ChannelId = ibctesting.InvalidID + }, + expErr: capabilitytypes.ErrCapabilityNotFound, + }, + } + + for _, tc := range cases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + // configure the channel upgrade version on testing endpoints + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + // fetch the previous channel when it is in the INITUPGRADE state. + prevChannel := path.EndpointA.GetChannel() + + // cause the upgrade to fail on chain b so an error receipt is written. + suite.chainB.GetSimApp().IBCMockModule.IBCApp.OnChanUpgradeTry = func( + ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string, + ) (string, error) { + return "", fmt.Errorf("mock app callback failed") + } + + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight := path.EndpointB.QueryProof(upgradeErrorReceiptKey) + + errorReceipt, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + msg = &channeltypes.MsgChannelUpgradeCancel{ + PortId: path.EndpointA.ChannelConfig.PortID, + ChannelId: path.EndpointA.ChannelID, + ErrorReceipt: errorReceipt, + ProofErrorReceipt: errorReceiptProof, + ProofHeight: proofHeight, + Signer: suite.chainA.SenderAccount.GetAddress().String(), + } + + tc.malleate() + + res, err := suite.chainA.GetSimApp().GetIBCKeeper().ChannelUpgradeCancel(suite.chainA.GetContext(), msg) + + expPass := tc.expErr == nil + if expPass { + suite.Require().NoError(err) + channel := path.EndpointA.GetChannel() + suite.Require().Equal(prevChannel.Version, channel.Version, "channel version should be reverted") + suite.Require().Equalf(channeltypes.OPEN, channel.State, "channel state should be %s", channeltypes.OPEN.String()) + suite.Require().Equalf(channeltypes.NOTINFLUSH, channel.FlushStatus, "channel flush status should be %s", channeltypes.NOTINFLUSH.String()) + suite.Require().Equal(errorReceipt.Sequence, channel.UpgradeSequence, "channel upgrade sequence should be set to error receipt sequence") + } else { + suite.Require().Nil(res) + suite.Require().ErrorIs(err, tc.expErr) + + channel := path.EndpointA.GetChannel() + + suite.Require().Equal(prevChannel.Version, channel.Version, "channel version should not be changed") + suite.Require().Equalf(prevChannel.State, channel.State, "channel state should be %s", prevChannel.State.String()) + suite.Require().Equalf(prevChannel.FlushStatus, channel.FlushStatus, "channel flush status should be %s", prevChannel.FlushStatus.String()) + suite.Require().Equal(prevChannel.UpgradeSequence, channel.UpgradeSequence, "channel upgrade sequence should not incremented") + } + }) + } +} diff --git a/testing/mock/ibc_module.go b/testing/mock/ibc_module.go index 4df7925c227..3a02aee3670 100644 --- a/testing/mock/ibc_module.go +++ b/testing/mock/ibc_module.go @@ -14,6 +14,13 @@ import ( "github.com/cosmos/ibc-go/v7/modules/core/exported" ) +// applicationCallbackError is a custom error type that will be unique for testing purposes. +type applicationCallbackError struct{} + +func (e applicationCallbackError) Error() string { + return "mock application callback failed" +} + // IBCModule implements the ICS26 callbacks for testing/mock. type IBCModule struct { appModule *AppModule diff --git a/testing/mock/mock.go b/testing/mock/mock.go index e50585eb49d..050c62823ef 100644 --- a/testing/mock/mock.go +++ b/testing/mock/mock.go @@ -37,6 +37,9 @@ var ( MockAckCanaryCapabilityName = "mock acknowledgement canary capability name" MockTimeoutCanaryCapabilityName = "mock timeout canary capability name" UpgradeVersion = fmt.Sprintf("%s-v2", Version) + // MockApplicationCallbackError should be returned when an application callback should fail. It is possible to + // test that this error was returned using ErrorIs. + MockApplicationCallbackError error = &applicationCallbackError{} ) var _ porttypes.IBCModule = IBCModule{}