Coverage for dynasor/modes/project_modes.py: 100%
35 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-18 09:03 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-18 09:03 +0000
1from typing import Optional
3import numpy as np
5from ase import Atoms
6from numpy.typing import NDArray
8from dynasor.logging_tools import logger
9from dynasor.trajectory import Trajectory
10from dynasor.tools.structures import get_displacements_from_u
13def project_modes(
14 traj: Trajectory,
15 modes: NDArray[float],
16 ideal_supercell: Atoms,
17 check_mic: Optional[bool] = True,
18 logging_interval: Optional[int] = 1000,
19) -> tuple[NDArray[float], NDArray[float]]:
20 """Projects an atomic trajectory onto set of phonon modes.
22 Parameters
23 ----------
24 traj
25 Input trajectory.
26 modes
27 Modes to project on, as an array with shape ``(..., N, 3)`` where ``N`` is the
28 number of atoms in the supercell and the leading dimensions define the output shape.
29 ideal_supercell
30 Ideal supercell used to find atomic displacements. It should correspond to the ideal
31 structure. Be careful not to mess up the permutation.
32 check_mic
33 Whether to wrap the displacements or not, faster if no wrap.
34 logging_interval
35 Log progress at ``INFO`` level every this many frames. Set to ``0`` to disable
36 progress logging.
38 Returns
39 -------
40 A tuple comprising `(Q,P)` where `Q` are the mode coordinates as a complex array
41 with dimension (length of traj, number of modes) and `P` are the mode momenta as a
42 complex array with dimension (length of traj, number of modes).
43 """
44 # logger
45 logger.info('Running mode projection')
47 modes = np.array(modes)
49 original_mode_shape = modes.shape
51 if modes.shape[-2] != traj.n_atoms:
52 raise ValueError('Second dim in modes must be same len as number of atoms in trajectory')
53 if traj.n_atoms != len(ideal_supercell):
54 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
56 modes = modes.reshape((-1, modes.shape[-2], 3))
57 modes_conj = modes.conj()
59 Q_traj, P_traj = [], []
60 for it, frame in enumerate(traj):
61 if logging_interval and it % logging_interval == 0:
62 logger.info(f'Reading frame {it}')
63 else:
64 logger.debug(f'Reading frame {it}')
66 # Make positions into displacements
67 x = frame.get_positions_as_array(traj._atomic_indices)
68 u = x - ideal_supercell.positions
70 # Calculate Q
71 u = get_displacements_from_u(u, ideal_supercell.cell, check_mic=check_mic)
72 Q = np.einsum('mnx,nx->m', modes, u, optimize=True)
74 # Calculate P
75 if frame.velocities_by_type is not None:
76 v = frame.get_velocities_as_array(traj._atomic_indices)
77 P = np.einsum('mna,na->m', modes_conj, v, optimize=True)
78 else:
79 P = np.zeros_like(Q)
81 Q_traj.append(Q)
82 P_traj.append(P)
84 Q_traj = np.array(Q_traj).reshape((len(Q_traj), *original_mode_shape[:-2]))
85 P_traj = np.array(P_traj).reshape((len(P_traj), *original_mode_shape[:-2]))
87 return Q_traj, P_traj