Skip to content

Commit

Permalink
db/python/layers/sample.py reduce cognitive complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
violetbrina committed Oct 7, 2024
1 parent 163c8c6 commit 8a9e641
Showing 1 changed file with 71 additions and 81 deletions.
152 changes: 71 additions & 81 deletions db/python/layers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from db.python.layers.sequencing_group import SequencingGroupLayer
from db.python.tables.sample import SampleFilter, SampleTable
from db.python.utils import NoOpAenter, NotFoundError
from models.models.assay import AssayUpsertInternal
from models.models.project import (
FullWriteAccessRoles,
ProjectId,
ReadAccessRoles,
)
from models.models.sample import SampleInternal, SampleUpsertInternal
from models.models.sequencing_group import SequencingGroupUpsertInternal
from models.utils.sample_id_format import sample_id_format_list


Expand Down Expand Up @@ -223,10 +225,14 @@ async def upsert_sample(
with_function = (
self.connection.connection.transaction if open_transaction else NoOpAenter
)
# safely ignore nested samples here
async with with_function():
alayer = AssayLayer(self.connection)

for r in self.unwrap_nested_samples([sample]):
s = r.sample
sample_parent_id = getattr(r.parent, 'id', sample_parent_id)
sample_root_id = getattr(r.root, 'id', sample_root_id)

if not s.id:
s.id = await self.st.insert_sample(
external_ids=s.external_ids,
Expand All @@ -235,48 +241,57 @@ async def upsert_sample(
meta=s.meta,
participant_id=s.participant_id,
project=project,
sample_parent_id=r.parent.id if r.parent else sample_parent_id,
sample_root_id=r.root.id if r.root else sample_root_id,
sample_parent_id=sample_parent_id,
sample_root_id=sample_root_id,
)
else:
# Otherwise update
await self.st.update_sample(
id_=s.id, # type: ignore
id_=s.id,
external_ids=s.external_ids,
meta=s.meta,
participant_id=s.participant_id,
type_=s.type,
active=s.active,
sample_parent_id=r.parent.id if r.parent else sample_parent_id,
sample_root_id=r.root.id if r.root else sample_root_id,
sample_parent_id=sample_parent_id,
sample_root_id=sample_root_id,
)

if sample.sequencing_groups:
if process_sequencing_groups and sample.sequencing_groups:
sglayer = SequencingGroupLayer(self.connection)
for seqg in sample.sequencing_groups:
seqg.sample_id = sample.id

if process_sequencing_groups:
await sglayer.upsert_sequencing_groups(sample.sequencing_groups)
self.set_sample_ids(s.id, sample.sequencing_groups)
await sglayer.upsert_sequencing_groups(sample.sequencing_groups)

if sample.non_sequencing_assays:
alayer = AssayLayer(self.connection)
for assay in sample.non_sequencing_assays:
assay.sample_id = sample.id
if process_assays:
await alayer.upsert_assays(
sample.non_sequencing_assays, open_transaction=False
)
if process_assays and sample.non_sequencing_assays:
self.set_sample_ids(s.id, sample.non_sequencing_assays)
await alayer.upsert_assays(
sample.non_sequencing_assays, open_transaction=False
)

return sample

def set_sample_ids(
self,
sample_id: int,
internal_objs: list[SequencingGroupUpsertInternal] | list[AssayUpsertInternal],
):
"""
Set the sample id for upserting sequencing group or assay
These internal upsert models will be children of a SampleUpsertInternal
but may not have the correct sample_id set. This is to ensure that they do.
"""
for obj in internal_objs:
assert hasattr(obj, 'sample_id')
assert obj.sample_id is None or obj.sample_id == sample_id
obj.sample_id = sample_id

async def upsert_samples(
self,
samples: list[SampleUpsertInternal],
open_transaction: bool = True,
project: ProjectId = None,
) -> list[SampleUpsertInternal]:
"""Batch upsert a list of samples with sequences"""
alayer = AssayLayer(self.connection)
seqglayer: SequencingGroupLayer = SequencingGroupLayer(self.connection)

with_function = (
Expand All @@ -292,6 +307,8 @@ async def upsert_samples(

async with with_function():
# Create or update samples
sequencing_groups: list[SequencingGroupUpsertInternal] = []
assays: list[AssayUpsertInternal] = []
for sample in samples:
await self.upsert_sample(
sample,
Expand All @@ -301,21 +318,13 @@ async def upsert_samples(
open_transaction=False,
)

# Upsert all sequencing_groups (in turn relevant assays)
sequencing_groups = [
seqg for sample in samples for seqg in (sample.sequencing_groups or [])
]
if sequencing_groups:
await seqglayer.upsert_sequencing_groups(sequencing_groups)
# Collect all sequencing_groups and assays
sequencing_groups.extend(getattr(sample, 'sequencing_groups', []))
assays.extend(getattr(sample, 'non_sequencing_assays', []))

assays = [
assay
for sample in samples
for assay in (sample.non_sequencing_assays or [])
]
if assays:
alayer = AssayLayer(self.connection)
await alayer.upsert_assays(assays, open_transaction=False)
# Upsert all sequencing_groups (in turn relevant assays)
await seqglayer.upsert_sequencing_groups(sequencing_groups)
await alayer.upsert_assays(assays, open_transaction=False)

return samples

Expand All @@ -340,57 +349,38 @@ def unwrap_nested_samples(
out the insert order, keeping reference to the root, and parent.
Just keep a soft limit on the depth, as we don't want to go too deep.
NB: Opting for a non-recursive approach here, as I'm a bit afraid of recursive
Python after a weird Hail Batch thing, and sounded like a nightmare to debug
"""

retval: list[SampleLayer.UnwrappedSample] = []
seen_samples = set()
stack: list[
tuple[
SampleUpsertInternal | None,
SampleUpsertInternal | None,
SampleUpsertInternal | None,
int,
]
] = [(None, None, sample, 0) for sample in samples]

while stack:
root, parent, sample, depth = stack.pop()
if depth > max_depth:
raise SampleLayer.SampleUnwrapMaxDepthError(
f'Exceeded max depth of {max_depth} for nested samples. '
f'Parents: {parent}'
)

seen_samples = {id(s) for s in samples}
if id(sample) in seen_samples:
raise ValueError(f'Sample sample was seen in the list ({sample})')
seen_samples.add(id(sample))

rounds: list[
list[
tuple[
SampleUpsertInternal | None,
SampleUpsertInternal | None,
list[SampleUpsertInternal],
]
]
] = [[(None, None, samples)]]

round_idx = 0
while round_idx < len(rounds):
prev_round = rounds[round_idx]
new_round = []
round_idx += 1
for root, parent, nested_samples in prev_round:
for sample in nested_samples:
retval.append(
SampleLayer.UnwrappedSample(
root=root, parent=parent, sample=sample
)
)
if not sample.nested_samples:
continue

# do the seen check
for s in sample.nested_samples:
if id(s) in seen_samples:
raise ValueError(
f'Sample sample was seen in the list ({s})'
)
seen_samples.add(id(s))
new_round.append((root or sample, sample, sample.nested_samples))

if new_round:
if round_idx >= max_depth:
parents = ', '.join(str(s) for _, s, _ in new_round)
raise SampleLayer.SampleUnwrapMaxDepthError(
f'Exceeded max depth of {max_depth} for nested samples. '
f'Parents: {parents}'
)
rounds.append(new_round)
retval.append(
SampleLayer.UnwrappedSample(root=root, parent=parent, sample=sample)
)

if sample.nested_samples:
for nested_sample in sample.nested_samples:
stack.append((root or sample, sample, nested_sample, depth + 1))

return retval

Expand Down

0 comments on commit 8a9e641

Please sign in to comment.