Newer
Older
# Mantid Repository : https://github.com/mantidproject/mantid
#
# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
# NScD Oak Ridge National Laboratory, European Spallation Source,
# Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
# SPDX - License - Identifier: GPL - 3.0 +
from abins.constants import (COMPLEX_ID, FLOAT_ID, GAMMA_POINT, SMALL_K)
"""Class storing atomic frequencies and displacements at specific k-points
"weights" - weights of all k-points; weights.shape == (num_k,);
"k_vectors" - k_vectors of all k-points; k_vectors.shape == (num_k, 3)
"frequencies" - frequencies for all k-points; frequencies.shape == (num_k, num_freq)
"atomic_displacements" - atomic displacements for all k-points;
atomic_displacements.shape == (num_k, num_atoms, num_freq, 3)
"unit_cell" - lattice vectors (use zeros for open boundary conditions);
unit_cell.shape == (3, 3)
def __init__(self, *, frequencies: np.ndarray, atomic_displacements: np.ndarray,
weights: np.ndarray, k_vectors: np.ndarray, unit_cell: np.ndarray) -> None:
super().__init__()
self._data = {}
dim = 3
Krzysztof Dymkowski
committed
for arg in (frequencies, atomic_displacements, weights, k_vectors, unit_cell):
if not isinstance(arg, np.ndarray):
raise TypeError("All arguments to KpointsData should be numpy arrays")
Krzysztof Dymkowski
committed
# unit_cell
if not (unit_cell.shape == (dim, dim)
and unit_cell.dtype.num == FLOAT_ID):
Krzysztof Dymkowski
committed
raise ValueError("Invalid values of unit cell vectors.")
self.unit_cell = unit_cell
Krzysztof Dymkowski
committed
num_k = weights.size
if not (weights.dtype.num == FLOAT_ID
and np.allclose(weights, weights[weights >= 0])):
raise ValueError("Invalid value of weights.")
# "k_vectors"
if not (k_vectors.shape == (num_k, dim)
and k_vectors.dtype.num == FLOAT_ID):
raise ValueError("Invalid value of k_vectors.")
# "frequencies"
if not (frequencies.shape == (num_k, num_freq)
and frequencies.dtype.num == FLOAT_ID):
raise ValueError("Invalid value of frequencies.")
Krzysztof Dymkowski
committed
# "atomic_displacements"
if len(atomic_displacements.shape) != 4:
raise ValueError("atomic_displacements should have four dimensions")
num_atoms = atomic_displacements.shape[1]
if not (atomic_displacements.shape == (weights.size, num_atoms, num_freq, dim)
and atomic_displacements.dtype.num == COMPLEX_ID):
raise ValueError("Invalid value of atomic_displacements.")
# Repackage data into dicts
self._frequencies = {k_index: frequencies[k_index, :] for k_index in range(num_k)}
self._atomic_displacements = {k_index: atomic_displacements[k_index] for k_index in range(num_k)}
self._k_vectors = {k_index: k_vectors[k_index] for k_index in range(num_k)}
self._weights = {k_index: weights[k_index] for k_index in range(num_k)}
Krzysztof Dymkowski
committed
def get_gamma_point_data(self):
"""
Extracts k points data only for Gamma point.
Krzysztof Dymkowski
committed
:returns: dictionary with data only for Gamma point
Krzysztof Dymkowski
committed
"""
gamma_pkt_index = -1
# look for index of Gamma point
for k_index, k in self._k_vectors.items():
if np.linalg.norm(k) < SMALL_K:
gamma_pkt_index = k_index
Krzysztof Dymkowski
committed
break
Krzysztof Dymkowski
committed
raise ValueError("Gamma point not found.")
k_points = {"weights": {GAMMA_POINT: self._data["weights"][gamma_pkt_index]},
"k_vectors": {GAMMA_POINT: self._data["k_vectors"][gamma_pkt_index]},
"frequencies": {GAMMA_POINT: self._data["frequencies"][gamma_pkt_index]},
"atomic_displacements": {GAMMA_POINT: self._data["atomic_displacements"][gamma_pkt_index]},
"unit_cell": self.unit_cell}
Krzysztof Dymkowski
committed
return k_points
extracted = {"unit_cell": self.unit_cell,
"weights": {str(key): value for key, value in self._weights.items()},
"k_vectors": {str(key): value for key, value in self._k_vectors.items()},
"frequencies": {str(key): value for key, value in self._frequencies.items()},
"atomic_displacements": {str(key): value for key, value in self._atomic_displacements.items()}}
return extracted
def __str__(self):
return "K-points data"