diff --git a/modules/apps/27-interchain-accounts/keeper/relay.go b/modules/apps/27-interchain-accounts/keeper/relay.go index 2e3a896858a..bbd15f32f46 100644 --- a/modules/apps/27-interchain-accounts/keeper/relay.go +++ b/modules/apps/27-interchain-accounts/keeper/relay.go @@ -75,26 +75,19 @@ func (k Keeper) createOutgoingPacket( return packet.Sequence, nil } -func (k Keeper) AuthenticateTx(ctx sdk.Context, msgs []sdk.Msg, portId string) error { - seen := map[string]bool{} - var signers []sdk.AccAddress - for _, msg := range msgs { - for _, addr := range msg.GetSigners() { - if !seen[addr.String()] { - signers = append(signers, addr) - seen[addr.String()] = true - } - } - } - - interchainAccountAddr, found := k.GetInterchainAccountAddress(ctx, portId) +// AuthenticateTx ensures the provided msgs contain the correct interchain account signer address retrieved +// from state using the provided controller port identifier +func (k Keeper) AuthenticateTx(ctx sdk.Context, msgs []sdk.Msg, portID string) error { + interchainAccountAddr, found := k.GetInterchainAccountAddress(ctx, portID) if !found { - return sdkerrors.ErrUnauthorized + return sdkerrors.Wrapf(types.ErrInterchainAccountNotFound, "failed to retrieve interchain account on port %s", portID) } - for _, signer := range signers { - if interchainAccountAddr != signer.String() { - return sdkerrors.ErrUnauthorized + for _, msg := range msgs { + for _, signer := range msg.GetSigners() { + if interchainAccountAddr != signer.String() { + return sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "unexpected signer address: expected %s, got %s", interchainAccountAddr, signer.String()) + } } } @@ -102,33 +95,26 @@ func (k Keeper) AuthenticateTx(ctx sdk.Context, msgs []sdk.Msg, portId string) e } func (k Keeper) executeTx(ctx sdk.Context, sourcePort, destPort, destChannel string, msgs []sdk.Msg) error { - err := k.AuthenticateTx(ctx, msgs, sourcePort) - if err != nil { + if err := k.AuthenticateTx(ctx, msgs, sourcePort); err != nil { return err } for _, msg := range msgs { - err := msg.ValidateBasic() - if err != nil { + if err := msg.ValidateBasic(); err != nil { return err } } - cacheContext, writeFn := ctx.CacheContext() + // CacheContext returns a new context with the multi-store branched into a cached storage object + // writeCache is called only if all msgs succeed, performing state transitions atomically + cacheCtx, writeCache := ctx.CacheContext() for _, msg := range msgs { - _, msgErr := k.executeMsg(cacheContext, msg) - if msgErr != nil { - err = msgErr - break + if _, err := k.executeMsg(cacheCtx, msg); err != nil { + return err } } - if err != nil { - return err - } - - // Write the state transitions if all handlers succeed. - writeFn() + writeCache() return nil } @@ -158,8 +144,7 @@ func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet) error return err } - err = k.executeTx(ctx, packet.SourcePort, packet.DestinationPort, packet.DestinationChannel, msgs) - if err != nil { + if err = k.executeTx(ctx, packet.SourcePort, packet.DestinationPort, packet.DestinationChannel, msgs); err != nil { return err } diff --git a/modules/apps/27-interchain-accounts/types/packet.go b/modules/apps/27-interchain-accounts/types/packet.go index 8342f911bbf..3c5223f4390 100644 --- a/modules/apps/27-interchain-accounts/types/packet.go +++ b/modules/apps/27-interchain-accounts/types/packet.go @@ -22,7 +22,6 @@ func (iapd InterchainAccountPacketData) ValidateBasic() error { if len(iapd.Memo) > MaxMemoCharLength { return sdkerrors.Wrapf(ErrInvalidOutgoingData, "packet data memo cannot be greater than %d characters", MaxMemoCharLength) } - // TODO: add type validation when data type enum supports unspecified type return nil }