diff --git a/src/common/BaseLightAccount.sol b/src/common/BaseLightAccount.sol index 473f8e5..bcf5e48 100644 --- a/src/common/BaseLightAccount.sol +++ b/src/common/BaseLightAccount.sol @@ -21,6 +21,7 @@ abstract contract BaseLightAccount is BaseAccount, TokenCallbackHandler, UUPSUpg } error ArrayLengthMismatch(); + error CreateFailed(); error InvalidSignatureType(); error NotAuthorized(address caller); error ZeroAddressNotAllowed(); @@ -75,6 +76,71 @@ abstract contract BaseLightAccount is BaseAccount, TokenCallbackHandler, UUPSUpg } } + /// @notice Creates a contract. + /// @param value The value to send to the new contract constructor. + /// @param initCode The initCode to deploy. + /// @return createdAddr The created contract address. + /// + /// @dev Assembly procedure: + /// 1. Load the free memory pointer. + /// 2. Get the initCode length. + /// 3. Copy the initCode from callata to memory at the free memory pointer. + /// 4. Create the contract. + /// 5. If creation failed (the address returned is zero), revert with CreateFailed(). + function performCreate(uint256 value, bytes calldata initCode) + external + payable + virtual + onlyAuthorized + returns (address createdAddr) + { + assembly ("memory-safe") { + let fmp := mload(0x40) + let len := initCode.length + calldatacopy(fmp, initCode.offset, len) + + createdAddr := create(value, fmp, len) + + if iszero(createdAddr) { + mstore(0x00, 0x7e16b8cd) + revert(0x1c, 0x04) + } + } + } + + /// @notice Creates a contract using create2 deterministic deployment. + /// @param value The value to send to the new contract constructor. + /// @param initCode The initCode to deploy. + /// @param salt The salt to use for the create2 operation. + /// @return createdAddr The created contract address. + /// + /// @dev Assembly procedure: + /// 1. Load the free memory pointer. + /// 2. Get the initCode length. + /// 3. Copy the initCode from callata to memory at the free memory pointer. + /// 4. Create the contract using Create2 with the passed salt parameter. + /// 5. If creation failed (the address returned is zero), revert with CreateFailed(). + function performCreate2(uint256 value, bytes calldata initCode, bytes32 salt) + external + payable + virtual + onlyAuthorized + returns (address createdAddr) + { + assembly ("memory-safe") { + let fmp := mload(0x40) + let len := initCode.length + calldatacopy(fmp, initCode.offset, len) + + createdAddr := create2(value, fmp, len, salt) + + if iszero(createdAddr) { + mstore(0x00, 0x7e16b8cd) + revert(0x1c, 0x04) + } + } + } + /// @notice Deposit more funds for this account in the entry point. function addDeposit() external payable { entryPoint().depositTo{value: msg.value}(address(this)); @@ -122,11 +188,25 @@ abstract contract BaseLightAccount is BaseAccount, TokenCallbackHandler, UUPSUpg return success ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED; } + /// @dev Assembly procedure: + /// 1. Execute the call, passing: + /// 1. The gas + /// 2. The target address + /// 3. The call value + /// 4. The pointer to the start location of the callData in memory + /// 5. The length of the calldata + /// 2. If the call failed, bubble up the revert reason by doing the following: + /// 1. Load the free memory pointer + /// 2. Copy the return data (which is the revert reason) to memory at the free memory pointer + /// 3. Revert with the copied return data function _call(address target, uint256 value, bytes memory data) internal { - (bool success, bytes memory result) = target.call{value: value}(data); - if (!success) { - assembly ("memory-safe") { - revert(add(result, 32), mload(result)) + assembly ("memory-safe") { + let succ := call(gas(), target, value, add(data, 0x20), mload(data), 0x00, 0) + + if iszero(succ) { + let fmp := mload(0x40) + returndatacopy(fmp, 0x00, returndatasize()) + revert(fmp, returndatasize()) } } } diff --git a/test/LightAccount.t.sol b/test/LightAccount.t.sol index 96e24fa..cbf9385 100644 --- a/test/LightAccount.t.sol +++ b/test/LightAccount.t.sol @@ -477,6 +477,77 @@ contract LightAccountTest is Test { assertEq(initialized, 1); } + function testRevertCreate_IncorrectCaller() public { + vm.expectRevert(abi.encodeWithSelector(BaseLightAccount.NotAuthorized.selector, address(this))); + account.performCreate(0, hex"1234"); + } + + function testRevertCreate_CreateFailed() public { + vm.prank(eoaAddress); + vm.expectRevert(BaseLightAccount.CreateFailed.selector); + account.performCreate(0, hex"3d3dfd"); + } + + function testRevertCreate2_IncorrectCaller() public { + vm.expectRevert(abi.encodeWithSelector(BaseLightAccount.NotAuthorized.selector, address(this))); + account.performCreate2(0, hex"1234", bytes32(0)); + } + + function testRevertCreate2_CreateFailed() public { + vm.prank(eoaAddress); + vm.expectRevert(BaseLightAccount.CreateFailed.selector); + account.performCreate2(0, hex"3d3dfd", bytes32(0)); + } + + function testCreate() public { + vm.prank(eoaAddress); + address expected = vm.computeCreateAddress(address(account), vm.getNonce(address(account))); + + address returnedAddress = + account.performCreate(0, abi.encodePacked(type(LightAccount).creationCode, abi.encode(address(entryPoint)))); + assertEq(address(LightAccount(payable(expected)).entryPoint()), address(entryPoint)); + assertEq(returnedAddress, expected); + } + + function testCreateValue() public { + vm.prank(eoaAddress); + address expected = vm.computeCreateAddress(address(account), vm.getNonce(address(account))); + + uint256 value = 1 ether; + deal(address(account), value); + + address returnedAddress = account.performCreate(value, ""); + assertEq(returnedAddress, expected); + assertEq(returnedAddress.balance, value); + } + + function testCreate2() public { + vm.prank(eoaAddress); + bytes memory initCode = abi.encodePacked(type(LightAccount).creationCode, abi.encode(address(entryPoint))); + bytes32 initCodeHash = keccak256(initCode); + bytes32 salt = bytes32(hex"04546b"); + address expected = vm.computeCreate2Address(salt, initCodeHash, address(account)); + + address returnedAddress = account.performCreate2(0, initCode, salt); + assertEq(address(LightAccount(payable(expected)).entryPoint()), address(entryPoint)); + assertEq(returnedAddress, expected); + } + + function testCreate2Value() public { + vm.prank(eoaAddress); + bytes memory initCode = ""; + bytes32 initCodeHash = keccak256(initCode); + bytes32 salt = bytes32(hex"04546b"); + address expected = vm.computeCreate2Address(salt, initCodeHash, address(account)); + + uint256 value = 1 ether; + deal(address(account), value); + + address returnedAddress = account.performCreate2(value, initCode, salt); + assertEq(returnedAddress, expected); + assertEq(returnedAddress.balance, value); + } + function _useContractOwner() internal { vm.prank(eoaAddress); account.transferOwnership(address(contractOwner));