Source code for ifermi.plot

"""Tools to plot FermiSurface and FermiSlice objects."""

from __future__ import annotations

import os
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import numpy as np
from matplotlib.colors import Colormap, Normalize
from monty.dev import requires
from pymatgen.electronic_structure.core import Spin

from ifermi.defaults import AZIMUTH, COLORMAP, ELEVATION, SCALE, SYMPREC, VECTOR_SPACING
from ifermi.surface import FermiSurface

if TYPE_CHECKING:
    from collections.abc import Collection
    from pathlib import Path

    from ifermi.slice import FermiSlice

try:
    import mayavi.mlab as mlab
except ImportError:
    mlab = False

try:
    import kaleido
except ImportError:
    kaleido = False

try:
    import crystal_toolkit
except ImportError:
    crystal_toolkit = False


# define plotly default styles
_plotly_scene = {
    "xaxis": {
        "backgroundcolor": "rgb(255, 255, 255)",
        "title": "",
        "showgrid": False,
        "zeroline": False,
        "showline": False,
        "ticks": "",
        "showticklabels": False,
    },
    "yaxis": {
        "backgroundcolor": "rgb(255, 255, 255)",
        "title": "",
        "showgrid": False,
        "zeroline": False,
        "showline": False,
        "ticks": "",
        "showticklabels": False,
    },
    "zaxis": {
        "backgroundcolor": "rgb(255, 255, 255)",
        "title": "",
        "showgrid": False,
        "zeroline": False,
        "showline": False,
        "ticks": "",
        "showticklabels": False,
    },
    "aspectmode": "data",
}
_plotly_bz_style = {"line": {"color": "black", "width": 3}}
_plotly_sym_pt_style = {"marker": {"size": 6, "color": "black"}}
_plotly_sym_label_style = {
    "xshift": 15,
    "yshift": 15,
    "showarrow": False,
    "font": {"size": 20, "color": "black"},
}
_plotly_cbar_style = {"lenmode": "fraction", "len": 0.5, "tickfont": {"size": 15}}

# define mayavi default styles
_mayavi_sym_label_style = {
    "color": (0, 0, 0),
    "scale": 0.1,
    "orientation": (90.0, 0.0, 0.0),
}
_mayavi_rs_style = {
    "color": (0.0, 0.0, 0.0),
    "tube_radius": 0.005,
    "representation": "surface",
}

# define matplotlib default styles
_mpl_cbar_style = {"shrink": 0.5}
_mpl_bz_style = {"linewidth": 1, "color": "k"}
_mpl_arrow_style = {
    "angles": "xy",
    "scale_units": "xy",
    "scale": 1,
    "zorder": 10,
    "units": "dots",
    "width": 10,
    "pivot": "tail",
}

_mpl_sym_pt_style = {"s": 20, "c": "k", "zorder": 20}
_mpl_sym_label_style = {"size": 16, "zorder": 20}

__all__ = [
    "FermiSlicePlotter",
    "FermiSurfacePlotter",
    "save_plot",
    "show_plot",
    "get_plot_type",
    "get_isosurface_colors",
    "plotly_arrow",
    "rgb_to_plotly",
    "cmap_to_mayavi",
    "cmap_to_plotly",
    "get_segment_arrows",
    "get_face_arrows",
]


@dataclass
class _FermiSurfacePlotData:
    isosurfaces: list[tuple[np.ndarray, np.ndarray]]
    azimuth: float
    elevation: float
    colors: list[tuple[int, int, int]]
    properties: list[np.ndarray]
    arrows: list[tuple[np.ndarray, np.ndarray, np.ndarray]]
    properties_colormap: Colormap | None
    arrow_colormap: Colormap | None
    cmin: float | None
    cmax: float | None
    hide_labels: bool
    hide_cell: bool


@dataclass
class _FermiSlicePlotData:
    slices: list[np.ndarray]
    colors: list[tuple[int, int, int]]
    properties: list[np.ndarray]
    arrows: list[tuple[np.ndarray, np.ndarray, np.ndarray]]
    properties_colormap: Colormap | None
    arrow_colormap: Colormap | None
    cmin: float | None
    cmax: float | None
    hide_labels: bool
    hide_cell: bool


