From 7defd2ade2aff9f14bcb5a6fb2648de540694659 Mon Sep 17 00:00:00 2001 From: soundsonacid Date: Thu, 23 May 2024 12:23:25 -0500 Subject: [PATCH 1/3] factor out sequence enforcer logic --- src/driftpy/drift_client.py | 153 +++++--------------- src/driftpy/enforcers/position_enforcer.py | 0 src/driftpy/enforcers/sequence_enforcer.py | 155 +++++++++++++++++++++ 3 files changed, 188 insertions(+), 120 deletions(-) create mode 100644 src/driftpy/enforcers/position_enforcer.py create mode 100644 src/driftpy/enforcers/sequence_enforcer.py diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 6e63d59c..fb218444 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -28,6 +28,7 @@ from driftpy.constants.numeric_constants import ( QUOTE_SPOT_MARKET_INDEX, ) +from driftpy.enforcers.sequence_enforcer import SequenceEnforcer from driftpy.decode.utils import decode_name from driftpy.drift_user import DriftUser from driftpy.accounts import * @@ -155,29 +156,9 @@ def __init__( self.tx_version = tx_version if tx_version is not None else Legacy - self.enforce_tx_sequencing = enforce_tx_sequencing - if self.enforce_tx_sequencing is True: - file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json") - with file.open() as f: - raw = file.read_text() - idl = Idl.from_json(raw) - - provider = Provider(connection, wallet, opts) - self.sequence_enforcer_pid = ( - SEQUENCER_PROGRAM_ID - if env == "mainnet" - else DEVNET_SEQUENCER_PROGRAM_ID - ) - self.sequence_enforcer_program = Program( - idl, - self.sequence_enforcer_pid, - provider, - ) - self.sequence_number_by_subaccount = {} - self.sequence_bump_by_subaccount = {} - self.sequence_initialized_by_subaccount = {} - self.sequence_address_by_subaccount = {} - self.resetting_sequence = False + self.sequence_enforcer = None + if enforce_tx_sequencing is True: + self.sequence_enforcer = SequenceEnforcer(self.connection, self.wallet) if jito_params is not None: from driftpy.tx.jito_tx_sender import JitoTxSender @@ -199,8 +180,8 @@ def __init__( async def subscribe(self): await self.account_subscriber.subscribe() - if self.enforce_tx_sequencing: - await self.load_sequence_info() + if self.sequence_enforcer: + await self.sequence_enforcer.load_sequence_info(self.sub_account_ids) for sub_account_id in self.sub_account_ids: await self.add_user(sub_account_id) @@ -357,12 +338,18 @@ async def send_ixs( subaccount = sequencer_subaccount or self.active_sub_account_id if ( - self.enforce_tx_sequencing - and self.sequence_initialized_by_subaccount[subaccount] - and not self.resetting_sequence + self.sequence_enforcer + and self.sequence_enforcer.get_sequence_init_for_subaccount(subaccount) + is True + and not self.sequence_enforcer.get_resetting_sequence() ): - sequence_instruction = self.get_check_and_set_sequence_number_ix( - self.sequence_number_by_subaccount[subaccount], subaccount + sequence_instruction = ( + self.sequence_enforcer.get_check_and_set_sequence_number_ix( + self.sequence_enforcer.get_sequence_number_for_subaccount( + subaccount + ), + subaccount, + ) ) ixs.insert(len(compute_unit_instructions), sequence_instruction) @@ -2846,106 +2833,32 @@ def get_update_prelaunch_oracle_ix(self, market_index: int): ) async def init_sequence(self, subaccount: int = 0) -> Signature: + if self.sequence_enforcer is None: + raise Exception("Sequence enforcer is not initialized") try: - sig = (await self.send_ixs([self.get_sequence_init_ix(subaccount)])).tx_sig - self.sequence_initialized_by_subaccount[subaccount] = True + sig = ( + await self.send_ixs( + [self.sequence_enforcer.get_sequence_init_ix(subaccount)] + ) + ).tx_sig + self.sequence_enforcer.set_sequence_init_for_subaccount(subaccount, True) return sig except Exception as e: print(f"WARNING: failed to initialize sequence: {e}") - def get_sequence_init_ix(self, subaccount: int = 0) -> Instruction: - if self.enforce_tx_sequencing is False: - raise ValueError("tx sequencing is disabled") - return self.sequence_enforcer_program.instruction["initialize"]( - self.sequence_bump_by_subaccount[subaccount], - str(subaccount), - ctx=Context( - accounts={ - "sequence_account": self.sequence_address_by_subaccount[subaccount], - "authority": self.wallet.payer.pubkey(), - "system_program": ID, - } - ), - ) - async def reset_sequence_number( self, sequence_number: int = 0, subaccount: int = 0 ) -> Signature: + if self.sequence_enforcer is None: + raise Exception("Sequence enforcer is not initialized") try: - ix = self.get_reset_sequence_number_ix(sequence_number) - self.resetting_sequence = True + ix = self.sequence_enforcer.get_reset_sequence_number_ix(sequence_number) + self.sequence_enforcer.set_resetting_sequence(True) sig = (await self.send_ixs(ix)).tx_sig - self.resetting_sequence = False - self.sequence_number_by_subaccount[subaccount] = sequence_number + self.sequence_enforcer.set_resetting_sequence(False) + self.sequence_enforcer.set_sequence_number_for_subaccount( + subaccount, sequence_number + ) return sig except Exception as e: print(f"WARNING: failed to reset sequence number: {e}") - - def get_reset_sequence_number_ix( - self, sequence_number: int, subaccount: int = 0 - ) -> Instruction: - if self.enforce_tx_sequencing is False: - raise ValueError("tx sequencing is disabled") - return self.sequence_enforcer_program.instruction["reset_sequence_number"]( - sequence_number, - ctx=Context( - accounts={ - "sequence_account": self.sequence_address_by_subaccount[subaccount], - "authority": self.wallet.payer.pubkey(), - } - ), - ) - - def get_check_and_set_sequence_number_ix( - self, sequence_number: Optional[int] = None, subaccount: int = 0 - ): - if self.enforce_tx_sequencing is False: - raise ValueError("tx sequencing is disabled") - sequence_number = ( - sequence_number or self.sequence_number_by_subaccount[subaccount] - ) - - if ( - sequence_number < self.sequence_number_by_subaccount[subaccount] - 1 - ): # we increment after creating the ix, so we check - 1 - print( - f"WARNING: sequence number {sequence_number} < last used {self.sequence_number_by_subaccount[subaccount] - 1}" - ) - - ix = self.sequence_enforcer_program.instruction[ - "check_and_set_sequence_number" - ]( - sequence_number, - ctx=Context( - accounts={ - "sequence_account": self.sequence_address_by_subaccount[subaccount], - "authority": self.wallet.payer.pubkey(), - } - ), - ) - - self.sequence_number_by_subaccount[subaccount] += 1 - return ix - - async def load_sequence_info(self): - for subaccount in self.sub_account_ids: - address, bump = get_sequencer_public_key_and_bump( - self.sequence_enforcer_pid, self.wallet.payer.pubkey(), subaccount - ) - try: - sequence_account_raw = await self.sequence_enforcer_program.account[ - "SequenceAccount" - ].fetch(address) - except anchorpy.error.AccountDoesNotExistError as e: - self.sequence_address_by_subaccount[subaccount] = address - self.sequence_number_by_subaccount[subaccount] = 1 - self.sequence_bump_by_subaccount[subaccount] = bump - self.sequence_initialized_by_subaccount[subaccount] = False - continue - sequence_account = cast(SequenceAccount, sequence_account_raw) - self.sequence_number_by_subaccount[subaccount] = ( - sequence_account.sequence_num + 1 - ) - self.sequence_bump_by_subaccount[subaccount] = bump - self.sequence_initialized_by_subaccount[subaccount] = True - self.sequence_address_by_subaccount[subaccount] = address diff --git a/src/driftpy/enforcers/position_enforcer.py b/src/driftpy/enforcers/position_enforcer.py new file mode 100644 index 00000000..e69de29b diff --git a/src/driftpy/enforcers/sequence_enforcer.py b/src/driftpy/enforcers/sequence_enforcer.py new file mode 100644 index 00000000..e869f5f8 --- /dev/null +++ b/src/driftpy/enforcers/sequence_enforcer.py @@ -0,0 +1,155 @@ +from typing import cast, Optional +import anchorpy +import driftpy + +from pathlib import Path +from anchorpy import Program, Idl, Provider, Wallet, Context +from solders.pubkey import Pubkey +from solders.instruction import Instruction +from solders.system_program import ID as SystemProgram + +from solana.rpc.async_api import AsyncClient +from driftpy.addresses import get_sequencer_public_key_and_bump +from driftpy.constants.config import SEQUENCER_PROGRAM_ID +from driftpy.types import SequenceAccount + + +class SequenceEnforcer: + def __init__(self, connection: AsyncClient, wallet: Wallet): + self.sequence_number_by_subaccount = {} + self.sequence_init_by_subaccount = {} + self.sequence_address_by_subaccount = {} + self.sequence_bump_by_subaccount = {} + self.resetting_sequence = False + self.wallet = wallet + file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json") + with file.open(): + raw = file.read_text() + idl = Idl.from_json(raw) + + provider = Provider(connection, wallet) + self.sequence_enforcer_pid = SEQUENCER_PROGRAM_ID + self.sequence_enforcer_program = Program( + idl, + self.sequence_enforcer_pid, + provider, + ) + + async def load_sequence_info(self, subaccounts: list): + for subaccount in subaccounts: + address, bump = get_sequencer_public_key_and_bump( + self.sequence_enforcer_pid, self.wallet.payer.pubkey(), subaccount + ) + try: + sequence_account_raw = await self.sequence_enforcer_program.account[ + "SequenceAccount" + ].fetch(address) + except anchorpy.error.AccountDoesNotExistError: + self.set_sequence_address_for_subaccount(subaccount, address) + self.set_sequence_bump_for_subaccount(subaccount, bump) + self.set_sequence_number_for_subaccount(subaccount, 1) + self.set_sequence_init_for_subaccount(subaccount, False) + continue + sequence_account = cast(SequenceAccount, sequence_account_raw) + self.set_sequence_address_for_subaccount(subaccount, address) + self.set_sequence_bump_for_subaccount(subaccount, bump) + self.set_sequence_number_for_subaccount( + subaccount, sequence_account.sequence_num + 1 + ) + self.set_sequence_init_for_subaccount(subaccount, True) + + def get_sequence_init_ix(self, subaccount: int = 0) -> Instruction: + return self.sequence_enforcer_program.instruction["initialize"]( + self.get_sequence_bump_for_subaccount(subaccount), + str(subaccount), + ctx=Context( + accounts={ + "sequence_account": self.get_sequence_address_for_subaccount( + subaccount + ), + "authority": self.wallet.payer.pubkey(), + "system_program": SystemProgram, + } + ), + ) + + def get_reset_sequence_number_ix( + self, sequence_number: int, subaccount: int = 0 + ) -> Instruction: + return self.sequence_enforcer_program.instruction["reset_sequence_number"]( + sequence_number, + ctx=Context( + accounts={ + "sequence_account": self.get_sequence_address_for_subaccount( + subaccount + ), + "authority": self.wallet.payer.pubkey(), + } + ), + ) + + def get_check_and_set_sequence_number_ix( + self, sequence_number: Optional[int] = None, subaccount: int = 0 + ): + current_for_subaccount = self.get_sequence_number_for_subaccount(subaccount) + sequence_number = sequence_number or current_for_subaccount + + if ( + sequence_number < current_for_subaccount - 1 + ): # we increment after creating the ix, so we check - 1 + print( + f"WARNING: sequence number {sequence_number} < last used {current_for_subaccount - 1}" + ) + + ix = self.sequence_enforcer_program.instruction[ + "check_and_set_sequence_number" + ]( + sequence_number, + ctx=Context( + accounts={ + "sequence_account": self.get_sequence_address_for_subaccount( + subaccount + ), + "authority": self.wallet.payer.pubkey(), + } + ), + ) + + if sequence_number - current_for_subaccount > 0: + self.set_sequence_number_for_subaccount(subaccount, sequence_number) + else: + self.set_sequence_number_for_subaccount( + subaccount, current_for_subaccount + 1 + ) + + return ix + + def get_sequence_number_for_subaccount(self, subaccount: int) -> Optional[int]: + return self.sequence_number_by_subaccount.get(subaccount, None) + + def get_sequence_init_for_subaccount(self, subaccount: int) -> bool: + return self.sequence_init_by_subaccount.get(subaccount, False) + + def get_sequence_address_for_subaccount(self, subaccount: int) -> Optional[Pubkey]: + return self.sequence_address_by_subaccount.get(subaccount, None) + + def get_sequence_bump_for_subaccount(self, subaccount: int) -> Optional[int]: + return self.sequence_bump_by_subaccount.get(subaccount, None) + + def set_sequence_number_for_subaccount(self, subaccount: int, sequence_number: int): + self.sequence_number_by_subaccount[subaccount] = sequence_number + + def set_sequence_init_for_subaccount(self, subaccount: int, sequence_init: bool): + self.sequence_init_by_subaccount[subaccount] = sequence_init + + def set_sequence_address_for_subaccount(self, subaccount: int, address: Pubkey): + self.sequence_address_by_subaccount[subaccount] = address + + def set_sequence_bump_for_subaccount(self, subaccount: int, bump: int): + self.sequence_bump_by_subaccount[subaccount] = bump + + def get_resetting_sequence(self): + return self.resetting_sequence + + def set_resetting_sequence(self, resetting_sequence: bool): + self.resetting_sequence = resetting_sequence From 5d38bb2a873752673e2420e601d6f93f38a105d1 Mon Sep 17 00:00:00 2001 From: soundsonacid Date: Thu, 23 May 2024 13:56:13 -0500 Subject: [PATCH 2/3] very basic position enforcer --- src/driftpy/drift_client.py | 7 +++ src/driftpy/enforcers/position_enforcer.py | 68 ++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index fb218444..03d7d517 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -28,6 +28,7 @@ from driftpy.constants.numeric_constants import ( QUOTE_SPOT_MARKET_INDEX, ) +from driftpy.enforcers.position_enforcer import PositionEnforcer from driftpy.enforcers.sequence_enforcer import SequenceEnforcer from driftpy.decode.utils import decode_name from driftpy.drift_user import DriftUser @@ -93,6 +94,7 @@ def __init__( market_lookup_table: Optional[Pubkey] = None, jito_params: Optional[JitoParams] = None, enforce_tx_sequencing: bool = False, + enforce_position_sizing: bool = False, ): """Initializes the drift client object @@ -160,6 +162,10 @@ def __init__( if enforce_tx_sequencing is True: self.sequence_enforcer = SequenceEnforcer(self.connection, self.wallet) + self.position_enforcer = None + if enforce_position_sizing is True: + self.position_enforcer = PositionEnforcer(self.connection, self.wallet) + if jito_params is not None: from driftpy.tx.jito_tx_sender import JitoTxSender @@ -873,6 +879,7 @@ async def place_perp_order( self, order_params: OrderParams, sub_account_id: int = None, + expected_size: Optional[int] = None, ): tx_sig_and_slot = await self.send_ixs( [ diff --git a/src/driftpy/enforcers/position_enforcer.py b/src/driftpy/enforcers/position_enforcer.py index e69de29b..295ad078 100644 --- a/src/driftpy/enforcers/position_enforcer.py +++ b/src/driftpy/enforcers/position_enforcer.py @@ -0,0 +1,68 @@ +from typing import Union +from driftpy.types import ( + OrderParams, + SpotPosition, + PerpPosition, + MarketType, + UserAccount, + is_variant, +) + + +class PositionEnforcer: + def __init__(self): + pass + + def set_and_check_order_params( + self, expected_size: int, order_params: OrderParams, user: UserAccount + ) -> OrderParams: + size_adjustment = self._get_size_adjustment( + expected_size, order_params.market_index, order_params.market_type, user + ) + order_params.base_asset_amount = max( + order_params.base_asset_amount + size_adjustment, 0 + ) + if order_params.base_asset_amount == 0: + print("WARNING: PositionEnforcer has reduced order size to ZERO.") + return order_params + + def _get_size_adjustment( + self, + expected_size: int, + market_index: int, + market_type: MarketType, + user: UserAccount, + ) -> int: + position: Union[SpotPosition, PerpPosition] + if is_variant(market_type, "Perp"): + position = next( + ( + pos + for pos in user.perp_positions + if pos.market_index == market_index + ), + None, + ) + if position is None: + raise Exception( + f"Position market_index: {market_index} market_type: {market_type} not found" + ) + + difference = position.base_asset_amount - expected_size + return difference * -1 # positive if too short, negative if too long + else: + position = next( + ( + pos + for pos in user.spot_positions + if pos.market_index == market_index + ), + None, + ) + if position is None: + raise Exception( + f"Position market_index: {market_index} market_type: {market_type} not found" + ) + + difference = position.scaled_balance - expected_size + return difference * -1 # positive if too short, negative if too long From 5084b8a49acb9c77fd48fb6d02332b2d9d1329f8 Mon Sep 17 00:00:00 2001 From: soundsonacid Date: Thu, 23 May 2024 14:50:06 -0500 Subject: [PATCH 3/3] add devnet sequencer program config --- src/driftpy/drift_client.py | 2 +- src/driftpy/enforcers/sequence_enforcer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 03d7d517..e2ba839f 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -160,7 +160,7 @@ def __init__( self.sequence_enforcer = None if enforce_tx_sequencing is True: - self.sequence_enforcer = SequenceEnforcer(self.connection, self.wallet) + self.sequence_enforcer = SequenceEnforcer(self.connection, self.wallet, env) self.position_enforcer = None if enforce_position_sizing is True: diff --git a/src/driftpy/enforcers/sequence_enforcer.py b/src/driftpy/enforcers/sequence_enforcer.py index e869f5f8..d31c13a5 100644 --- a/src/driftpy/enforcers/sequence_enforcer.py +++ b/src/driftpy/enforcers/sequence_enforcer.py @@ -10,12 +10,12 @@ from solana.rpc.async_api import AsyncClient from driftpy.addresses import get_sequencer_public_key_and_bump -from driftpy.constants.config import SEQUENCER_PROGRAM_ID +from driftpy.constants.config import SEQUENCER_PROGRAM_ID, DEVNET_SEQUENCER_PROGRAM_ID from driftpy.types import SequenceAccount class SequenceEnforcer: - def __init__(self, connection: AsyncClient, wallet: Wallet): + def __init__(self, connection: AsyncClient, wallet: Wallet, env): self.sequence_number_by_subaccount = {} self.sequence_init_by_subaccount = {} self.sequence_address_by_subaccount = {} @@ -28,7 +28,7 @@ def __init__(self, connection: AsyncClient, wallet: Wallet): idl = Idl.from_json(raw) provider = Provider(connection, wallet) - self.sequence_enforcer_pid = SEQUENCER_PROGRAM_ID + self.sequence_enforcer_pid = SEQUENCER_PROGRAM_ID if env == "mainnet" else DEVNET_SEQUENCER_PROGRAM_ID self.sequence_enforcer_program = Program( idl, self.sequence_enforcer_pid,