Coverage for local_installation/dynasor/trajectory/ase_trajectory_reader.py: 88%

27 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-06-13 16:00 +0000

1import numpy as np 

2from ase import io 

3from dynasor.trajectory.abstract_trajectory_reader import AbstractTrajectoryReader 

4from dynasor.trajectory.trajectory_frame import ReaderFrame 

5from itertools import count 

6 

7 

8class ASETrajectoryReader(AbstractTrajectoryReader): 

9 """Read ASE trajectory file 

10 

11 ... 

12 

13 Parameters 

14 ---------- 

15 filename 

16 Name of input file. 

17 length_unit 

18 Unit of length for the input trajectory (``'Angstrom'``, ``'nm'``, ``'pm'``, ``'fm'``). 

19 time_unit 

20 Unit of time for the input trajectory (``'fs'``, ``'ps'``, ``'ns'``). 

21 """ 

22 

23 def __init__( 

24 self, 

25 filename: str, 

26 length_unit: str = 'Angstrom', 

27 time_unit: str = 'fs', 

28 ): 

29 self._frame_index = count(0) 

30 self._atoms = io.iread(filename, index=':') 

31 

32 # setup units 

33 if length_unit not in self.lengthunits_to_nm_table: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true

34 raise ValueError(f'Specified length unit {length_unit} is not an available option.') 

35 else: 

36 self.x_factor = self.lengthunits_to_nm_table[length_unit] 

37 if time_unit not in self.timeunits_to_fs_table: 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true

38 raise ValueError(f'Specified time unit {time_unit} is not an available option.') 

39 else: 

40 self.t_factor = self.timeunits_to_fs_table[time_unit] 

41 self.v_factor = self.x_factor / self.t_factor 

42 

43 def __iter__(self): 

44 return self 

45 

46 def close(self): 

47 pass 

48 

49 def __next__(self): 

50 ind = next(self._frame_index) 

51 a = next(self._atoms) 

52 if 'momenta' in a.arrays: 

53 vel = self.v_factor * a.get_velocities() 

54 else: 

55 vel = None 

56 return ReaderFrame( 

57 frame_index=ind, 

58 n_atoms=len(a), 

59 cell=self.x_factor * a.cell.array.copy('F'), 

60 positions=self.x_factor * a.get_positions(), 

61 velocities=vel, 

62 atom_types=np.array(list(a.symbols)), 

63 )