diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 6e63d59c..e2ba839f 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -28,6 +28,8 @@ from driftpy.constants.numeric_constants import ( QUOTE_SPOT_MARKET_INDEX, ) +from driftpy.enforcers.position_enforcer import PositionEnforcer +from driftpy.enforcers.sequence_enforcer import SequenceEnforcer from driftpy.decode.utils import decode_name from driftpy.drift_user import DriftUser from driftpy.accounts import * @@ -92,6 +94,7 @@ def __init__( market_lookup_table: Optional[Pubkey] = None, jito_params: Optional[JitoParams] = None, enforce_tx_sequencing: bool = False, + enforce_position_sizing: bool = False, ): """Initializes the drift client object @@ -155,29 +158,13 @@ def __init__( self.tx_version = tx_version if tx_version is not None else Legacy - self.enforce_tx_sequencing = enforce_tx_sequencing - if self.enforce_tx_sequencing is True: - file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json") - with file.open() as f: - raw = file.read_text() - idl = Idl.from_json(raw) - - provider = Provider(connection, wallet, opts) - self.sequence_enforcer_pid = ( - SEQUENCER_PROGRAM_ID - if env == "mainnet" - else DEVNET_SEQUENCER_PROGRAM_ID - ) - self.sequence_enforcer_program = Program( - idl, - self.sequence_enforcer_pid, - provider, - ) - self.sequence_number_by_subaccount = {} - self.sequence_bump_by_subaccount = {} - self.sequence_initialized_by_subaccount = {} - self.sequence_address_by_subaccount = {} - self.resetting_sequence = False + self.sequence_enforcer = None + if enforce_tx_sequencing is True: + self.sequence_enforcer = SequenceEnforcer(self.connection, self.wallet, env) + + self.position_enforcer = None + if enforce_position_sizing is True: + self.position_enforcer = PositionEnforcer(self.connection, self.wallet) if jito_params is not None: from driftpy.tx.jito_tx_sender import JitoTxSender @@ -199,8 +186,8 @@ def __init__( async def subscribe(self): await self.account_subscriber.subscribe() - if self.enforce_tx_sequencing: - await self.load_sequence_info() + if self.sequence_enforcer: + await self.sequence_enforcer.load_sequence_info(self.sub_account_ids) for sub_account_id in self.sub_account_ids: await self.add_user(sub_account_id) @@ -357,12 +344,18 @@ async def send_ixs( subaccount = sequencer_subaccount or self.active_sub_account_id if ( - self.enforce_tx_sequencing - and self.sequence_initialized_by_subaccount[subaccount] - and not self.resetting_sequence + self.sequence_enforcer + and self.sequence_enforcer.get_sequence_init_for_subaccount(subaccount) + is True + and not self.sequence_enforcer.get_resetting_sequence() ): - sequence_instruction = self.get_check_and_set_sequence_number_ix( - self.sequence_number_by_subaccount[subaccount], subaccount + sequence_instruction = ( + self.sequence_enforcer.get_check_and_set_sequence_number_ix( + self.sequence_enforcer.get_sequence_number_for_subaccount( + subaccount + ), + subaccount, + ) ) ixs.insert(len(compute_unit_instructions), sequence_instruction) @@ -886,6 +879,7 @@ async def place_perp_order( self, order_params: OrderParams, sub_account_id: int = None, + expected_size: Optional[int] = None, ): tx_sig_and_slot = await self.send_ixs( [ @@ -2846,106 +2840,32 @@ def get_update_prelaunch_oracle_ix(self, market_index: int): ) async def init_sequence(self, subaccount: int = 0) -> Signature: + if self.sequence_enforcer is None: + raise Exception("Sequence enforcer is not initialized") try: - sig = (await self.send_ixs([self.get_sequence_init_ix(subaccount)])).tx_sig - self.sequence_initialized_by_subaccount[subaccount] = True + sig = ( + await self.send_ixs( + [self.sequence_enforcer.get_sequence_init_ix(subaccount)] + ) + ).tx_sig + self.sequence_enforcer.set_sequence_init_for_subaccount(subaccount, True) return sig except Exception as e: print(f"WARNING: failed to initialize sequence: {e}") - def get_sequence_init_ix(self, subaccount: int = 0) -> Instruction: - if self.enforce_tx_sequencing is False: - raise ValueError("tx sequencing is disabled") - return self.sequence_enforcer_program.instruction["initialize"]( - self.sequence_bump_by_subaccount[subaccount], - str(subaccount), - ctx=Context( - accounts={ - "sequence_account": self.sequence_address_by_subaccount[subaccount], - "authority": self.wallet.payer.pubkey(), - "system_program": ID, - } - ), - ) - async def reset_sequence_number( self, sequence_number: int = 0, subaccount: int = 0 ) -> Signature: + if self.sequence_enforcer is None: + raise Exception("Sequence enforcer is not initialized") try: - ix = self.get_reset_sequence_number_ix(sequence_number) - self.resetting_sequence = True + ix = self.sequence_enforcer.get_reset_sequence_number_ix(sequence_number) + self.sequence_enforcer.set_resetting_sequence(True) sig = (await self.send_ixs(ix)).tx_sig - self.resetting_sequence = False - self.sequence_number_by_subaccount[subaccount] = sequence_number + self.sequence_enforcer.set_resetting_sequence(False) + self.sequence_enforcer.set_sequence_number_for_subaccount( + subaccount, sequence_number + ) return sig except Exception as e: print(f"WARNING: failed to reset sequence number: {e}") - - def get_reset_sequence_number_ix( - self, sequence_number: int, subaccount: int = 0 - ) -> Instruction: - if self.enforce_tx_sequencing is False: - raise ValueError("tx sequencing is disabled") - return self.sequence_enforcer_program.instruction["reset_sequence_number"]( - sequence_number, - ctx=Context( - accounts={ - "sequence_account": self.sequence_address_by_subaccount[subaccount], - "authority": self.wallet.payer.pubkey(), - } - ), - ) - - def get_check_and_set_sequence_number_ix( - self, sequence_number: Optional[int] = None, subaccount: int = 0 - ): - if self.enforce_tx_sequencing is False: - raise ValueError("tx sequencing is disabled") - sequence_number = ( - sequence_number or self.sequence_number_by_subaccount[subaccount] - ) - - if ( - sequence_number < self.sequence_number_by_subaccount[subaccount] - 1 - ): # we increment after creating the ix, so we check - 1 - print( - f"WARNING: sequence number {sequence_number} < last used {self.sequence_number_by_subaccount[subaccount] - 1}" - ) - - ix = self.sequence_enforcer_program.instruction[ - "check_and_set_sequence_number" - ]( - sequence_number, - ctx=Context( - accounts={ - "sequence_account": self.sequence_address_by_subaccount[subaccount], - "authority": self.wallet.payer.pubkey(), - } - ), - ) - - self.sequence_number_by_subaccount[subaccount] += 1 - return ix - - async def load_sequence_info(self): - for subaccount in self.sub_account_ids: - address, bump = get_sequencer_public_key_and_bump( - self.sequence_enforcer_pid, self.wallet.payer.pubkey(), subaccount - ) - try: - sequence_account_raw = await self.sequence_enforcer_program.account[ - "SequenceAccount" - ].fetch(address) - except anchorpy.error.AccountDoesNotExistError as e: - self.sequence_address_by_subaccount[subaccount] = address - self.sequence_number_by_subaccount[subaccount] = 1 - self.sequence_bump_by_subaccount[subaccount] = bump - self.sequence_initialized_by_subaccount[subaccount] = False - continue - sequence_account = cast(SequenceAccount, sequence_account_raw) - self.sequence_number_by_subaccount[subaccount] = ( - sequence_account.sequence_num + 1 - ) - self.sequence_bump_by_subaccount[subaccount] = bump - self.sequence_initialized_by_subaccount[subaccount] = True - self.sequence_address_by_subaccount[subaccount] = address diff --git a/src/driftpy/enforcers/position_enforcer.py b/src/driftpy/enforcers/position_enforcer.py new file mode 100644 index 00000000..295ad078 --- /dev/null +++ b/src/driftpy/enforcers/position_enforcer.py @@ -0,0 +1,68 @@ +from typing import Union +from driftpy.types import ( + OrderParams, + SpotPosition, + PerpPosition, + MarketType, + UserAccount, + is_variant, +) + + +class PositionEnforcer: + def __init__(self): + pass + + def set_and_check_order_params( + self, expected_size: int, order_params: OrderParams, user: UserAccount + ) -> OrderParams: + size_adjustment = self._get_size_adjustment( + expected_size, order_params.market_index, order_params.market_type, user + ) + order_params.base_asset_amount = max( + order_params.base_asset_amount + size_adjustment, 0 + ) + if order_params.base_asset_amount == 0: + print("WARNING: PositionEnforcer has reduced order size to ZERO.") + return order_params + + def _get_size_adjustment( + self, + expected_size: int, + market_index: int, + market_type: MarketType, + user: UserAccount, + ) -> int: + position: Union[SpotPosition, PerpPosition] + if is_variant(market_type, "Perp"): + position = next( + ( + pos + for pos in user.perp_positions + if pos.market_index == market_index + ), + None, + ) + if position is None: + raise Exception( + f"Position market_index: {market_index} market_type: {market_type} not found" + ) + + difference = position.base_asset_amount - expected_size + return difference * -1 # positive if too short, negative if too long + else: + position = next( + ( + pos + for pos in user.spot_positions + if pos.market_index == market_index + ), + None, + ) + if position is None: + raise Exception( + f"Position market_index: {market_index} market_type: {market_type} not found" + ) + + difference = position.scaled_balance - expected_size + return difference * -1 # positive if too short, negative if too long diff --git a/src/driftpy/enforcers/sequence_enforcer.py b/src/driftpy/enforcers/sequence_enforcer.py new file mode 100644 index 00000000..d31c13a5 --- /dev/null +++ b/src/driftpy/enforcers/sequence_enforcer.py @@ -0,0 +1,155 @@ +from typing import cast, Optional +import anchorpy +import driftpy + +from pathlib import Path +from anchorpy import Program, Idl, Provider, Wallet, Context +from solders.pubkey import Pubkey +from solders.instruction import Instruction +from solders.system_program import ID as SystemProgram + +from solana.rpc.async_api import AsyncClient +from driftpy.addresses import get_sequencer_public_key_and_bump +from driftpy.constants.config import SEQUENCER_PROGRAM_ID, DEVNET_SEQUENCER_PROGRAM_ID +from driftpy.types import SequenceAccount + + +class SequenceEnforcer: + def __init__(self, connection: AsyncClient, wallet: Wallet, env): + self.sequence_number_by_subaccount = {} + self.sequence_init_by_subaccount = {} + self.sequence_address_by_subaccount = {} + self.sequence_bump_by_subaccount = {} + self.resetting_sequence = False + self.wallet = wallet + file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json") + with file.open(): + raw = file.read_text() + idl = Idl.from_json(raw) + + provider = Provider(connection, wallet) + self.sequence_enforcer_pid = SEQUENCER_PROGRAM_ID if env == "mainnet" else DEVNET_SEQUENCER_PROGRAM_ID + self.sequence_enforcer_program = Program( + idl, + self.sequence_enforcer_pid, + provider, + ) + + async def load_sequence_info(self, subaccounts: list): + for subaccount in subaccounts: + address, bump = get_sequencer_public_key_and_bump( + self.sequence_enforcer_pid, self.wallet.payer.pubkey(), subaccount + ) + try: + sequence_account_raw = await self.sequence_enforcer_program.account[ + "SequenceAccount" + ].fetch(address) + except anchorpy.error.AccountDoesNotExistError: + self.set_sequence_address_for_subaccount(subaccount, address) + self.set_sequence_bump_for_subaccount(subaccount, bump) + self.set_sequence_number_for_subaccount(subaccount, 1) + self.set_sequence_init_for_subaccount(subaccount, False) + continue + sequence_account = cast(SequenceAccount, sequence_account_raw) + self.set_sequence_address_for_subaccount(subaccount, address) + self.set_sequence_bump_for_subaccount(subaccount, bump) + self.set_sequence_number_for_subaccount( + subaccount, sequence_account.sequence_num + 1 + ) + self.set_sequence_init_for_subaccount(subaccount, True) + + def get_sequence_init_ix(self, subaccount: int = 0) -> Instruction: + return self.sequence_enforcer_program.instruction["initialize"]( + self.get_sequence_bump_for_subaccount(subaccount), + str(subaccount), + ctx=Context( + accounts={ + "sequence_account": self.get_sequence_address_for_subaccount( + subaccount + ), + "authority": self.wallet.payer.pubkey(), + "system_program": SystemProgram, + } + ), + ) + + def get_reset_sequence_number_ix( + self, sequence_number: int, subaccount: int = 0 + ) -> Instruction: + return self.sequence_enforcer_program.instruction["reset_sequence_number"]( + sequence_number, + ctx=Context( + accounts={ + "sequence_account": self.get_sequence_address_for_subaccount( + subaccount + ), + "authority": self.wallet.payer.pubkey(), + } + ), + ) + + def get_check_and_set_sequence_number_ix( + self, sequence_number: Optional[int] = None, subaccount: int = 0 + ): + current_for_subaccount = self.get_sequence_number_for_subaccount(subaccount) + sequence_number = sequence_number or current_for_subaccount + + if ( + sequence_number < current_for_subaccount - 1 + ): # we increment after creating the ix, so we check - 1 + print( + f"WARNING: sequence number {sequence_number} < last used {current_for_subaccount - 1}" + ) + + ix = self.sequence_enforcer_program.instruction[ + "check_and_set_sequence_number" + ]( + sequence_number, + ctx=Context( + accounts={ + "sequence_account": self.get_sequence_address_for_subaccount( + subaccount + ), + "authority": self.wallet.payer.pubkey(), + } + ), + ) + + if sequence_number - current_for_subaccount > 0: + self.set_sequence_number_for_subaccount(subaccount, sequence_number) + else: + self.set_sequence_number_for_subaccount( + subaccount, current_for_subaccount + 1 + ) + + return ix + + def get_sequence_number_for_subaccount(self, subaccount: int) -> Optional[int]: + return self.sequence_number_by_subaccount.get(subaccount, None) + + def get_sequence_init_for_subaccount(self, subaccount: int) -> bool: + return self.sequence_init_by_subaccount.get(subaccount, False) + + def get_sequence_address_for_subaccount(self, subaccount: int) -> Optional[Pubkey]: + return self.sequence_address_by_subaccount.get(subaccount, None) + + def get_sequence_bump_for_subaccount(self, subaccount: int) -> Optional[int]: + return self.sequence_bump_by_subaccount.get(subaccount, None) + + def set_sequence_number_for_subaccount(self, subaccount: int, sequence_number: int): + self.sequence_number_by_subaccount[subaccount] = sequence_number + + def set_sequence_init_for_subaccount(self, subaccount: int, sequence_init: bool): + self.sequence_init_by_subaccount[subaccount] = sequence_init + + def set_sequence_address_for_subaccount(self, subaccount: int, address: Pubkey): + self.sequence_address_by_subaccount[subaccount] = address + + def set_sequence_bump_for_subaccount(self, subaccount: int, bump: int): + self.sequence_bump_by_subaccount[subaccount] = bump + + def get_resetting_sequence(self): + return self.resetting_sequence + + def set_resetting_sequence(self, resetting_sequence: bool): + self.resetting_sequence = resetting_sequence