Coverage for local_installation/dynasor/trajectory/trajectory_frame.py: 99%
66 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-05 09:53 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-05 09:53 +0000
1from dataclasses import dataclass
2import numpy as np
3from typing import Dict, List
4from numpy.typing import NDArray
7@dataclass
8class ReaderFrame:
9 """Trivial data struct holding MD-data for one time frame
11 Parameters
12 ----------
13 frame_index
14 Trajectory index of the snapshot (frame)
15 cell
16 Simulation cell as 3 row vectors (Å)
17 n_atoms
18 number of atoms
19 positions
20 particle positions as 3xn_atoms array (Å)
21 velocities
22 particle velocities as 3xn_atoms array (Å/fs);
23 may not be available, depending on reader and trajectory file format
24 atom_types
25 Array with the type of each atom;
26 may not be available, depending on reader and trajectory file format
27 """
28 frame_index: int
29 cell: np.ndarray
30 n_atoms: int
31 positions: np.ndarray
32 velocities: np.ndarray = None
33 atom_types: NDArray[str] = None
36class TrajectoryFrame:
37 """
38 Class holding positions and optionally velocities split by atom type
39 for one snapshot (frame) in a trajectory
41 Attributes
42 ----------
43 * positions_by_type
44 * velocities_by_type
46 such that e.g.
47 positions_by_type['Cs'] numpy array with shape (n_atoms_Cs, 3)
48 positions_by_type['Pb'] numpy array with shape (n_atoms_Pb, 3)
50 Parameters
51 ----------
52 atomic_indices
53 Dictionary specifying which indices (dict values) belong to which atom type (dict keys)
54 frame_index
55 Trajectory index of the snapshot (frame)
56 positions
57 Positions as an array with shape ``(n_atoms, 3)``
58 velocities
59 Velocities as an array with shape ``(n_atoms, 3)``; defaults to ``None``
60 """
62 def __init__(self,
63 atomic_indices: Dict[str, List[int]],
64 frame_index: int,
65 positions: np.ndarray,
66 velocities: np.ndarray = None):
67 self._frame_index = frame_index
69 self.positions_by_type = dict()
70 for atom_type, indices in atomic_indices.items():
71 self.positions_by_type[atom_type] = positions[indices, :].copy()
73 if velocities is not None:
74 self.velocities_by_type = dict()
75 for atom_type, indices in atomic_indices.items():
76 self.velocities_by_type[atom_type] = velocities[indices, :].copy()
77 else:
78 self.velocities_by_type = None
80 def get_positions_as_array(self, atomic_indices: Dict[str, List[int]]):
81 """
82 Construct the full positions array with shape ``(n_atoms, 3)``.
84 Parameters
85 ---------
86 atomic_indices
87 Dictionary specifying which indices (dict values) belong to which atom type (dict keys)
88 """
90 # check that atomic_indices is complete
91 n_atoms = np.max([np.max(indices) for indices in atomic_indices.values()]) + 1
92 all_inds = [i for indices in atomic_indices.values() for i in indices]
93 if len(all_inds) != n_atoms or len(set(all_inds)) != n_atoms:
94 raise ValueError('atomic_indices is incomplete')
96 # collect positions into a single array
97 x = np.empty((n_atoms, 3))
98 for atom_type, indices in atomic_indices.items():
99 x[indices, :] = self.positions_by_type[atom_type]
100 return x
102 def get_velocities_as_array(self, atomic_indices: Dict[str, List[int]]):
103 """
104 Construct the full velocities array with shape ``(n_atoms, 3)``.
106 Parameters
107 ---------
108 atomic_indices
109 Dictionary specifying which indices (dict values) belong to which atom type (dict keys)
110 """
112 # check that atomic_indices is complete
113 n_atoms = np.max([np.max(indices) for indices in atomic_indices.values()]) + 1
114 all_inds = [i for indices in atomic_indices.values() for i in indices]
115 if len(all_inds) != n_atoms or len(set(all_inds)) != n_atoms:
116 raise ValueError('atomic_indices is incomplete')
118 # collect velocities into a single array
119 v = np.empty((n_atoms, 3))
120 for atom_type, indices in atomic_indices.items():
121 v[indices, :] = self.velocities_by_type[atom_type]
122 return v
124 @property
125 def frame_index(self) -> int:
126 """ Index of the frame. """
127 return self._frame_index
129 def __str__(self) -> str:
130 s = [f'Frame index {self.frame_index}']
131 for key, val in self.positions_by_type.items():
132 s.append(f' positions : {key} shape : {val.shape}')
133 if self.velocities_by_type is not None:
134 for key, val in self.velocities_by_type.items():
135 s.append(f' velocities : {key} shape : {val.shape}')
136 return '\n'.join(s)
138 def _repr_html_(self) -> str:
139 s = [f'<h3>{self.__class__.__name__}</h3>']
140 s += ['<table border="1" class="dataframe">']
141 s += ['<thead><tr><th style="text-align: left;">Field</th>'
142 '<th>Value/Shape</th></tr></thead>']
143 s += ['<tbody>']
144 s += [f'<tr><td style="text-align: left;">Index</td><td>{self.frame_index}</td></tr>']
145 for key, val in self.positions_by_type.items():
146 s += [f'<tr><td style="text-align: left;">Positions {key}</td>'
147 f'<td>{val.shape}</td></tr>']
148 if self.velocities_by_type is not None: 148 ↛ 152line 148 didn't jump to line 152, because the condition on line 148 was never false
149 for key, val in self.velocities_by_type.items():
150 s += [f'<tr><td style="text-align: left;">Velocities {key}</td>'
151 f'<td>{val.shape}</td></tr>']
152 s += ['</tbody>']
153 s += ['</table>']
154 return '\n'.join(s)