Coverage for local_installation/dynasor/trajectory/trajectory_frame.py: 99%

66 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-05 09:53 +0000

1from dataclasses import dataclass 

2import numpy as np 

3from typing import Dict, List 

4from numpy.typing import NDArray 

5 

6 

7@dataclass 

8class ReaderFrame: 

9 """Trivial data struct holding MD-data for one time frame 

10 

11 Parameters 

12 ---------- 

13 frame_index 

14 Trajectory index of the snapshot (frame) 

15 cell 

16 Simulation cell as 3 row vectors (Å) 

17 n_atoms 

18 number of atoms 

19 positions 

20 particle positions as 3xn_atoms array (Å) 

21 velocities 

22 particle velocities as 3xn_atoms array (Å/fs); 

23 may not be available, depending on reader and trajectory file format 

24 atom_types 

25 Array with the type of each atom; 

26 may not be available, depending on reader and trajectory file format 

27 """ 

28 frame_index: int 

29 cell: np.ndarray 

30 n_atoms: int 

31 positions: np.ndarray 

32 velocities: np.ndarray = None 

33 atom_types: NDArray[str] = None 

34 

35 

36class TrajectoryFrame: 

37 """ 

38 Class holding positions and optionally velocities split by atom type 

39 for one snapshot (frame) in a trajectory 

40 

41 Attributes 

42 ---------- 

43 * positions_by_type 

44 * velocities_by_type 

45 

46 such that e.g. 

47 positions_by_type['Cs'] numpy array with shape (n_atoms_Cs, 3) 

48 positions_by_type['Pb'] numpy array with shape (n_atoms_Pb, 3) 

49 

50 Parameters 

51 ---------- 

52 atomic_indices 

53 Dictionary specifying which indices (dict values) belong to which atom type (dict keys) 

54 frame_index 

55 Trajectory index of the snapshot (frame) 

56 positions 

57 Positions as an array with shape ``(n_atoms, 3)`` 

58 velocities 

59 Velocities as an array with shape ``(n_atoms, 3)``; defaults to ``None`` 

60 """ 

61 

62 def __init__(self, 

63 atomic_indices: Dict[str, List[int]], 

64 frame_index: int, 

65 positions: np.ndarray, 

66 velocities: np.ndarray = None): 

67 self._frame_index = frame_index 

68 

69 self.positions_by_type = dict() 

70 for atom_type, indices in atomic_indices.items(): 

71 self.positions_by_type[atom_type] = positions[indices, :].copy() 

72 

73 if velocities is not None: 

74 self.velocities_by_type = dict() 

75 for atom_type, indices in atomic_indices.items(): 

76 self.velocities_by_type[atom_type] = velocities[indices, :].copy() 

77 else: 

78 self.velocities_by_type = None 

79 

80 def get_positions_as_array(self, atomic_indices: Dict[str, List[int]]): 

81 """ 

82 Construct the full positions array with shape ``(n_atoms, 3)``. 

83 

84 Parameters 

85 --------- 

86 atomic_indices 

87 Dictionary specifying which indices (dict values) belong to which atom type (dict keys) 

88 """ 

89 

90 # check that atomic_indices is complete 

91 n_atoms = np.max([np.max(indices) for indices in atomic_indices.values()]) + 1 

92 all_inds = [i for indices in atomic_indices.values() for i in indices] 

93 if len(all_inds) != n_atoms or len(set(all_inds)) != n_atoms: 

94 raise ValueError('atomic_indices is incomplete') 

95 

96 # collect positions into a single array 

97 x = np.empty((n_atoms, 3)) 

98 for atom_type, indices in atomic_indices.items(): 

99 x[indices, :] = self.positions_by_type[atom_type] 

100 return x 

101 

102 def get_velocities_as_array(self, atomic_indices: Dict[str, List[int]]): 

103 """ 

104 Construct the full velocities array with shape ``(n_atoms, 3)``. 

105 

106 Parameters 

107 --------- 

108 atomic_indices 

109 Dictionary specifying which indices (dict values) belong to which atom type (dict keys) 

110 """ 

111 

112 # check that atomic_indices is complete 

113 n_atoms = np.max([np.max(indices) for indices in atomic_indices.values()]) + 1 

114 all_inds = [i for indices in atomic_indices.values() for i in indices] 

115 if len(all_inds) != n_atoms or len(set(all_inds)) != n_atoms: 

116 raise ValueError('atomic_indices is incomplete') 

117 

118 # collect velocities into a single array 

119 v = np.empty((n_atoms, 3)) 

120 for atom_type, indices in atomic_indices.items(): 

121 v[indices, :] = self.velocities_by_type[atom_type] 

122 return v 

123 

124 @property 

125 def frame_index(self) -> int: 

126 """ Index of the frame. """ 

127 return self._frame_index 

128 

129 def __str__(self) -> str: 

130 s = [f'Frame index {self.frame_index}'] 

131 for key, val in self.positions_by_type.items(): 

132 s.append(f' positions : {key} shape : {val.shape}') 

133 if self.velocities_by_type is not None: 

134 for key, val in self.velocities_by_type.items(): 

135 s.append(f' velocities : {key} shape : {val.shape}') 

136 return '\n'.join(s) 

137 

138 def _repr_html_(self) -> str: 

139 s = [f'<h3>{self.__class__.__name__}</h3>'] 

140 s += ['<table border="1" class="dataframe">'] 

141 s += ['<thead><tr><th style="text-align: left;">Field</th>' 

142 '<th>Value/Shape</th></tr></thead>'] 

143 s += ['<tbody>'] 

144 s += [f'<tr><td style="text-align: left;">Index</td><td>{self.frame_index}</td></tr>'] 

145 for key, val in self.positions_by_type.items(): 

146 s += [f'<tr><td style="text-align: left;">Positions {key}</td>' 

147 f'<td>{val.shape}</td></tr>'] 

148 if self.velocities_by_type is not None: 148 ↛ 152line 148 didn't jump to line 152, because the condition on line 148 was never false

149 for key, val in self.velocities_by_type.items(): 

150 s += [f'<tr><td style="text-align: left;">Velocities {key}</td>' 

151 f'<td>{val.shape}</td></tr>'] 

152 s += ['</tbody>'] 

153 s += ['</table>'] 

154 return '\n'.join(s)