[docs] class FermiSurfacePlotter: """Class to plot a FermiSurface. Args: fermi_surface: A FermiSurface object. symprec: The symmetry precision in Angstrom for determining the high symmetry k-point labels. """ def __init__(self, fermi_surface: FermiSurface, symprec: float = SYMPREC): self.fermi_surface = fermi_surface self.reciprocal_space = fermi_surface.reciprocal_space self.rlat = self.reciprocal_space.reciprocal_lattice self._symmetry_pts = self.get_symmetry_points(fermi_surface, symprec=symprec)
[docs] @staticmethod def get_symmetry_points( fermi_surface: FermiSurface, symprec: float = SYMPREC ) -> tuple[np.ndarray, list[str]]: """Get the high symmetry k-points and labels for the Fermi surface. Args: fermi_surface: A fermi surface. symprec: The symmetry precision in Angstrom. Returns: The high symmetry k-points and labels. """ from pymatgen.symmetry.bandstructure import HighSymmKpath from ifermi.brillouin_zone import WignerSeitzCell from ifermi.kpoints import kpoints_to_first_bz hskp = HighSymmKpath(fermi_surface.structure, symprec=symprec) labels, kpoints = list(zip(*hskp.kpath["kpoints"].items())) if not np.allclose( hskp.prim.lattice.matrix, fermi_surface.structure.lattice.matrix, 1e-5 ): warnings.warn( "Structure does not match expected primitive cell", stacklevel=2 ) if not isinstance(fermi_surface.reciprocal_space, WignerSeitzCell): kpoints = kpoints_to_first_bz(np.array(kpoints)) kpoints = np.dot(kpoints, fermi_surface.reciprocal_space.reciprocal_lattice) return kpoints, labels
[docs] def get_plot( self, plot_type: str = "plotly", spin: Spin | None = None, colors: str | dict | list | None = None, azimuth: float = AZIMUTH, elevation: float = ELEVATION, color_properties: str | bool = True, vector_properties: str | bool = False, projection_axis: tuple[int, int, int] | None = None, vector_spacing: float = VECTOR_SPACING, cmin: float | None = None, cmax: float | None = None, vnorm: float | None = None, hide_surface: bool = False, hide_labels: bool = False, hide_cell: bool = False, plot_index: list[int] | dict[Spin, list[int] | int] | int = None, **plot_kwargs, ): """Plot the Fermi surface. Args: plot_type: Method used for plotting. Valid options are: "matplotlib", "plotly", "mayavi", "crystal_toolkit". spin: Which spin channel to plot. By default plot both spin channels if available. azimuth: The azimuth of the viewpoint in degrees. i.e. the angle subtended by the position vector on a sphere projected on to the x-y plane. elevation: The zenith angle of the viewpoint in degrees, i.e. the angle subtended by the position vector and the z-axis. colors: The color specification for the iso-surfaces. Valid options are: - A single color to use for all Fermi surfaces, specified as a tuple of rgb values from 0 to 1. E.g., red would be ``(1, 0, 0)``. - A list of colors, specified as above. - A dictionary of ``{Spin.up: color1, Spin.down: color2}``, where the colors are specified as above. - A string specifying which matplotlib colormap to use. See https://matplotlib.org/tutorials/colors/colormaps.html for more information. - ``None``, in which case the default colors will be used. color_properties: Whether to use the properties to color the Fermi surface. If the properties is a vector then the norm of the properties will be used. Note, this will only take effect if the Fermi surface has properties. If set to True, the viridis colormap will be used. Alternative colormaps can be selected by setting ``color_properties`` to a matplotlib colormap name. This setting will override the ``colors`` option. For vector properties, the arrows are colored according to the norm of the properties by default. If used in combination with the ``projection_axis`` option, the color will be determined by the dot product of the properties with the projection axis. vector_properties: Whether to plot arrows for vector properties. Note, this will only take effect if the Fermi surface has vector properties. If set to True, the viridis colormap will be used. Alternative colormaps can be selected by setting ``vector_properties`` to a matplotlib colormap name. By default, the arrows are colored according to the norm of the properties. If used in combination with the ``projection_axis`` option, the color will be determined by the dot product of the properties with the projection axis. projection_axis: Projection axis that can be used to calculate the color of vector properties. If None, the norm of the properties will be used, otherwise the color will be determined by the dot product of the properties with the projection axis. Only has an effect when used with the ``vector_properties`` option. vector_spacing: The rough spacing between arrows. Uses a custom algorithm for resampling the Fermi surface to ensure that arrows are not too close together. Only has an effect when used with the ``vector_properties`` option. cmin: Minimum intensity for normalising properties colors (including vector colors). Only has an effect when used with ``color_properties`` or ``vector_properties`` options. cmax: Maximum intensity for normalising properties colors (including vector colors). Only has an effect when used with ``color_properties`` or ``vector_properties`` options. vnorm: The value by which to normalize the vector lengths. For example, spin properties should typically have a norm of 1 whereas group velocity properties can have larger or smaller norms depending on the structure. By changing this number, the size of the vectors will be scaled. Note that the properties of two materials can only be compared quantitatively if a fixed values is used for both plots. Only has an effect when used with the ``vector_properties`` option. hide_surface: Whether to hide the Fermi surface. Only recommended in combination with the ``vector_properties`` option. hide_labels: Whether to show the high-symmetry k-point labels. hide_cell: Whether to show the reciprocal cell boundary. plot_index: A choice of band indices (0-based). Valid options are: - A single integer, which will select that band index in both spin channels (if both spin channels are present). - A list of integers, which will select that set of bands from both spin channels (if both are present). - A dictionary of ``{Spin.up: band_index_1, Spin.down: band_index_2}``, where band_index_1 and band_index_2 are either single integers (if one wishes to plot a single band for that particular spin) or a list of integers. Note that the choice of integer and list can be different for different spin channels. - ``None`` in which case all bands will be plotted. **plot_kwargs: Other keyword arguments supported by the individual plotting methods. """ plot_data = self._get_plot_data( spin=spin, azimuth=azimuth, elevation=elevation, colors=colors, color_properties=color_properties, vector_properties=vector_properties, projection_axis=projection_axis, vector_spacing=vector_spacing, cmin=cmin, cmax=cmax, vnorm=vnorm, hide_surface=hide_surface, hide_labels=hide_labels, hide_cell=hide_cell, plot_index=plot_index, ) if plot_type == "matplotlib": plot = self._get_matplotlib_plot(plot_data, **plot_kwargs) elif plot_type == "plotly": plot = self._get_plotly_plot(plot_data, **plot_kwargs) elif plot_type == "mayavi": plot = self._get_mayavi_plot(plot_data, **plot_kwargs) elif plot_type == "crystal_toolkit": plot = self._get_crystal_toolkit_plot(plot_data, **plot_kwargs) else: types = ["matplotlib", "plotly", "mayavi", "crystal_toolkit"] error_msg = f"Plot type not recognised, valid options: {types}" raise ValueError(error_msg) return plot
def _get_plot_data( self, spin: Spin | None = None, azimuth: float = AZIMUTH, elevation: float = ELEVATION, colors: str | dict | list | None = None, color_properties: str | bool = True, vector_properties: str | bool = False, projection_axis: tuple[int, int, int] | None = None, vector_spacing: float = VECTOR_SPACING, cmin: float | None = None, cmax: float | None = None, vnorm: float | None = None, hide_surface: bool = False, hide_labels: bool = False, hide_cell: bool = False, plot_index: list[int] | dict[Spin, list[int] | int] | int = None, ) -> _FermiSurfacePlotData: """Get the the Fermi surface plot data. See ``FermiSurfacePlotter.get_plot()`` for more details. Returns: The Fermi surface plot data. """ from matplotlib.cm import get_cmap if not spin: spin = self.fermi_surface.spins elif isinstance(spin, Spin): spin = [spin] isosurfaces = [] if not hide_surface: isosurfaces = self.fermi_surface.all_vertices_faces( spins=spin, band_index=plot_index ) properties = [] properties_colormap = None if self.fermi_surface.has_properties: # always calculate properties if they are present so we can determine # cmin and cmax. These are also be used for arrows and it is critical that # cmin and cmax are the same for properties and arrow color scales (even # if the colormap used is different) norm = self.fermi_surface.properties_ndim == 2 properties = self.fermi_surface.all_properties( spins=spin, band_index=plot_index, projection_axis=projection_axis, norm=norm, ) if isinstance(color_properties, str): properties_colormap = get_cmap(color_properties) else: properties_colormap = get_cmap(COLORMAP) cmin, cmax = _get_properties_limits(properties, cmin, cmax) if not color_properties or not self.fermi_surface.has_properties: colors = get_isosurface_colors(colors, self.fermi_surface, spin, plot_index) properties = [] cmin = None cmax = None arrows = [] arrow_colormap = None if vector_properties and self.fermi_surface.has_properties: arrows = get_face_arrows( self.fermi_surface, spin, vector_spacing, vnorm, projection_axis ) if isinstance(vector_properties, str): arrow_colormap = get_cmap(vector_properties) else: arrow_colormap = get_cmap(COLORMAP) return _FermiSurfacePlotData( isosurfaces=isosurfaces, azimuth=azimuth, elevation=elevation, colors=colors, properties=properties, arrows=arrows, properties_colormap=properties_colormap, arrow_colormap=arrow_colormap, cmin=cmin, cmax=cmax, hide_labels=hide_labels, hide_cell=hide_cell, ) def _get_matplotlib_plot( self, plot_data: _FermiSurfacePlotData, ax: Any | None = None, trisurf_kwargs: dict[str, Any] | None = None, cbar_kwargs: dict[str, Any] | None = None, quiver_kwargs: dict[str, Any] | None = None, bz_kwargs: dict[str, Any] | None = None, sym_pt_kwargs: dict[str, Any] | None = None, sym_label_kwargs: dict[str, Any] | None = None, ): """Plot the Fermi surface using matplotlib. Args: plot_data: The plot data. ax: Matplotlib 3D axes on which to plot. trisurf_kwargs: Optional arguments that are passed to ``ax.trisurf`` and are used to style the iso-surface. cbar_kwargs: Optional arguments that are passed to ``fig.colorbar``. quiver_kwargs: Optional arguments that are passed to ``ax.quiver`` and are used to style the arrows. bz_kwargs: Optional arguments that passed to ``Line3DCollection`` and used to style the Brillouin zone boundary. sym_pt_kwargs: Optional arguments that are passed to ``ax.scatter`` and are used to style the high-symmetry k-point symbols. sym_label_kwargs: Optional arguments that are passed to ``ax.text`` and are used to style the high-symmetry k-point labels. Returns: matplotlib pyplot object. """ import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Line3DCollection trisurf_kwargs = trisurf_kwargs or {} cbar_kwargs = cbar_kwargs or {} quiver_kwargs = quiver_kwargs or {} bz_kwargs = bz_kwargs or {} sym_pt_kwargs = sym_pt_kwargs or {} sym_label_kwargs = sym_label_kwargs or {} if ax is None: fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111, projection="3d", proj_type="persp") else: fig = plt.gcf() if plot_data.properties: polyc = None for (verts, faces), proj in zip( plot_data.isosurfaces, plot_data.properties ): x, y, z = verts.T polyc = ax.plot_trisurf( x, y, faces, z, cmap=plot_data.properties_colormap, **trisurf_kwargs ) polyc.set_array(proj) polyc.set_clim(plot_data.cmin, plot_data.cmax) if polyc: _mpl_cbar_style.update(cbar_kwargs) fig.colorbar(polyc, ax=ax, shrink=0.5, **_mpl_cbar_style) else: for c, (verts, faces) in zip(plot_data.colors, plot_data.isosurfaces): x, y, z = verts.T ax.plot_trisurf(x, y, faces, z, facecolor=c, **trisurf_kwargs) if plot_data.arrows is not None: norm = Normalize(vmin=plot_data.cmin, vmax=plot_data.cmax) for starts, stops, intensities in plot_data.arrows: colors = plot_data.arrow_colormap(norm(intensities)) vectors = stops - starts for (x, y, z), (u, v, w), color in zip(starts, vectors, colors): ax.quiver(x, y, z, u, v, w, color=color, **quiver_kwargs) # add the cell outline to the plot if not plot_data.hide_cell: _mpl_bz_style.update(bz_kwargs) lines = Line3DCollection(self.reciprocal_space.lines, **_mpl_bz_style) ax.add_collection3d(lines) if not plot_data.hide_labels: for coords, label in zip(*self._symmetry_pts): _mpl_sym_pt_style.update(sym_pt_kwargs) _mpl_sym_label_style.update(sym_label_kwargs) ax.scatter(*coords, **_mpl_sym_pt_style) ax.text(*coords, f"${label}$", **_mpl_sym_label_style) xlim, ylim, zlim = np.linalg.norm(self.rlat, axis=1) / 2 ax.set(xlim=(-xlim, xlim), ylim=(-ylim, ylim), zlim=(-zlim, zlim)) ax.view_init(elev=plot_data.elevation, azim=plot_data.azimuth) ax.axis("off") plt.tight_layout() return plt def _get_plotly_plot( self, plot_data: _FermiSurfacePlotData, mesh_kwargs: dict[str, Any] | None = None, arrow_line_kwargs: dict[str, Any] | None = None, arrow_cone_kwargs: dict[str, Any] | None = None, bz_kwargs: dict[str, Any] | None = None, sym_pt_kwargs: dict[str, Any] | None = None, sym_label_kwargs: dict[str, Any] | None = None, ): """Plot the Fermi surface using plotly. Args: plot_data: The data to plot. mesh_kwargs: Optional arguments that are passed to ``Mesh3d`` and are used to style the iso-surface. arrow_line_kwargs: Additional keyword arguments used to style the arrow shaft and that are passed to ``Scatter3d``. arrow_cone_kwargs: Additional keyword arguments used to style the arrow cone and that are passed to ``Cone``. bz_kwargs: Optional arguments that passed to ``Scatter3d`` and used to style the Brillouin zone boundary. sym_pt_kwargs: Optional arguments that are passed to ``Scatter3d`` and are used to style the high-symmetry k-point symbols. sym_label_kwargs: Optional arguments that are used in the annotations to style the high-symmetry k-point labels. Returns: Plotly figure object. """ import plotly.graph_objs as go mesh_kwargs = mesh_kwargs or {} arrow_line_kwargs = arrow_line_kwargs or {} arrow_cone_kwargs = arrow_cone_kwargs or {} bz_kwargs = bz_kwargs or {} sym_pt_kwargs = sym_pt_kwargs or {} sym_label_kwargs = sym_label_kwargs or {} if _is_notebook(): from plotly.offline import init_notebook_mode init_notebook_mode(connected=True) meshes = [] if plot_data.properties: # plot mesh with colored properties colors = cmap_to_plotly(plot_data.properties_colormap) for (verts, faces), proj in zip( plot_data.isosurfaces, plot_data.properties ): x, y, z = verts.T i, j, k = faces.T trace = go.Mesh3d( x=x, y=y, z=z, i=i, j=j, k=k, intensity=proj, colorscale=colors, intensitymode="cell", cmin=plot_data.cmin, cmax=plot_data.cmax, **mesh_kwargs, colorbar=_plotly_cbar_style, ) meshes.append(trace) else: for c, (verts, faces) in zip(plot_data.colors, plot_data.isosurfaces): c = rgb_to_plotly(c) x, y, z = verts.T i, j, k = faces.T trace = go.Mesh3d( x=x, y=y, z=z, color=c, opacity=1, i=i, j=j, k=k, **mesh_kwargs ) meshes.append(trace) # add arrows if plot_data.arrows is not None: norm = Normalize(vmin=plot_data.cmin, vmax=plot_data.cmax) for starts, ends, intensities in plot_data.arrows: arrow_colors = plot_data.arrow_colormap(norm(intensities)) for start, end, color in zip(starts, ends, arrow_colors): arrow = plotly_arrow( start, end, color[:3], line_kwargs=arrow_line_kwargs, cone_kwargs=arrow_cone_kwargs, ) meshes.extend(arrow) # add the cell outline to the plot if not plot_data.hide_cell: for line in self.reciprocal_space.lines: x, y, z = line.T _plotly_bz_style.update(bz_kwargs) trace = go.Scatter3d(x=x, y=y, z=z, mode="lines", **_plotly_bz_style) meshes.append(trace) scene = _plotly_scene.copy() if not plot_data.hide_labels: # plot high symmetry k-point markers labels = [f"${i}$" for i in self._symmetry_pts[1]] x, y, z = self._symmetry_pts[0].T _plotly_sym_pt_style.update(sym_pt_kwargs) trace = go.Scatter3d(x=x, y=y, z=z, mode="markers", **_plotly_sym_pt_style) meshes.append(trace) # add high symmetry label annotations = [] for label, (x, y, z) in zip(labels, self._symmetry_pts[0]): _plotly_sym_label_style.update(sym_label_kwargs) annotations.append( dict(x=x, y=y, z=z, text=label, **_plotly_sym_label_style) ) scene["annotations"] = annotations # Specify plot parameters layout = go.Layout( scene=scene, showlegend=False, margin=go.layout.Margin(l=0, r=0, b=0, t=0) ) fig = go.Figure(data=meshes, layout=layout) camera = _get_plotly_camera(plot_data.azimuth, plot_data.elevation) fig.update_layout(scene_camera=camera) return fig @requires(mlab, "mayavi option requires mayavi to be installed.") def _get_mayavi_plot(self, plot_data: _FermiSurfacePlotData): """Plot the Fermi surface using mayavi. Args: plot_data: The data to plot. Returns: mlab figure object. """ from mlabtex import mlabtex mlab.figure(figure=None, bgcolor=(1, 1, 1), size=(800, 800), fgcolor=(0, 0, 0)) if plot_data.properties: cmap = cmap_to_mayavi(plot_data.properties_colormap) for (verts, faces), proj in zip( plot_data.isosurfaces, plot_data.properties ): from tvtk.api import tvtk polydata = tvtk.PolyData(points=verts, polys=faces) polydata.cell_data.scalars = proj polydata.cell_data.scalars.name = "celldata" mesh = mlab.pipeline.surface( polydata, vmin=plot_data.cmin, vmax=plot_data.cmax, opacity=0.8 ) mesh.module_manager.scalar_lut_manager.lut.table = cmap cb = mlab.colorbar(object=mesh, orientation="vertical", nb_labels=5) cb.label_text_property.bold = 0 cb.label_text_property.italic = 0 else: for c, (verts, faces) in zip(plot_data.colors, plot_data.isosurfaces): x, y, z = verts.T mlab.triangular_mesh(x, y, z, faces, color=tuple(c), opacity=0.8) if plot_data.arrows is not None: cmap = cmap_to_mayavi(plot_data.arrow_colormap) for starts, stops, intensities in plot_data.arrows: centers = (stops + starts) / 2 vectors = stops - starts x, y, z = (centers - (vectors * 0.8)).T # leave room for arrow tip u, v, w = vectors.T pnts = mlab.quiver3d( x, y, z, u, v, w, line_width=4.5, mode="arrow", resolution=25, scale_mode="vector", scale_factor=2, scalars=intensities, vmin=plot_data.cmin, vmax=plot_data.cmax, ) pnts.module_manager.scalar_lut_manager.lut.table = cmap pnts.glyph.color_mode = "color_by_scalar" pnts.glyph.glyph_source.glyph_source.shaft_radius = 0.035 pnts.glyph.glyph_source.glyph_source.tip_length = 0.3 if not plot_data.hide_cell: for line in self.reciprocal_space.lines: x, y, z = line.T mlab.plot3d(x, y, z, **_mayavi_rs_style) if not plot_data.hide_labels: # latexify labels labels = [f"${i}$" for i in self._symmetry_pts[1]] for coords, label in zip(self._symmetry_pts[0], labels): mlabtex(*coords, label, **_mayavi_sym_label_style) mlab.gcf().scene._lift() # required to be able to set view mlab.view( azimuth=plot_data.azimuth - 180, elevation=plot_data.elevation - 90, distance="auto", ) return mlab @requires( crystal_toolkit, "crystal_toolkit option requires crystal_toolkit to be installed.", ) def _get_crystal_toolkit_plot( self, plot_data: _FermiSurfacePlotData, opacity: float = 1.0 ): """Get a crystal toolkit Scene showing the Fermi surface. The Scene can be displayed in an interactive web app using Crystal Toolkit, can be shown interactively in Jupyter Lab using the crystal-toolkit lab extension, or can be converted to JSON to store for future use. Args: plot_data: The data to plot. opacity: Opacity of surface. Note that due to limitations of WebGL, overlapping semi-transparent surfaces might result in visual artefacts. Returns: Crystal-toolkit scene. """ from crystal_toolkit.core.scene import Lines, Scene, Spheres, Surface if plot_data.properties is not None or plot_data.arrows is not None: warnings.warn( "crystal_toolkit plot does not support properties or arrows", stacklevel=2, ) # The implementation here is very similar to the plotly implementation, except # the crystal toolkit scene is constructed using the scene primitives from # crystal toolkit (Spheres, Surface, Lines, etc.) scene_contents = [] # create a mesh for each electron band which has an isosurfaces at the Fermi # energy mesh data is generated by a marching cubes algorithm when the # FermiSurface object is created. surfaces = [] for c, (verts, faces) in zip(plot_data.colors, plot_data.isosurfaces): c = rgb_to_plotly(c) positions = verts[faces].reshape(-1, 3).tolist() surface = Surface(positions=positions, color=c, opacity=opacity) surfaces.append(surface) fermi_surface = Scene("fermi_object", contents=surfaces) scene_contents.append(fermi_surface) # add the cell outline to the plot lines = Lines(positions=list(self.reciprocal_space.lines.flatten())) # alternatively, # cylinders have finite width and are lighted, but no strong reason to choose # one over the other # cylinders = Cylinders(positionPairs=self.reciprocal_space.lines.tolist(), # radius=0.01, color="rgb(0,0,0)") scene_contents.append(lines) if not plot_data.hide_labels: spheres = [] for position, label in zip(self._symmetry_pts[0], self._symmetry_pts[1]): sphere = Spheres( positions=[list(position)], tooltip=label, radius=0.05, color="rgb(0, 0, 0)", ) spheres.append(sphere) label_scene = Scene("labels", contents=spheres) scene_contents.append(label_scene) return Scene("ifermi", contents=scene_contents)
[docs] class FermiSlicePlotter: """Class to plot 2D isolines through a FermiSurface. Args: fermi_slice: A slice through a Fermi surface. symprec: The symmetry precision in Angstrom for determining the high symmetry k-point labels. """ def __init__(self, fermi_slice: FermiSlice, symprec: float = SYMPREC): self.fermi_slice = fermi_slice self.reciprocal_slice = fermi_slice.reciprocal_slice self._symmetry_pts = self.get_symmetry_points(fermi_slice, symprec=symprec)
[docs] @staticmethod def get_symmetry_points( fermi_slice: FermiSlice, symprec: float = SYMPREC ) -> tuple[np.ndarray, list[str]]: """Get the high symmetry k-points and labels for the Fermi slice. Args: fermi_slice: A fermi slice. symprec: The symmetry precision in Angstrom. Returns: The high symmetry k-points and labels for points that lie on the slice. """ from pymatgen.symmetry.bandstructure import HighSymmKpath from trimesh import transform_points from ifermi.brillouin_zone import WignerSeitzCell from ifermi.kpoints import kpoints_to_first_bz hskp = HighSymmKpath(fermi_slice.structure, symprec=symprec) labels, kpoints = list(zip(*hskp.kpath["kpoints"].items())) if not np.allclose( hskp.prim.lattice.matrix, fermi_slice.structure.lattice.matrix, 1e-5 ): warnings.warn( "Structure does not match expected primitive cell", stacklevel=2 ) if not isinstance( fermi_slice.reciprocal_slice.reciprocal_space, WignerSeitzCell ): kpoints = kpoints_to_first_bz(np.array(kpoints)) kpoints = np.dot( kpoints, fermi_slice.reciprocal_slice.reciprocal_space.reciprocal_lattice ) kpoints = transform_points(kpoints, fermi_slice.reciprocal_slice.transformation) # filter points that do not lie very close to the plane on_plane = np.where(np.abs(kpoints[:, 2]) < 1e-4)[0] kpoints = kpoints[on_plane] labels = [labels[i] for i in on_plane] return kpoints[:, :2], labels
[docs] def get_plot( self, ax: Any | None = None, spin: Spin | None = None, colors: str | dict | list | None = None, color_properties: str | bool = True, vector_properties: str | bool = False, projection_axis: tuple[int, int, int] | None = None, scale_linewidth: bool | float = False, vector_spacing: float = VECTOR_SPACING, cmin: float | None = None, cmax: float | None = None, vnorm: float | None = None, hide_slice: bool = False, hide_labels: bool = False, hide_cell: bool = False, plot_index: list[int] | dict[Spin, list[int] | int] | int = None, arrow_pivot: str = "tail", slice_kwargs: dict[str, Any] | None = None, cbar_kwargs: dict[str, Any] | None = None, quiver_kwargs: dict[str, Any] | None = None, bz_kwargs: dict[str, Any] | None = None, sym_pt_kwargs: dict[str, Any] | None = None, sym_label_kwargs: dict[str, Any] | None = None, ): """Plot the Fermi slice. Args: ax: Matplotlib axes object on which to plot. spin: Which spin channel to plot. By default plot both spin channels if available. colors: The color specification for the iso-surfaces. Valid options are: - A single color to use for all Fermi isolines, specified as a tuple of rgb values from 0 to 1. E.g., red would be ``(1, 0, 0)``. - A list of colors, specified as above. - A dictionary of ``{Spin.up: color1, Spin.down: color2}``, where the colors are specified as above. - A string specifying which matplotlib colormap to use. See https://matplotlib.org/tutorials/colors/colormaps.html for more information. - ``None``, in which case the default colors will be used. color_properties: Whether to use the properties to color the Fermi isolines. If the properties is a vector then the norm of the properties will be used. Note, this will only take effect if the Fermi slice has properties. If set to True, the viridis colormap will be used. Alternative colormaps can be selected by setting ``color_properties`` to a matplotlib colormap name. This setting will override the ``colors`` option. For vector properties, the arrows are colored according to the norm of the properties by default. If used in combination with the ``projection_axis`` option, the color will be determined by the dot product of the properties with the projection axis. vector_properties: Whether to plot arrows for vector properties. Note, this will only take effect if the Fermi slice has vector properties. If set to True, the viridis colormap will be used. Alternative colormaps can be selected by setting ``vector_properties`` to a matplotlib colormap name. By default, the arrows are colored according to the norm of the properties. If used in combination with the ``projection_axis`` option, the color will be determined by the dot product of the properties with the projection axis. projection_axis: Projection axis that can be used to calculate the color of vector properties. If None, the norm of the properties will be used, otherwise the color will be determined by the dot product of the properties with the projection axis. Only has an effect when used with the ``vector_properties`` option. scale_linewidth: Scale the linewidth by the absolute value of the segment properties. Can be true, false or a number. If a number, then this will be used as the max linewidth for scaling. vector_spacing: The rough spacing between arrows. Uses a custom algorithm for resampling the Fermi surface to ensure that arrows are not too close together. Only has an effect when used with the ``vector_properties`` option. cmin: Minimum intensity for normalising properties colors (including vector colors). Only has an effect when used with ``color_properties`` or ``vector_properties`` options. cmax: Maximum intensity for normalising properties colors (including vector colors). Only has an effect when used with ``color_properties`` or ``vector_properties`` options. vnorm: The value by which to normalize the vector lengths. For example, spin properties should typically have a norm of 1 whereas group velocity properties can have larger or smaller norms depending on the structure. By changing this number, the size of the vectors will be scaled. Note that the properties of two materials can only be compared quantitatively if a fixed values is used for both plots. Only has an effect when used with the ``vector_properties`` option. hide_slice: Whether to hide the Fermi surface. Only recommended in combination with the ``vector_properties`` option. hide_labels: Whether to show the high-symmetry k-point labels. hide_cell: Whether to show the reciprocal cell boundary. plot_index: A choice of band indices (0-based). Valid options are: - A single integer, which will select that band index in both spin channels (if both spin channels are present). - A list of integers, which will select that set of bands from both spin channels (if both a present). - A dictionary of ``{Spin.up: band_index_1, Spin.down: band_index_2}``, where band_index_1 and band_index_2 are either single integers (if one wishes to plot a single band for that particular spin) or a list of integers. Note that the choice of integer and list can be different for different spin channels. - ``None`` in which case all bands will be plotted. arrow_pivot: The part of the arrow that is anchored to the X, Y grid. The arrow rotates about this point, options are: tail, middle, tip. slice_kwargs: Optional arguments that are passed to ``LineCollection`` and are used to style the iso slice. cbar_kwargs: Optional arguments that are passed to ``fig.colorbar``. quiver_kwargs: Optional arguments that are passed to ``ax.quiver`` and are used to style the arrows. bz_kwargs: Optional arguments that passed to ``LineCollection`` and used to style the Brillouin zone boundary. sym_pt_kwargs: Optional arguments that are passed to ``ax.scatter`` and are used to style the high-symmetry k-point symbols. sym_label_kwargs: Optional arguments that are passed to ``ax.text`` and are used to style the high-symmetry k-point labels. Returns: matplotlib pyplot object. """ import matplotlib.pyplot as plt from matplotlib.collections import LineCollection from matplotlib.transforms import ScaledTranslation slice_kwargs = slice_kwargs or {} cbar_kwargs = cbar_kwargs or {} quiver_kwargs = quiver_kwargs or {} bz_kwargs = bz_kwargs or {} sym_pt_kwargs = sym_pt_kwargs or {} sym_label_kwargs = sym_label_kwargs or {} plot_data = self._get_plot_data( spin=spin, colors=colors, color_properties=color_properties, vector_properties=vector_properties, projection_axis=projection_axis, vector_spacing=vector_spacing, cmin=cmin, cmax=cmax, vnorm=vnorm, hide_slice=hide_slice, hide_labels=hide_labels, hide_cell=hide_cell, plot_index=plot_index, ) if ax is None: fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111) else: fig = plt.gcf() # get rotation matrix that will align the longest slice length along the x-axis rotation = _get_rotation(self.fermi_slice.reciprocal_slice) if plot_data.properties: norm = Normalize(vmin=plot_data.cmin, vmax=plot_data.cmax) reference = max(abs(plot_data.cmax), abs(plot_data.cmin)) lines = None for segments, proj in zip(plot_data.slices, plot_data.properties): if scale_linewidth is False: linewidth = 2 else: base_width = 4 if isinstance(scale_linewidth, (float, int)) else 4 linewidth = abs(proj) * base_width / reference slice_style = {"antialiaseds": True, "linewidth": linewidth} slice_style.update(slice_kwargs) lines = LineCollection( np.dot(segments, rotation), cmap=plot_data.properties_colormap, norm=norm, **slice_style, ) lines.set_array(proj) # set the values used for color mapping ax.add_collection(lines) if lines: _mpl_cbar_style.update(cbar_kwargs) fig.colorbar(lines, ax=ax, **_mpl_cbar_style) else: slice_style = {"antialiasted": True, "linewidth": 2} slice_style.update(slice_kwargs) for c, segments in zip(plot_data.colors, plot_data.slices): lines = LineCollection( np.dot(segments, rotation), colors=c, **slice_kwargs ) ax.add_collection(lines) if not plot_data.hide_cell: # add the cell outline to the plot rotated_lines = np.dot(self.reciprocal_slice.lines, rotation) _mpl_bz_style.update(bz_kwargs) lines = LineCollection(rotated_lines, **_mpl_bz_style) ax.add_collection(lines) if not plot_data.hide_labels: # shift labels a few pixels away from the high-sym points offset = ScaledTranslation(4 / 72, 4 / 72, fig.dpi_scale_trans) for coords, label in zip(*self._symmetry_pts): coords = np.dot(coords, rotation) _mpl_sym_pt_style.update(sym_pt_kwargs) _mpl_sym_label_style.update(sym_label_kwargs) ax.scatter(*coords, **_mpl_sym_pt_style) ax.text( *coords, f"${label}$", **_mpl_sym_label_style, transform=ax.transData + offset, ) if plot_data.arrows is not None: norm = Normalize(vmin=plot_data.cmin, vmax=plot_data.cmax) _mpl_arrow_style["pivot"] = arrow_pivot _mpl_arrow_style.update(quiver_kwargs) for starts, stops, intensities in plot_data.arrows: colors = plot_data.arrow_colormap(norm(intensities)) starts = np.dot(starts, rotation) stops = np.dot(stops, rotation) u, v = (stops - starts).T x, y = starts.T ax.quiver(x, y, u, v, color=colors, **_mpl_arrow_style) ax.margins(y=0.1, x=0.1) ax.autoscale_view() ax.axis("equal") ax.axis("off") return plt
def _get_plot_data( self, spin: Spin | None = None, colors: str | dict | list | None = None, color_properties: str | bool = True, vector_properties: str | bool = False, projection_axis: tuple[int, int, int] | None = None, vector_spacing: float = VECTOR_SPACING, cmin: float | None = None, cmax: float | None = None, vnorm: float | None = None, hide_slice: bool = False, hide_labels: bool = False, hide_cell: bool = False, plot_index: list[int] | dict[Spin, list[int] | int] | int = None, ) -> _FermiSlicePlotData: """Get the the Fermi slice plot data. See ``FermiSlicePlotter.get_plot()`` for more details. Returns: The Fermi slice plot data. """ from matplotlib.cm import get_cmap if not spin: spin = self.fermi_slice.spins elif isinstance(spin, Spin): spin = [spin] slices = [] if not hide_slice: slices = self.fermi_slice.all_lines(spins=spin, band_index=plot_index) properties = [] properties_colormap = None if self.fermi_slice.has_properties: # always calculate properties if they are present so we can determine # cmin and cmax. These are also be used for arrows and it is critical that # cmin and cmax are the same for properties and arrow color scales (even # if the colormap used is different) norm = self.fermi_slice.properties_ndim == 2 properties = self.fermi_slice.all_properties( spins=spin, band_index=plot_index, projection_axis=projection_axis, norm=norm, ) if isinstance(color_properties, str): properties_colormap = get_cmap(color_properties) else: properties_colormap = get_cmap(COLORMAP) cmin, cmax = _get_properties_limits(properties, cmin, cmax) if not color_properties or not self.fermi_slice.has_properties: colors = get_isosurface_colors(colors, self.fermi_slice, spin, plot_index) properties = [] cmin = None cmax = None arrows = [] arrow_colormap = None if vector_properties and self.fermi_slice.has_properties: arrows = get_segment_arrows( self.fermi_slice, spin, vector_spacing, vnorm, projection_axis ) if isinstance(vector_properties, str): arrow_colormap = get_cmap(vector_properties) else: arrow_colormap = get_cmap(COLORMAP) return _FermiSlicePlotData( slices=slices, colors=colors, properties=properties, arrows=arrows, properties_colormap=properties_colormap, arrow_colormap=arrow_colormap, cmin=cmin, cmax=cmax, hide_labels=hide_labels, hide_cell=hide_cell, )
[docs] def show_plot(plot: Any): """Display a plot. Args: plot: A plot object from ``FermiSurfacePlotter.get_plot()``. Supports matplotlib pyplot objects, plotly figure objects, and mlab figure objects. """ plot_type = get_plot_type(plot) if plot_type == "matplotlib": plot.show() elif plot_type == "plotly": from plotly.offline import plot as show_plotly show_plotly(plot, include_mathjax="cdn", filename="fermi-surface.html") elif plot_type == "mayavi": plot.show()
[docs] def save_plot(plot: Any, filename: Path | str, scale: float = SCALE): """Save a plot to file. Args: plot: A plot object from ``FermiSurfacePlotter.get_plot()``. Supports matplotlib pyplot objects, plotly figure objects, and mlab figure objects. filename: The output filename. scale: Scale for the figure size. Increases resolution but does not change the relative size of the figure and text. """ plot_type = get_plot_type(plot) filename = str(filename) if plot_type == "matplotlib": # default dpi is ~100 plot.savefig(filename, dpi=scale * 100, bbox_inches="tight") elif plot_type == "plotly": if "html" in filename: from plotly.offline import plot as show_plotly show_plotly(plot, include_mathjax="cdn", filename=filename, auto_open=False) else: if kaleido is None: raise ValueError( "kaleido package required to save static ploty images\n" "please install it using:\npip install kaleido" ) plot.write_image( filename, engine="kaleido", scale=scale, width=750, height=750 ) elif plot_type == "mayavi": plot.savefig(filename, magnification=scale)
[docs] def get_plot_type(plot: Any) -> str: """Get the plot type. Args: plot: A plot object from ``FermiSurfacePlotter.get_plot()``. Supports matplotlib pyplot objects, plotly figure objects, and mlab figure objects. Returns: The plot type. Current options are "matplotlib", "mayavi", and "plotly". """ from plotly.graph_objs import Figure if isinstance(plot, Figure): return "plotly" if hasattr(plot, "__name__"): if "matplotlib" in plot.__name__: return "matplotlib" if "mayavi" in plot.__name__: return "mayavi" raise ValueError("Unrecognised plot type.")
[docs] def get_isosurface_colors( colors: str | dict | list | None, fermi_object: FermiSurface | FermiSlice, spins: list[Spin], plot_index: list[int] | dict[Spin, list[int] | int] | int, ) -> list[tuple[float, float, float]]: """Get colors for each Fermi surface. Args: colors: The color specification. Valid options are: - A single color to use for all Fermi surfaces, specified as a tuple of rgb values from 0 to 1. E.g., red would be ``(1, 0, 0)``. - A list of colors, specified as above. - A dictionary of ``{Spin.up: color1, Spin.down: color2}``, where the colors are specified as above. - A string specifying which matplotlib colormap to use. See https://matplotlib.org/tutorials/colors/colormaps.html for more information. - ``None``, in which case the default colors will be used. fermi_object: A Fermi surface or Fermi slice object. spins: A list of spins for which colors will be generated. plot_index: A choice of band indices (0-based). Valid options are: - A single integer, which will select that band index in both spin channels (if both spin channels are present). - A list of integers, which will select that set of bands from both spin channels (if both are present). - A dictionary of ``{Spin.up: band_index_1, Spin.down: band_index_2}``, where band_index_1 and band_index_2 are either single integers (if one wishes to plot a single band for that particular spin) or a list of integers. Note that the choice of integer and list can be different for different spin channels. - ``None`` in which case all bands will be selected. Returns: The colors as a list of tuples, where each color is specified as the rgb values from 0 to 1. E.g., red would be ``(1, 0, 0)``. """ from matplotlib.cm import get_cmap if isinstance(fermi_object, FermiSurface): n_objects_per_band = fermi_object.n_surfaces_per_band else: n_objects_per_band = fermi_object.n_lines_per_band surface_multiplicity = [] for spin in spins: if isinstance(plot_index, dict): # if plot_index is a dict, the get the idxs and make sure they are a list idxs = plot_index.get(spin, []) idxs = idxs if isinstance(idxs, (list, tuple)) else [idxs] elif isinstance(plot_index, int): idxs = [plot_index] elif isinstance(plot_index, (list, tuple)): idxs = plot_index else: # otherwise plot all bands idxs = sorted(n_objects_per_band[spin].keys()) for band_idx in idxs: surface_multiplicity.append(n_objects_per_band[spin][band_idx]) n_objects = len(surface_multiplicity) if n_objects == 0: # catch the case of no surfaces present return [] if isinstance(colors, (tuple, list, np.ndarray)): if isinstance(colors[0], (tuple, list, np.ndarray)): # colors is a list of colors cc = list(colors) * (len(colors) // n_objects + 1) color_list = cc[:n_objects] else: # colors is a single color specification color_list = [colors] * n_objects elif isinstance(colors, dict): if len(colors) < len(spins): raise ValueError( "colors dict must have same number of spin channels as spins to plot" ) return [ colors[s] for s in spins for _ in range(fermi_object.n_surfaces_per_spin[s]) ] elif isinstance(colors, str): # get rid of alpha channel color_list = [i[:3] for i in get_cmap(colors)(np.linspace(0, 1, n_objects))] else: from plotly.colors import qualitative, unconvert_from_RGB_255, unlabel_rgb cc = qualitative.Prism * (len(qualitative.Prism) // n_objects + 1) color_list = [unconvert_from_RGB_255(unlabel_rgb(c)) for c in cc[:n_objects]] return [c for c, n in zip(color_list, surface_multiplicity) for _ in range(n)]
[docs] def get_face_arrows( fermi_surface: FermiSurface, spins: list[Spin], vector_spacing: float, vnorm: float | None, projection_axis: tuple[int, int, int] | None, ) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]: """Get face arrows from vector properties. Args: fermi_surface: The fermi surface containing the isosurfaces and properties. spins: Spin channels from which to extract arrows. vector_spacing: The rough spacing between arrows. Uses a custom algorithm for resampling the Fermi surface to ensure that arrows are not too close together. vnorm: The value by which to normalize the vector lengths. For example, spin properties should typically have a norm of 1 whereas group velocity properties can have larger or smaller norms depending on the structure. By changing this number, the size of the vectors will be scaled. Note that the properties of two materials can only be compared quantitatively if a fixed values is used for both plots. projection_axis: Projection axis that can be used to calculate the color of vector projections. If None, the norm of the properties will be used, otherwise the color will be determined by the dot product of the properties with the projection axis. Returns: The arrows, as a list of (starts, stops, intensities) for each face. The starts and stops are numpy arrays with the shape (narrows, 3) and intensities is a numpy array with the shape (narrows, ). The intensities are used to color the arrows during plotting. """ centers = [] intensity = [] vectors = [] for spin in spins: for isosurface in fermi_surface.isosurfaces[spin]: if isosurface.properties_ndim != 2: continue face_idx = isosurface.sample_uniform(vector_spacing) # get the center of each of face in cartesian coords faces = isosurface.faces[face_idx] centers.append(isosurface.vertices[faces].mean(axis=1)) vectors.append(isosurface.properties[face_idx]) if projection_axis is None: # intensity is the norm of the properties intensities = isosurface.properties_norms[face_idx] else: # get intensity from projection of the vector onto axis intensities = isosurface.scalar_projection(projection_axis)[face_idx] intensity.append(intensities) if vnorm is None: property_norms = fermi_surface.all_properties(spins=spins, norm=True) vnorm = np.max([np.max(x) for x in property_norms]) arrows = [] for face_vectors, face_centers, face_intensity in zip(vectors, centers, intensity): face_vectors *= 0.14 / vnorm # 0.14 is magic scaling factor for vector length start = face_centers - face_vectors / 2 stop = start + face_vectors arrows.append((start, stop, face_intensity)) return arrows
[docs] def get_segment_arrows( fermi_slice: FermiSlice, spins: Collection[Spin], vector_spacing: float, vnorm: float | None, projection_axis: tuple[int, int, int] | None, ) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]: """Get segment arrows from vector properties. Args: fermi_slice: The Fermi slice containing the isolines and properties. spins: Spin channels from which to extract arrows. vector_spacing: The rough spacing between arrows. Uses a custom algorithm for resampling the Fermi slic to ensure that arrows are not too close together. vnorm: The value by which to normalize the vector lengths. For example, spin properties should typically have a norm of 1 whereas group velocity properties can have larger or smaller norms depending on the structure. By changing this number, the size of the vectors will be scaled. Note that the properties of two materials can only be compared quantitatively if a fixed values is used for both plots. projection_axis: Projection axis that can be used to calculate the color of vector projects. If None, the norm of the properties will be used, otherwise the color will be determined by the dot product of the properties with the properties axis. Returns: The arrows, as a list of (starts, stops, intensities) for each face. The starts and stops are numpy arrays with the shape (narrows, 3) and intensities is a numpy array with the shape (narrows, ). The intensities are used to color the arrows during plotting. """ from trimesh import transform_points centers = [] intensity = [] vectors = [] for spin in spins: for isoline in fermi_slice.isolines[spin]: if isoline.properties_ndim != 2: continue segment_idx = isoline.sample_uniform(vector_spacing) # get the center of each of segment in cartesian coords centers.append(isoline.segments[segment_idx].mean(axis=1)) vectors.append(isoline.properties[segment_idx]) if projection_axis is None: # properties intensity is the norm of the properties intensities = isoline.properties_norms[segment_idx] else: # get properties intensity from properties of the vector onto axis intensities = isoline.scalar_projection(projection_axis)[segment_idx] intensity.append(intensities) if vnorm is None: property_norms = fermi_slice.all_properties(spins=spins, norm=True) vnorm = np.max([np.max(x) for x in property_norms]) arrows = [] for segment_vectors, segment_centers, segment_intensity in zip( vectors, centers, intensity ): segment_vectors *= 0.31 / vnorm # magic scaling factor for length # transform vectors onto 2D plane segment_vectors = transform_points( segment_vectors, fermi_slice.reciprocal_slice.transformation )[:, :2] start = segment_centers stop = start + segment_vectors arrows.append((start, stop, segment_intensity)) return arrows
def _get_properties_limits( projections: list[np.ndarray], cmin: float | None, cmax: float | None ) -> tuple[float, float]: """Get the min and max properties if they are not already set. Args: projections: The properties for each Fermi surface as a list of numpy arrays. cmin: A minimum value that overrides the one extracted from the properties. cmax: A maximum value that overrides the one extracted from the properties. Returns: The projection limits as a tuple of (min, max). """ if cmax is None: cmax = np.max([np.max(x) for x in projections]) if cmin is None: cmin = np.min([np.min(x) for x in projections]) return cmin, cmax def _get_plotly_camera(azimuth: float, elevation: float) -> dict[str, dict[str, float]]: """Get plotly viewpoint from azimuth and elevation.""" azimuth = np.radians(azimuth) elevation = np.radians(elevation) norm = np.linalg.norm([1.25, 1.25, 1.25]) # default plotly vector distance x = np.sin(azimuth) * np.cos(elevation) * norm y = np.cos(azimuth) * np.cos(elevation) * norm z = np.sin(elevation) * norm return { "up": {"x": 0, "y": 0, "z": 1}, "center": {"x": 0, "y": 0, "z": 0}, "eye": {"x": x, "y": y, "z": z}, }
[docs] def plotly_arrow( start: np.ndarray, stop: np.ndarray, color: tuple[float, float, float], line_kwargs: dict[str, Any] | None = None, cone_kwargs: dict[str, Any] | None = None, ) -> tuple[Any, Any]: """Create an arrow object. Args: start: The starting coordinates. stop: The ending coordinates. color: The arrow color in rgb format as a tuple of floats from 0 to 1. line_kwargs: Additional keyword arguments used to style the arrow shaft and that are passed to ``Scatter3d``. cone_kwargs: Additional keyword arguments used to style the arrow cone and that are passed to ``Cone``. Returns: The arrow, formed by a line and cone. """ import plotly.graph_objs as go vector = (stop - start) / np.linalg.norm(stop - start) color = rgb_to_plotly(color) line_kwargs = line_kwargs or {} cone_kwargs = cone_kwargs or {} line_style = {"line": {"width": 6, "color": color}, "showlegend": False} line_style.update(line_kwargs) cone_style = { "showscale": False, "sizemode": "absolute", "sizeref": 0.08, # magic cone length "anchor": "cm", } cone_style.update(cone_kwargs) line = go.Scatter3d( x=[start[0], stop[0]], y=[start[1], stop[1]], z=[start[2], stop[2]], mode="lines", **line_style, ) cone = go.Cone( x=[stop[0]], y=[stop[1]], z=[stop[2]], u=[vector[0]], v=[vector[1]], w=[vector[2]], colorscale=[[0, color], [1, color]], **cone_style, ) return line, cone
[docs] def rgb_to_plotly(color: tuple[float, float, float]) -> str: """Get a plotly formatted color from rgb values. Args: color: The color in rgb format as a tuple of three floats from 0 to 1. Returns: The plotly formatted color. """ from plotly.colors import convert_to_RGB_255, label_rgb return label_rgb(convert_to_RGB_255(color))
[docs] def cmap_to_plotly(colormap: Colormap) -> list[str]: """Convert a matplotlib colormap to plotly colorscale format. Args: colormap: A matplotlib colormap object. Returns: The equivalent plotly colorscale. """ from plotly.colors import make_colorscale rgb_colors = colormap(np.linspace(0, 1, 255))[:, :3] return make_colorscale([rgb_to_plotly(color) for color in rgb_colors])
[docs] def cmap_to_mayavi(colormap: Colormap) -> np.ndarray: """Convert a matplotlib colormap to mayavi format. Args: colormap: A matplotlib colormap object. Returns: The equivalent mayavi colormap, as a (255, 4) numpy array. """ return (colormap(np.linspace(0, 1, 255)) * 255).astype(int)
def _get_rotation(reciprocal_slice) -> np.ndarray: """Get a rotation matrix that aligns the longest slice length along the x axis. Args: reciprocal_slice: A reciprocal slice. Returns: The transformation matrix as 2x2 array. """ line_vectors = reciprocal_slice.lines[:, 0, :] - reciprocal_slice.lines[:, 1, :] line_lengths = np.linalg.norm(line_vectors, axis=-1) longest_line = line_vectors[np.argmax(line_lengths)] longest_line_norm = longest_line / np.linalg.norm(longest_line) dotp = np.dot(longest_line_norm, [1, 0]) angle = np.arccos(dotp) cos_angle = np.cos(angle) sin_angle = np.sin(angle) rotation = np.array([[cos_angle, -sin_angle], [sin_angle, cos_angle]]) return rotation.T def _is_notebook(): """Check if running in a jupyter notebook.""" try: from IPython import get_ipython if "IPKernelApp" not in get_ipython().config: raise ImportError("console") if "VSCODE_PID" in os.environ: raise ImportError("vscode") except Exception: return False else: return True