Skip to content

Commit

Permalink
Merge pull request #227 from SMTG-Bham/fixes
Browse files Browse the repository at this point in the history
Fixes from #216
  • Loading branch information
utf authored Oct 11, 2023
2 parents 9d6b00e + 4ea323b commit 0d7361f
Show file tree
Hide file tree
Showing 18 changed files with 81 additions and 171 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
python -m pip install -e '.[tests]'
- name: Test
run: pytest
run: pytest
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
rev: 22.6.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"h5py",
"pymatgen>=2020.10.20",
"phonopy>=2.1.3",
"matplotlib",
"matplotlib>=3.2.0",
"seekpath",
"castepxbin<1.0",
"colormath",
Expand Down
34 changes: 10 additions & 24 deletions sumo/cli/bandplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files

import matplotlib as mpl
from pymatgen.electronic_structure.bandstructure import (
get_reconstructed_band_structure,
)
from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure
from pymatgen.electronic_structure.core import Spin
from pymatgen.io.vasp.outputs import BSVasprun

Expand Down Expand Up @@ -394,9 +393,7 @@ def bandplot(
else:
logging.info(f"Found PDOS file {pdos_file}")
else:
logging.info(
f"Cell file {cell_file} does not exist, cannot plot PDOS."
)
logging.info(f"Cell file {cell_file} does not exist, cannot plot PDOS.")

dos, pdos = read_castep_dos(
dos_file,
Expand Down Expand Up @@ -620,8 +617,7 @@ def _get_parser():
"-c",
"--code",
default="vasp",
help="Electronic structure code (default: vasp)."
'"questaal" also supported.',
help="Electronic structure code (default: vasp)." '"questaal" also supported.',
)
parser.add_argument(
"-p", "--prefix", metavar="P", help="prefix for the files generated"
Expand Down Expand Up @@ -762,24 +758,20 @@ def _get_parser():
"--orbitals",
type=_el_orb,
metavar="O",
help=(
"orbitals to split into lm-decomposed "
'contributions (e.g. "Ru.d")'
),
help="orbitals to split into lm-decomposed contributions (e.g. 'Ru.d')",
)
parser.add_argument(
"--atoms",
type=_atoms,
metavar="A",
help=('atoms to include (e.g. "O.1.2.3,Ru.1.2.3")'),
help='atoms to include (e.g. "O.1.2.3,Ru.1.2.3")',
)
parser.add_argument(
"--spin",
type=str,
default=None,
help=(
"select only one spin channel for a "
"spin-polarised calculation "
"select only one spin channel for a spin-polarised calculation "
"(options: up, 1; down, -1)"
),
)
Expand Down Expand Up @@ -829,9 +821,7 @@ def _get_parser():
parser.add_argument(
"--height", type=float, default=None, help="height of the graph"
)
parser.add_argument(
"--width", type=float, default=None, help="width of the graph"
)
parser.add_argument("--width", type=float, default=None, help="width of the graph")
parser.add_argument(
"--ymin", type=float, default=-6.0, help="minimum energy on the y-axis"
)
Expand Down Expand Up @@ -883,18 +873,14 @@ def main():
logging.getLogger("").addHandler(console)

if args.config is None:
config_path = os.path.join(
ilr_files("sumo.plotting"), "orbital_colours.conf"
)
config_path = ilr_files("sumo.plotting") / "orbital_colours.conf"
else:
config_path = args.config
colours = configparser.ConfigParser()
colours.read(os.path.abspath(config_path))

warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings(
"ignore", category=UnicodeWarning, module="matplotlib"
)
warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

bandplot(
Expand Down
27 changes: 8 additions & 19 deletions sumo/cli/dosplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import matplotlib as mpl
import numpy as np

try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
Expand Down Expand Up @@ -441,9 +442,7 @@ def _get_parser():
"--code",
default="vasp",
metavar="C",
help=(
'Input file format: "vasp" (vasprun.xml) or ' '"questaal" (opt.ext)'
),
help='Input file format: "vasp" (vasprun.xml) or "questaal" (opt.ext)',
)
parser.add_argument(
"-p", "--prefix", metavar="P", help="prefix for the files generated"
Expand All @@ -463,25 +462,21 @@ def _get_parser():
"--orbitals",
type=_el_orb,
metavar="O",
help=(
"orbitals to split into lm-decomposed "
'contributions (e.g. "Ru.d")'
),
help="orbitals to split into lm-decomposed contributions (e.g. 'Ru.d')",
)
parser.add_argument(
"-a",
"--atoms",
type=_atoms,
metavar="A",
help=('atoms to include (e.g. "O.1.2.3,Ru.1.2.3")'),
help='atoms to include (e.g. "O.1.2.3,Ru.1.2.3")',
)
parser.add_argument(
"--spin",
type=str,
default=None,
help=(
"select one spin channel only for a "
"spin-polarised calculation "
"select one spin channel only for a spin-polarised calculation "
"(options: up, 1; down, -1)"
),
)
Expand Down Expand Up @@ -560,9 +555,7 @@ def _get_parser():
parser.add_argument(
"--height", type=float, default=None, help="height of the graph"
)
parser.add_argument(
"--width", type=float, default=None, help="width of the graph"
)
parser.add_argument("--width", type=float, default=None, help="width of the graph")
parser.add_argument(
"--xmin", type=float, default=-6.0, help="minimum energy on the x-axis"
)
Expand Down Expand Up @@ -634,18 +627,14 @@ def main():
logging.getLogger("").addHandler(console)

if args.config is None:
config_path = os.path.join(
ilr_files("sumo.plotting"), "orbital_colours.conf"
)
config_path = ilr_files("sumo.plotting") / "orbital_colours.conf"
else:
config_path = args.config
colours = configparser.ConfigParser()
colours.read(os.path.abspath(config_path))

warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings(
"ignore", category=UnicodeWarning, module="matplotlib"
)
warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

if args.zero_energy is not None:
Expand Down
2 changes: 1 addition & 1 deletion sumo/io/castep.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def labels_from_cell(cell_file, phonon=False):
line = f.readline() # Skip past block start line
while blockend.match(line.lower()) is None:
# Do not parse break lines
if 'break' not in line.lower():
if "break" not in line.lower():
kpt = tuple(map(float, line.split()[:3]))
if len(line.split()) > 3:
label = line.split()[-1]
Expand Down
43 changes: 13 additions & 30 deletions sumo/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
Subpackage providing helper functions for generating publication ready plots.
"""
from functools import wraps
import os
from functools import wraps

import matplotlib.pyplot
import numpy as np
Expand All @@ -19,15 +19,11 @@

colour_cache = {}

sumo_base_style = os.path.join(ilr_files("sumo.plotting"), "sumo_base.mplstyle")
sumo_dos_style = os.path.join(ilr_files("sumo.plotting"), "sumo_dos.mplstyle")
sumo_bs_style = os.path.join(ilr_files("sumo.plotting"), "sumo_bs.mplstyle")
sumo_phonon_style = os.path.join(
ilr_files("sumo.plotting"), "sumo_phonon.mplstyle"
)
sumo_optics_style = os.path.join(
ilr_files("sumo.plotting"), "sumo_optics.mplstyle"
)
sumo_base_style = ilr_files("sumo.plotting") / "sumo_base.mplstyle"
sumo_dos_style = ilr_files("sumo.plotting") / "sumo_dos.mplstyle"
sumo_bs_style = ilr_files("sumo.plotting") / "sumo_bs.mplstyle"
sumo_phonon_style = ilr_files("sumo.plotting") / "sumo_phonon.mplstyle"
sumo_optics_style = ilr_files("sumo.plotting") / "sumo_optics.mplstyle"


def styled_plot(*style_sheets):
Expand All @@ -47,9 +43,7 @@ def styled_plot(*style_sheets):

def decorator(get_plot):
@wraps(get_plot)
def wrapper(
*args, fonts=None, style=None, no_base_style=False, **kwargs
):
def wrapper(*args, fonts=None, style=None, no_base_style=False, **kwargs):
if no_base_style:
list_style = []
else:
Expand All @@ -62,9 +56,7 @@ def wrapper(
list_style += [style]

if fonts is not None:
list_style += [
{"font.family": "sans-serif", "font.sans-serif": fonts}
]
list_style += [{"font.family": "sans-serif", "font.sans-serif": fonts}]

matplotlib.pyplot.style.use(list_style)
return get_plot(*args, **kwargs)
Expand Down Expand Up @@ -277,9 +269,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):
"xyz": XYZColor,
}
if colorspace not in list(colorspace_mapping.keys()):
raise ValueError(
f"colorspace must be one of {colorspace_mapping.keys()}"
)
raise ValueError(f"colorspace must be one of {colorspace_mapping.keys()}")

colorspace = colorspace_mapping[colorspace]

Expand All @@ -290,19 +280,13 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):

# now convert to the colorspace basis for interpolation
basis1 = np.array(
convert_color(
color1_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color1_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis2 = np.array(
convert_color(
color2_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color2_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis3 = np.array(
convert_color(
color3_rgb, colorspace, target_illuminant="d50"
).get_value_tuple()
convert_color(color3_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)

# ensure weights is a numpy array
Expand All @@ -317,8 +301,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):

# convert colors to RGB
rgb_colors = [
convert_color(colorspace(*c), sRGBColor).get_value_tuple()
for c in colors
convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in colors
]

# ensure all rgb values are less than 1 (sometimes issues in interpolation
Expand Down
2 changes: 0 additions & 2 deletions sumo/plotting/optics_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import scipy.constants as scpc
from matplotlib import rcParams
from matplotlib.font_manager import FontProperties, findfont
from matplotlib.ticker import AutoMinorLocator, FuncFormatter, MaxNLocator

from sumo.plotting import (
Expand Down Expand Up @@ -242,7 +241,6 @@ def get_plot(
ax.set_ylim(ymin, ymax)

if spectrum_key == "absorption":
font = findfont(FontProperties(family=["sans-serif"]))
ax.yaxis.set_major_formatter(
FuncFormatter(curry_power_tick(times_sign=r"\times"))
)
Expand Down
7 changes: 2 additions & 5 deletions sumo/symmetry/brad_crack_kpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

from json import load as load_json
import os

import numpy as np

Expand Down Expand Up @@ -66,9 +65,7 @@ def __init__(self, structure, symprec=1e-3, spg=None):
spg_symbol = self.spg_symbol
lattice_type = self.lattice_type

bravais = self._get_bravais_lattice(
spg_symbol, lattice_type, a, b, c, unique
)
bravais = self._get_bravais_lattice(spg_symbol, lattice_type, a, b, c, unique)
self._kpath = self._get_bradcrack_data(bravais)

@staticmethod
Expand All @@ -85,7 +82,7 @@ def _get_bradcrack_data(bravais):
'path': [['\Gamma', 'X', ..., 'P'], ['H', 'N', ...]]}
"""
json_file = os.path.join(ilr_files("sumo.symmetry"), "bradcrack.json")
json_file = ilr_files("sumo.symmetry") / "bradcrack.json"
with open(json_file) as f:
bradcrack_data = load_json(f)
return bradcrack_data[bravais]
Expand Down
16 changes: 5 additions & 11 deletions tests/tests_electronic_structure/test_optics.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import json
import unittest
import os
import unittest

try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files

import numpy as np
from numpy.testing import assert_almost_equal
from pymatgen.io.vasp import Vasprun

from sumo.electronic_structure.optics import (
calculate_dielectric_properties,
kkr,
)
from sumo.electronic_structure.optics import calculate_dielectric_properties, kkr


class AbsorptionTestCase(unittest.TestCase):
def setUp(self):
diel_path = os.path.join(
ilr_files("tests"), "data", "Ge", "ge_diel.json"
)
diel_path = os.path.join(ilr_files("tests"), "data", "Ge", "ge_diel.json")
with open(diel_path) as f:
self.ge_diel = json.load(f)

Expand All @@ -35,9 +31,7 @@ def test_absorption(self):
self.ge_diel,
{"absorption"},
)
self.assertIsNone(
assert_almost_equal(properties["absorption"], self.ge_abs)
)
self.assertIsNone(assert_almost_equal(properties["absorption"], self.ge_abs))


class KramersKronigTestCase(unittest.TestCase):
Expand Down
Loading

0 comments on commit 0d7361f

Please sign in to comment.