Coverage for dynasor / trajectory / trajectory.py: 94%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 12:31 +0000

1__all__ = ['Trajectory', 'WindowIterator'] 

2 

3from collections import deque 

4from itertools import islice, chain 

5from os.path import isfile 

6from typing import Callable, Optional, Union 

7 

8import numpy as np 

9from numpy.typing import NDArray 

10 

11from dynasor.trajectory.atomic_indices import parse_gromacs_index_file 

12from dynasor.trajectory.ase_trajectory_reader import ASETrajectoryReader 

13from dynasor.trajectory.extxyz_trajectory_reader import ExtxyzTrajectoryReader 

14from dynasor.trajectory.lammps_trajectory_reader import LammpsTrajectoryReader 

15from dynasor.trajectory.mdanalysis_trajectory_reader import MDAnalysisTrajectoryReader 

16from dynasor.trajectory.trajectory_frame import TrajectoryFrame 

17from dynasor.logging_tools import logger 

18 

19 

20class Trajectory: 

21 """Instances of this class hold trajectories in a format suitable for 

22 the computation of correlation functions. They behave as 

23 iterators, where each step returns the next frame as a 

24 :class:`TrajectoryFrame` object. The latter hold information 

25 regarding atomic positions, types, and velocities. 

26 

27 Parameters 

28 ---------- 

29 filename 

30 Name of input file. 

31 trajectory_format 

32 Type of trajectory. Possible values are: 

33 ``'lammps_internal'``, ``'extxyz'``, ``'ase'`` or one of the formats supported by 

34 `MDAnalysis <https://www.mdanalysis.org/>`_ (except for ``'lammpsdump'``, 

35 which can be called by specifying ``'lammps_mdanalysis'`` to avoid ambiguity) 

36 atomic_indices 

37 Specify which indices belong to which atom type. Can be 

38 (1) a dictionary where the keys specify the species and the values 

39 are a list of atomic indices, 

40 (2) ``'read_from_trajectory'``, in which case the species are read from the trajectory or 

41 (3) the path to a gromacs index file. 

42 length_unit 

43 Length unit of trajectory (``'Angstrom'``, ``'nm'``, ``'pm'``, ``'fm'``). Necessary for 

44 correct conversion to internal dynasor units if the trajectory file does not contain unit 

45 information. 

46 If no length unit is specified and the reader cannot read units from the trajectory, 

47 Angstrom is assumed. 

48 time_unit 

49 Time unit of trajectory (``'fs'``, ``'ps'``, ``'ns'``). Necessary for correct conversion to 

50 internal dynasor units if the trajectory file does not contain unit information. 

51 If no time unit is specified and the reader cannot read units from the trajectory, fs is 

52 assumed. 

53 frame_start 

54 First frame to read; must be larger or equal ``0``. 

55 frame_stop 

56 Last frame to read. By default (``None``) the entire trajectory is read. 

57 frame_step 

58 Read every :attr:`frame_step`-th step of the input trajectory. 

59 By default (``1``) every frame is read. Must be larger than ``0``. 

60 

61 """ 

62 def __init__( 

63 self, 

64 filename: str, 

65 trajectory_format: str, 

66 atomic_indices: Optional[Union[str, dict[str, list[int]]]] = None, 

67 length_unit: Optional[str] = None, 

68 time_unit: Optional[str] = None, 

69 frame_start: Optional[int] = 0, 

70 frame_stop: Optional[int] = None, 

71 frame_step: Optional[int] = 1 

72 ): 

73 

74 if frame_start < 0: 

75 raise ValueError('frame_start should be positive') 

76 if frame_step < 0: 

77 raise ValueError('frame_step should be positive') 

78 

79 self._frame_start = frame_start 

80 self._frame_step = frame_step 

81 self._frame_stop = frame_stop 

82 

83 # setup trajectory reader 

84 if not isfile(filename): 

85 raise IOError(f'File {filename} does not exist') 

86 self._filename = filename 

87 

88 if trajectory_format == 'lammps_internal': 

89 reader = LammpsTrajectoryReader 

90 elif trajectory_format == 'extxyz': 

91 reader = ExtxyzTrajectoryReader 

92 elif trajectory_format == 'lammps_mdanalysis': 

93 reader = MDAnalysisTrajectoryReader 

94 trajectory_format = 'lammpsdump' 

95 elif trajectory_format == 'ase': 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true

96 reader = ASETrajectoryReader 

97 elif trajectory_format == 'lammps': 

