Coverage for dynasor/modes/project_modes.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-18 09:03 +0000

1from typing import Optional 

2 

3import numpy as np 

4 

5from ase import Atoms 

6from numpy.typing import NDArray 

7 

8from dynasor.logging_tools import logger 

9from dynasor.trajectory import Trajectory 

10from dynasor.tools.structures import get_displacements_from_u 

11 

12 

13def project_modes( 

14 traj: Trajectory, 

15 modes: NDArray[float], 

16 ideal_supercell: Atoms, 

17 check_mic: Optional[bool] = True, 

18 logging_interval: Optional[int] = 1000, 

19) -> tuple[NDArray[float], NDArray[float]]: 

20 """Projects an atomic trajectory onto set of phonon modes. 

21 

22 Parameters 

23 ---------- 

24 traj 

25 Input trajectory. 

26 modes 

27 Modes to project on, as an array with shape ``(..., N, 3)`` where ``N`` is the 

28 number of atoms in the supercell and the leading dimensions define the output shape. 

29 ideal_supercell 

30 Ideal supercell used to find atomic displacements. It should correspond to the ideal 

31 structure. Be careful not to mess up the permutation. 

32 check_mic 

33 Whether to wrap the displacements or not, faster if no wrap. 

34 logging_interval 

35 Log progress at ``INFO`` level every this many frames. Set to ``0`` to disable 

36 progress logging. 

37 

38 Returns 

39 ------- 

40 A tuple comprising `(Q,P)` where `Q` are the mode coordinates as a complex array 

41 with dimension (length of traj, number of modes) and `P` are the mode momenta as a 

42 complex array with dimension (length of traj, number of modes). 

43 """ 

44 # logger 

45 logger.info('Running mode projection') 

46 

47 modes = np.array(modes) 

48 

49 original_mode_shape = modes.shape 

50 

51 if modes.shape[-2] != traj.n_atoms: 

52 raise ValueError('Second dim in modes must be same len as number of atoms in trajectory') 

53 if traj.n_atoms != len(ideal_supercell): 

54 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

55 

56 modes = modes.reshape((-1, modes.shape[-2], 3)) 

57 modes_conj = modes.conj() 

58 

59 Q_traj, P_traj = [], [] 

60 for it, frame in enumerate(traj): 

61 if logging_interval and it % logging_interval == 0: 

62 logger.info(f'Reading frame {it}') 

63 else: 

64 logger.debug(f'Reading frame {it}') 

65 

66 # Make positions into displacements 

67 x = frame.get_positions_as_array(traj._atomic_indices) 

68 u = x - ideal_supercell.positions 

69 

70 # Calculate Q 

71 u = get_displacements_from_u(u, ideal_supercell.cell, check_mic=check_mic) 

72 Q = np.einsum('mnx,nx->m', modes, u, optimize=True) 

73 

74 # Calculate P 

75 if frame.velocities_by_type is not None: 

76 v = frame.get_velocities_as_array(traj._atomic_indices) 

77 P = np.einsum('mna,na->m', modes_conj, v, optimize=True) 

78 else: 

79 P = np.zeros_like(Q) 

80 

81 Q_traj.append(Q) 

82 P_traj.append(P) 

83 

84 Q_traj = np.array(Q_traj).reshape((len(Q_traj), *original_mode_shape[:-2])) 

85 P_traj = np.array(P_traj).reshape((len(P_traj), *original_mode_shape[:-2])) 

86 

87 return Q_traj, P_traj