diff --git a/atomate/utils/database.py b/atomate/utils/database.py index 33db98959..1c1fcbd88 100644 --- a/atomate/utils/database.py +++ b/atomate/utils/database.py @@ -57,7 +57,6 @@ def __init__( "compress" : Whether compression is used "endpoint_url" : the url used to access the S3 store maggma_store_prefix: when using maggma stores, you can set the prefix string. - **kwargs: """ if maggma_store_kwargs is None: diff --git a/atomate/vasp/firetasks/electrode_tasks.py b/atomate/vasp/firetasks/electrode_tasks.py index d94103d79..032904e63 100644 --- a/atomate/vasp/firetasks/electrode_tasks.py +++ b/atomate/vasp/firetasks/electrode_tasks.py @@ -98,21 +98,12 @@ def run_task(self, fw_spec): chgcar = chgcar["aeccar0"] + chgcar["aeccar2"] cia = ChargeInsertionAnalyzer(chgcar, **cia_kwargs) - cia.get_labels() + avg_chg_groups = list(cia.filter_and_group()) insert_sites = [] - seent = set() - - cia._extrema_df.sort_values(by=["avg_charge_den"], inplace=True) - for itr, li_site in cia._extrema_df.iterrows(): - if len(insert_sites) >= attempt_insertions: - break - li_site = cia._extrema_df.iloc[itr] - lab = li_site["site_label"] - if lab not in seent: - insert_sites.append([li_site["a"], li_site["b"], li_site["c"]]) - seent.add(lab) - + for _, g in avg_chg_groups[:attempt_insertions]: + insert_sites.append(g[ 0]) + logger.info( f"Found {len(insert_sites)} insertion sites for task : {base_task_id}" ) diff --git a/setup.py b/setup.py index 1f028bae2..b47d45780 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ "moto>=4.1.14", "pytest-cov>=4.1.0", "pytest>=7.4.0", + "pymatgen-analysis-defects>=2022.8.18", ], }, classifiers=[