Source code for ifermi.kpoints

"""k-point manipulation functions."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np

from ifermi.defaults import KTOL

if TYPE_CHECKING:
    from pymatgen.electronic_structure.bandstructure import BandStructure

__all__ = [
    "kpoints_to_first_bz",
    "kpoints_from_bandstructure",
    "get_kpoint_mesh_dim",
    "get_kpoint_spacing",
    "sort_boltztrap_to_spglib",
]


[docs] def kpoints_to_first_bz(kpoints: np.ndarray, tol: float = KTOL) -> np.ndarray: """Translate fractional k-points to the first Brillouin zone. I.e. all k-points will have fractional coordinates: -0.5 <= fractional coordinates < 0.5 Args: kpoints: A (n, 3) float array of the k-points in fractional coordinates. tol: Tolerance for treating two k-points as equivalent. Returns: A (n, 3) float array of the translated k-points. """ kp = kpoints - np.round(kpoints) # account for small rounding errors for 0.5 round_dp = int(np.log10(1 / tol)) krounded = np.round(kp, round_dp) kp[krounded == -0.5] = 0.5 return kp
[docs] def get_kpoint_mesh_dim(kpoints: np.ndarray, tol: float = KTOL) -> tuple[int, int, int]: """Get the k-point mesh dimensions. Args: kpoints: A (n, 3) float array of the k-points in fractional coordinates. tol: Tolerance for treating two k-points as equivalent. Returns: A (3, ) int array of the k-point mesh dimensions. """ round_dp = int(np.log10(1 / tol)) round_kpoints = np.round(kpoints, round_dp) nx = len(np.unique(round_kpoints[:, 0])) ny = len(np.unique(round_kpoints[:, 1])) nz = len(np.unique(round_kpoints[:, 2])) return nx, ny, nz
[docs] def sort_boltztrap_to_spglib(kpoints: np.ndarray) -> np.ndarray: """Get an index array that sorts the k-points from BoltzTraP2 to the spglib order. Args: kpoints: A (n, 3) float array of the k-points in fractional coordinates. Returns: A (n, ) int array of the sort order. """ sort_idx = np.lexsort( ( kpoints[:, 2], kpoints[:, 2] < 0, kpoints[:, 1], kpoints[:, 1] < 0, kpoints[:, 0], kpoints[:, 0] < 0, ) ) boltztrap_kpoints = kpoints[sort_idx] return np.lexsort( ( boltztrap_kpoints[:, 0], boltztrap_kpoints[:, 0] < 0, boltztrap_kpoints[:, 1], boltztrap_kpoints[:, 1] < 0, boltztrap_kpoints[:, 2], boltztrap_kpoints[:, 2] < 0, ) )
[docs] def get_kpoint_spacing(kpoints: np.ndarray) -> np.ndarray: """Get the spacing between fractional k-points. Args: kpoints: A (n, 3) float array of the k-points in fractional coordinates. Returns: A (3, ) float array of the spacing along each reciprocal lattice direction. """ kpoints = kpoints.round(8) unique_a = np.unique(kpoints[:, 0]) unique_b = np.unique(kpoints[:, 1]) unique_c = np.unique(kpoints[:, 2]) diff_a = np.diff(unique_a) diff_b = np.diff(unique_b) diff_c = np.diff(unique_c) if not ( np.allclose(diff_a - diff_a[0], 0, atol=1e-7) and np.allclose(diff_b - diff_b[0], 0, atol=1e-7) and np.allclose(diff_c - diff_c[0], 0, atol=1e-7) ): warnings.warn("k-point mesh is not uniform", stacklevel=2) return np.array([diff_a[0], diff_b[0], diff_c[0]])
[docs] def kpoints_from_bandstructure( bandstructure: BandStructure, cartesian: bool = False ) -> np.ndarray: """Extract the k-points from a band structure. Args: bandstructure: A band structure object. cartesian: Whether to return the k-points in cartesian coordinates. Returns: A (n, 3) float array of the k-points. """ if cartesian: kpoints = np.array([k.cart_coords for k in bandstructure.kpoints]) else: kpoints = np.array([k.frac_coords for k in bandstructure.kpoints]) return kpoints