98 raise IOError('Ambiguous trajectory format, ' 

99 'did you mean lammps_internal or lammps_mdanalysis?') 

100 else: 

101 reader = MDAnalysisTrajectoryReader 

102 

103 logger.debug(f'Using trajectory reader: {reader.__name__}') 

104 if reader == MDAnalysisTrajectoryReader: 

105 self._reader_obj = reader(self._filename, trajectory_format, 

106 length_unit=length_unit, time_unit=time_unit) 

107 else: 

108 self._reader_obj = reader(self._filename, length_unit=length_unit, time_unit=time_unit) 

109 

110 # Get two frames to set cell etc. 

111 frame0 = next(self._reader_obj) 

112 frame1 = next(self._reader_obj) 

113 self._cell = frame0.cell 

114 self._n_atoms = frame0.n_atoms 

115 

116 # Make sure cell is not changed during consecutive frames 

117 if not np.allclose(frame0.cell, frame1.cell): 

118 raise ValueError('The cell changes between the first and second frame. ' 

119 'The concept of q-points becomes muddy if the simulation cell is ' 

120 'changing, such as during NPT MD simulations, so trajectories where ' 

121 'the cell changes are not supported by dynasor.') 

122 

123 # setup iterator slice (reuse frame0 and frame1 via chain) 

124 self.number_of_frames_read = 0 

125 self.current_frame_index = 0 

126 self._reader_iter = islice(chain([frame0, frame1], self._reader_obj), 

127 self._frame_start, self._frame_stop, self._frame_step) 

128 

129 # setup atomic indices 

130 if atomic_indices is None: # Default behaviour 

131 atomic_indices = {'X': np.arange(0, self.n_atoms)} 

132 elif isinstance(atomic_indices, str): # Str input 

133 if atomic_indices == 'read_from_trajectory': 

134 if frame0.atom_types is None: 

135 raise ValueError('Could not read atomic indices from the trajectory.') 

136 else: 

137 uniques = np.unique(frame0.atom_types) 

138 atomic_indices = {str(uniques[i]): 

139 (frame0.atom_types == uniques[i]).nonzero()[0] 

140 for i in range(len(uniques))} 

141 else: 

142 atomic_indices = parse_gromacs_index_file(atomic_indices) 

143 elif isinstance(atomic_indices, dict): # dict input 

144 pass 

145 else: 

146 raise ValueError('Could not understand atomic_indices.') 

147 self._atomic_indices = atomic_indices 

148 

149 # sanity checks for atomic_indices 

150 for key, indices in self._atomic_indices.items(): 

151 if np.max(indices) > self.n_atoms: 

152 raise ValueError('maximum index in atomic_indices exceeds number of atoms') 

153 if np.min(indices) < 0: 

154 raise ValueError('minimum index in atomic_indices is negative') 

155 if '_' in key: 

156 # Since '_' is what we use to distinguish atom types in the results, e.g. Sqw_Cs_Pb 

157 raise ValueError('The char "_" is not allowed in atomic_indices.') 

158 

159 # log info on trajectory and atom types etc 

160 logger.info(f'Trajectory file: {self.filename}') 

161 logger.info(f'Total number of particles: {self.n_atoms}') 

162 logger.info(f'Number of atom types: {len(self.atom_types)}') 

163 for atom_type, indices in self._atomic_indices.items(): 

164 logger.info(f'Number of atoms of type {atom_type}: {len(indices)}') 

165 logger.info(f'Simulation cell (in Angstrom):\n{str(self._cell)}') 

166 

167 def __iter__(self): 

168 return self 

169 

170 def __next__(self): 

171 frame = next(self._reader_iter) 

172 logger.debug(f'Read frame #{frame.frame_index}') 

173 new_frame = TrajectoryFrame( 

174 self.atomic_indices, frame.frame_index, frame.positions, frame.velocities) 

175 self.number_of_frames_read += 1 

176 self.current_frame_index = frame.frame_index 

177 return new_frame 

178 

179 def __str__(self) -> str: 

180 s = ['Trajectory'] 

181 s += ['{:12} : {}'.format('filename', self.filename)] 

182 s += ['{:12} : {}'.format('natoms', self.n_atoms)] 

183 s += ['{:12} : {}'.format('frame_start', self._frame_start)] 

184 s += ['{:12} : {}'.format('frame_stop', self._frame_stop)] 

185 s += ['{:12} : {}'.format('frame_step', self.frame_step)] 

186 s += ['{:12} : {}'.format('frame_index', self.current_frame_index)] 

187 s += ['{:12} : [{}\n {}\n {}]' 

188 .format('cell', self.cell[0], self.cell[1], self.cell[2])] 

189 return '\n'.join(s) 

190 

191 def __repr__(self) -> str: 

192 return str(self) 

193 

194 def _repr_html_(self) -> str: 

195 s = [f'<h3>{self.__class__.__name__}</h3>'] 

196 s += ['<table border="1" class="dataframe">'] 

197 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>'] 

198 s += ['<tbody>'] 

199 s += [f'<tr"><td style="text-align: left;">File name</td><td>{self.filename}</td></tr>'] 

200 s += [f'<tr><td style="text-align: left;">Number of atoms</td><td>{self.n_atoms}</td></tr>'] 

201 s += [f'<tr><td style="text-align: left;">Cell metric</td><td>{self.cell}</td></tr>'] 

202 s += [f'<tr><td style="text-align: left;">Frame step</td><td>{self.frame_step}</td></tr>'] 

203 s += [f'<tr><td style="text-align: left;">Atom types</td><td>{self.atom_types}</td></tr>'] 

204 s += ['</tbody>'] 

205 s += ['</table>'] 

206 return '\n'.join(s) 

207 

208 @property 

209 def cell(self) -> NDArray[float]: 

210 """ Simulation cell """ 

211 return self._cell 

212 

213 @property 

214 def n_atoms(self) -> int: 

215 """ Number of atoms """ 

216 return self._n_atoms 

217 

218 @property 

219 def filename(self) -> str: 

220 """ The trajectory filename """ 

221 return self._filename 

222 

223 @property 

224 def atomic_indices(self) -> dict[str, list[int]]: 

225 """ Return copy of index arrays """ 

226 atomic_indices = dict() 

227 for name, inds in self._atomic_indices.items(): 

228 atomic_indices[name] = inds.copy() 

229 return atomic_indices 

230 

231 @property 

232 def atom_types(self) -> list[str]: 

233 return sorted(self._atomic_indices.keys()) 

234 

235 @property 

236 def frame_step(self) -> int: 

237 """ Frame to access, trajectory will return every :attr:`frame_step`-th snapshot. """ 

238 return self._frame_step 

239 

240 

241def consume(iterator, n): 

242 """ Advance the iterator by :attr:`n` steps. If :attr:`n` is ``None``, consume entirely. """ 

243 # From python.org 

244 if n is None: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true

245 deque(iterator, maxlen=0) 

246 else: 

247 next(islice(iterator, n, n), None) 

248 

249 

250class WindowIterator: 

251 """Sliding window iterator. 

252 

253 Returns consecutive windows (a window is represented as a list 

254 of objects), created from an input iterator. 

255 

256 Parameters 

257 ---------- 

258 itraj 

259 Trajectory object. 

260 width 

261 Length of window (``window_size`` + 1). 

262 window_step 

263 Distance between the start of two consecutive window frames. 

264 element_processor 

265 Optional function applied to each frame before it is stored in the window. 

266 Useful for pre-computing per-frame quantities (e.g., reciprocal-space densities) 

267 so that the work is not repeated for frames shared between consecutive windows. 

268 """ 

269 def __init__(self, 

270 itraj: Trajectory, 

271 width: int, 

272 window_step: Optional[int] = 1, 

273 element_processor: Optional[Callable] = None): 

274 

275 self._raw_it = itraj 

276 if element_processor: 

277 self._it = map(element_processor, self._raw_it) 

278 else: 

279 self._it = self._raw_it 

280 assert window_step >= 1 

281 assert width >= 1 

282 self.width = width 

283 self.window_step = window_step 

284 self._window = None 

285 

286 def __iter__(self): 

287 return self 

288 

289 def __next__(self): 

290 """ Returns next element in sequence. """ 

291 if self._window is None: 

292 self._window = deque(islice(self._it, self.width), self.width) 

293 else: 

294 if self.window_step >= self.width: 

295 self._window.clear() 

296 consume(self._raw_it, self.window_step - self.width) 

297 else: 

298 for _ in range(min((self.window_step, len(self._window)))): 

299 self._window.popleft() 

300 for f in islice(self._it, min((self.window_step, self.width))): 

301 self._window.append(f) 

302 

303 if len(self._window) == 0: 

304 raise StopIteration 

305 

306 return list(self._window)