Coverage for dynasor / correlation_functions.py: 94%

299 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-10 06:28 +0000

1import concurrent 

2from functools import partial 

3from itertools import combinations_with_replacement 

4from typing import Optional 

5 

6import numba 

7import numpy as np 

8from ase import Atoms 

9from ase.units import fs 

10from numpy.typing import NDArray 

11 

12from dynasor.logging_tools import logger 

13from dynasor.trajectory import Trajectory, WindowIterator 

14from dynasor.sample import DynamicSample, StaticSample 

15 

16from dynasor.tools.acfs import psd_from_acf_2d 

17from dynasor.core.time_averager import TimeAverager, truncate_nans 

18from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q 

19from dynasor.tools.structures import get_offset_index 

20from dynasor.units import radians_per_fs_to_meV 

21 

22 

23def compute_dynamic_structure_factors( 

24 traj: Trajectory, 

25 q_points: NDArray[float], 

26 dt: float, 

27 window_size: int, 

28 window_step: Optional[int] = 1, 

29 calculate_currents: Optional[bool] = False, 

30 calculate_incoherent: Optional[bool] = False, 

31 logging_interval: Optional[int] = 1000, 

32) -> DynamicSample: 

33 r"""Compute the dynamic structure factors. The results are returned in the 

34 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>` 

35 object. 

36 

37 Parameters 

38 ---------- 

39 traj 

40 Input trajectory. 

41 q_points 

42 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates. 

43 dt 

44 Time difference in femtoseconds between two consecutive snapshots 

45 in the trajectory. Note that you should *not* change :attr:`dt` if you change 

46 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

47 window_size 

48 Length of the trajectory frame window to use for time correlation calculation. 

49 It is expressed in terms of the number of time lags to consider 

50 and thus determines the smallest frequency resolved. 

51 window_step 

52 Window step (or stride) given as the number of frames between consecutive trajectory 

53 windows. This parameter does *not* affect the time between consecutive frames in the 

54 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not 

55 be used. 

56 calculate_currents 

57 Calculate the current correlations. Requires velocities to be available in :attr:`traj`. 

58 calculate_incoherent 

59 Calculate the incoherent part (self-part) of :math:`F_\text{incoh}`. 

60 logging_interval 

61 Log progress at ``INFO`` level every this many windows. Set to ``0`` to disable 

62 progress logging. 

63 """ 

64 # sanity check input args 

65 if q_points.shape[1] != 3: 

66 raise ValueError('q-points array has the wrong shape.') 

67 if dt <= 0: 

68 raise ValueError(f'dt must be positive: dt= {dt}') 

69 if window_size <= 2: 

70 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}') 

71 if window_step <= 0: 

72 raise ValueError(f'window_step must be positive: window_step= {window_step}') 

73 

74 # define internal parameters 

75 n_qpoints = q_points.shape[0] 

76 delta_t = traj.frame_step * dt 

77 N_tc = window_size + 1 

78 

79 # log all setup information 

80 dw = np.pi / (window_size * delta_t) 

81 w_max = dw * window_size 

82 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency 

83 

84 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}') 

85 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs') 

86 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs') 

87 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs') 

88 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = ' 

89 f'{dw * radians_per_fs_to_meV:.3f} meV') 

90 logger.info(f'Maximum angular frequency (dw * window_size):' 

91 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV') 

92 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):' 

93 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV') 

94 

95 if calculate_currents: 

96 logger.info('Calculating current (velocity) correlations') 

97 if calculate_incoherent: 

98 logger.info('Calculating incoherent part (self-part) of correlations') 

99 

100 # log some info regarding q-points 

101 logger.info(f'Number of q-points: {n_qpoints}') 

102 

103 q_directions = q_points.copy() 

104 q_distances = np.linalg.norm(q_points, axis=1) 

105 nonzero = q_distances > 0 

106 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1) 

107 

108 # setup functions to process frames 

109 def f2_rho(frame): 

110 rho_qs_dict = dict() 

111 for atom_type in frame.positions_by_type.keys(): 

112 x = frame.positions_by_type[atom_type] 

113 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

114 frame.rho_qs_dict = rho_qs_dict 

115 return frame 

116 

117 def f2_rho_and_j(frame): 

118 rho_qs_dict = dict() 

119 jz_qs_dict = dict() 

120 jper_qs_dict = dict() 

121 

122 for atom_type in frame.positions_by_type.keys(): 

123 x = frame.positions_by_type[atom_type] 

124 v = frame.velocities_by_type[atom_type] 

125 rho_qs, j_qs = calc_rho_j_q(x, v, q_points) 

126 jz_qs = np.sum(j_qs * q_directions, axis=1) 

127 jper_qs = j_qs - (jz_qs[:, None] * q_directions) 

128 

129 rho_qs_dict[atom_type] = rho_qs 

130 jz_qs_dict[atom_type] = jz_qs 

131 jper_qs_dict[atom_type] = jper_qs 

132 

133 frame.rho_qs_dict = rho_qs_dict 

134 frame.jz_qs_dict = jz_qs_dict 

135 frame.jper_qs_dict = jper_qs_dict 

136 return frame 

137 

138 if calculate_currents: 

139 element_processor = f2_rho_and_j 

140 else: 

141 element_processor = f2_rho 

142 

143 # setup window iterator 

144 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step, 

145 element_processor=element_processor) 

146 

147 # define all pairs 

148 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

149 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

150 logger.debug('Considering pairs:') 

151 for pair in pairs: 

152 logger.debug(f' {pair}') 

153 

154 # set up all time averager instances 

155 F_q_t_averager = dict() 

156 for pair in pairs: 

157 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

158 if calculate_currents: 

159 Cl_q_t_averager = dict() 

160 Ct_q_t_averager = dict() 

161 for pair in pairs: 

162 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

163 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

164 if calculate_incoherent: 

165 F_s_q_t_averager = dict() 

166 for pair in traj.atom_types: 

167 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

168 

169 # define correlation function 

170 def calc_corr(window, time_i): 

171 # Calculate correlations between two frames in the window without normalization 1/N 

172 f0 = window[0] 

173 fi = window[time_i] 

174 for s1, s2 in pairs: 

175 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate()) 

176 if s1 != s2: 

177 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate()) 

178 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt) 

179 

180 if calculate_currents: 

181 for s1, s2 in pairs: 

182 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate()) 

183 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] * 

184 fi.jper_qs_dict[s2].conjugate(), axis=1)) 

185 if s1 != s2: 

186 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate()) 

187 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] * 

188 fi.jper_qs_dict[s1].conjugate(), axis=1)) 

189 

190 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt) 

191 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt) 

192 

193 if calculate_incoherent: 

194 for atom_type in traj.atom_types: 

195 xi = fi.positions_by_type[atom_type] 

196 x0 = f0.positions_by_type[atom_type] 

197 Fsqt = np.real(calc_rho_q(xi - x0, q_points)) 

198 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt) 

199 

200 # run calculation 

201 with concurrent.futures.ThreadPoolExecutor() as tpe: 

202 # This is the "main loop" over the trajectory 

203 for window in window_iterator: 

204 if logging_interval and window[0].frame_index % logging_interval == 0: 

205 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa 

206 else: 

207 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa 

208 

209 # The map conviniently applies calc_corr to all time-lags. However, 

210 # as everything is done in place nothing gets returned so in order 

211 # to start and wait for the processes to finish we must iterate 

212 # over the None values returned 

213 for _ in tpe.map(partial(calc_corr, window), range(len(window))): 

214 pass 

215 

216 # collect results into dict with numpy arrays (n_qpoints, N_tc) 

217 data_dict_corr = dict() 

218 data_dict_corr['q_points'] = q_points 

219 

220 time = None 

221 for pair in pairs: 

222 key = '_'.join(pair) 

223 F_q_t = 1 / traj.n_atoms * truncate_nans(F_q_t_averager[pair].get_average_all()) 

224 w, S_q_w = psd_from_acf_2d(F_q_t, delta_t) 

225 S_q_w = np.array(S_q_w) 

226 

227 if time is None: 

228 # Determine the actual length of time signal after truncation (if needed) 

229 time = delta_t * np.arange(F_q_t.shape[1], dtype=float) 

230 N_tc_actual = len(time) 

231 F_q_t_tot = np.zeros((n_qpoints, N_tc_actual)) 

232 S_q_w_tot = np.zeros((n_qpoints, N_tc_actual)) 

233 else: 

234 assert F_q_t.shape[1] == len(time) 

235 

236 data_dict_corr['omega'] = w 

237 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t 

238 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w 

239 

240 # sum all partials to the total 

241 F_q_t_tot += F_q_t 

242 S_q_w_tot += S_q_w 

243 

244 if N_tc_actual < N_tc: 

245 logger.warning('Truncating ACF due to NaNs, likely time_window is longer than length of Trajectory') # noqa 

246 data_dict_corr['time'] = time 

247 data_dict_corr['Fqt_coh'] = F_q_t_tot 

248 data_dict_corr['Sqw_coh'] = S_q_w_tot 

249 

250 if calculate_currents: 

251 Cl_q_t_tot = np.zeros((n_qpoints, N_tc_actual)) 

252 Ct_q_t_tot = np.zeros((n_qpoints, N_tc_actual)) 

253 Cl_q_w_tot = np.zeros((n_qpoints, N_tc_actual)) 

254 Ct_q_w_tot = np.zeros((n_qpoints, N_tc_actual)) 

255 for pair in pairs: 

256 key = '_'.join(pair) 

257 Cl_q_t = 1 / traj.n_atoms * truncate_nans(Cl_q_t_averager[pair].get_average_all()) 

258 Ct_q_t = 1 / traj.n_atoms * truncate_nans(Ct_q_t_averager[pair].get_average_all()) 

259 _, Cl_q_w = psd_from_acf_2d(Cl_q_t, delta_t) 

260 _, Ct_q_w = psd_from_acf_2d(Ct_q_t, delta_t) 

261 data_dict_corr[f'Clqt_{key}'] = Cl_q_t 

262 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t 

263 data_dict_corr[f'Clqw_{key}'] = Cl_q_w 

264 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w 

265 

266 # sum all partials to the total 

267 Cl_q_t_tot += Cl_q_t 

268 Ct_q_t_tot += Ct_q_t 

269 Cl_q_w_tot += Cl_q_w 

270 Ct_q_w_tot += Ct_q_w 

271 data_dict_corr['Clqt'] = Cl_q_t_tot 

272 data_dict_corr['Ctqt'] = Ct_q_t_tot 

273 data_dict_corr['Clqw'] = Cl_q_w_tot 

274 data_dict_corr['Ctqw'] = Ct_q_w_tot 

275 

276 if calculate_incoherent: 

277 Fs_q_t_tot = np.zeros((n_qpoints, N_tc_actual)) 

278 Ss_q_w_tot = np.zeros((n_qpoints, N_tc_actual)) 

279 for atom_type in traj.atom_types: 

280 Fs_q_t = 1 / traj.n_atoms * truncate_nans(F_s_q_t_averager[atom_type].get_average_all()) 

281 _, Ss_q_w = psd_from_acf_2d(Fs_q_t, delta_t) 

282 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t 

283 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w 

284 

285 # sum all partials to the total 

286 Fs_q_t_tot += Fs_q_t 

287 Ss_q_w_tot += Ss_q_w 

288 

289 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot 

290 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot 

291 

292 # finalize results with additional metadata 

293 new_sample = DynamicSample( 

294 data_dict_corr, 

295 simulation_data=dict( 

296 atom_types=traj.atom_types, pairs=pairs, 

297 particle_counts=particle_counts, 

298 cell=traj.cell, 

299 time_between_frames=delta_t, 

300 maximum_time_lag=delta_t * window_size, 

301 angular_frequency_resolution=dw, 

302 maximum_angular_frequency=w_max, 

303 number_of_frames=traj.number_of_frames_read, 

304 )) 

305 new_sample._append_history( 

306 'compute_dynamic_structure_factors', 

307 dict( 

308 dt=dt, 

309 window_size=window_size, 

310 window_step=window_step, 

311 calculate_currents=calculate_currents, 

312 calculate_incoherent=calculate_incoherent, 

313 )) 

314 

315 return new_sample 

316 

317 

318def compute_static_structure_factors( 

319 traj: Trajectory, 

320 q_points: NDArray[float], 

321 logging_interval: Optional[int] = 1000, 

322) -> StaticSample: 

323 r"""Compute the static structure factors. The results are returned in the 

324 form of a :class:`StaticSample <dynasor.sample.StaticSample>` 

325 object. 

326 

327 Parameters 

328 ---------- 

329 traj 

330 Input trajectory. 

331 q_points 

332 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates. 

333 logging_interval 

334 Log progress at ``INFO`` level every this many frames. Set to ``0`` to disable 

335 progress logging. 

336 """ 

337 # sanity check input args 

338 if q_points.shape[1] != 3: 

339 raise ValueError('q-points array has the wrong shape.') 

340 

341 n_qpoints = q_points.shape[0] 

342 logger.info(f'Number of q-points: {n_qpoints}') 

343 

344 # define all pairs 

345 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

346 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

347 logger.debug('Considering pairs:') 

348 for pair in pairs: 

349 logger.debug(f' {pair}') 

350 

351 # processing function 

352 def f2_rho(frame): 

353 rho_qs_dict = dict() 

354 for atom_type in frame.positions_by_type.keys(): 

355 x = frame.positions_by_type[atom_type] 

356 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

357 frame.rho_qs_dict = rho_qs_dict 

358 return frame 

359 

360 # setup averager 

361 Sq_averager = dict() 

362 for pair in pairs: 

363 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0 

364 

365 # main loop 

366 for frame in traj: 

367 

368 # process_frame 

369 f2_rho(frame) 

370 if logging_interval and frame.frame_index % logging_interval == 0: 

371 logger.info(f'Processing frame {frame.frame_index}') 

372 else: 

373 logger.debug(f'Processing frame {frame.frame_index}') 

374 

375 for s1, s2 in pairs: 

376 # compute correlation 

377 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate()) 

378 if s1 != s2: 

379 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate()) 

380 Sq_averager[(s1, s2)].add_sample(0, Sq_pair) 

381 

382 # collect results 

383 data_dict = dict() 

384 data_dict['q_points'] = q_points 

385 S_q_tot = np.zeros((n_qpoints, 1)) 

386 for s1, s2 in pairs: 

387 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1) 

388 data_dict[f'Sq_{s1}_{s2}'] = Sq 

389 S_q_tot += Sq 

390 data_dict['Sq'] = S_q_tot 

391 

392 # finalize results 

393 new_sample = StaticSample( 

394 data_dict, 

395 simulation_data=dict( 

396 atom_types=traj.atom_types, 

397 pairs=pairs, 

398 particle_counts=particle_counts, 

399 cell=traj.cell, 

400 number_of_frames=traj.number_of_frames_read, 

401 )) 

402 new_sample._append_history('compute_static_structure_factors') 

403 

404 return new_sample 

405 

406 

407def compute_spectral_energy_density( 

408 traj: Trajectory, 

409 ideal_supercell: Atoms, 

410 primitive_cell: Atoms, 

411 q_points: NDArray[float], 

412 dt: float, 

413 partial: Optional[bool] = False, 

414 logging_interval: Optional[int] = 1000, 

415) -> tuple[NDArray[float], NDArray[float]]: 

416 r""" 

417 Compute the spectral energy density (SED) at specific q-points. The results 

418 are returned in the form of a tuple, which comprises the angular 

419 frequencies in an array of length ``N_times`` in units of rad/fs and the 

420 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``. 

421 The normalization is chosen such that integrating the SED of a q-point 

422 together with the supplied angular frequencies omega (rad/fs) yields 

423 `1/2 kB T` * number of bands (where number of bands = `len(prim) * 3`) 

424 

425 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010), 

426 which should be cited when using this function along with the dynasor reference. 

427 

428 **Note 1:** 

429 SED analysis is only suitable for crystalline materials without diffusion as 

430 atoms are assumed to move around fixed reference positions throughout the entire trajectory. 

431 

432 **Note 2:** 

433 This implementation reads the full trajectory and can thus consume a lot of memory. 

434 

435 Parameters 

436 ---------- 

437 traj 

438 Input trajectory. 

439 ideal_supercell 

440 Ideal structure defining the reference positions. Do not change the masses 

441 in the ASE :class:`Atoms` objects to dynasor internal units, this will be 

442 done internally. 

443 primitive_cell 

444 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`. 

445 q_points 

446 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates. 

447 dt 

448 Time difference in femtoseconds between two consecutive snapshots in 

449 the trajectory. Note that you should not change :attr:`dt` if you change 

450 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

451 partial 

452 If True the SED will be returned decomposed per basis and Cartesian direction. 

453 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)``. 

454 logging_interval 

455 Log progress at ``INFO`` level every this many frames. Set to ``0`` to disable 

456 progress logging. 

457 """ 

458 

459 delta_t = traj.frame_step * dt 

460 

461 # logger 

462 logger.info('Running SED') 

463 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs') 

464 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}') 

465 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}') 

466 logger.info(f'Number of q-points: {q_points.shape[0]}') 

467 

468 # check that the ideal supercell agrees with traj 

469 if traj.n_atoms != len(ideal_supercell): 

470 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

471 

472 if len(primitive_cell) >= len(ideal_supercell): 

473 raise ValueError('primitive_cell contains more atoms than ideal_supercell.') 

474 

475 # colllect all velocities, and scale with sqrt(masses) 

476 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu 

477 velocities = [] 

478 for it, frame in enumerate(traj): 

479 if logging_interval and it % logging_interval == 0: 

480 logger.info(f'Reading frame {it}') 

481 else: 

482 logger.debug(f'Reading frame {it}') 

483 if frame.velocities_by_type is None: 

484 raise ValueError(f'Could not read velocities from frame {it}') 

485 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs 

486 velocities.append(np.sqrt(masses) * v) 

487 logger.info(f'Number of snapshots: {len(velocities)}') 

488 

489 # Perform the FFT on the last axis for extra speed (maybe not needed) 

490 N_samples = len(velocities) 

491 velocities = np.array(velocities) 

492 # places time index last and makes a copy for continuity 

493 velocities = velocities.transpose(1, 2, 0).copy() 

494 # #atoms in supercell x 3 directions x #frequencies 

495 velocities = np.fft.rfft(velocities, axis=2) 

496 

497 # Calcualte indices and offsets needed for the sed method 

498 offsets, indices = get_offset_index(primitive_cell, ideal_supercell) 

499 

500 # Phase factor for use in FT. #qpoints x #atoms in supercell 

501 cell_positions = np.dot(offsets, primitive_cell.cell) 

502 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells 

503 phase_factors = np.exp(1.0j * phase) 

504 

505 # This dict maps the offsets to an index so ndarrays can be over 

506 # offset,index instead of atoms in supercell 

507 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))} 

508 

509 # Pick out some shapes 

510 n_super, _, n_w = velocities.shape 

511 n_qpts = len(q_points) 

512 n_prim = len(primitive_cell) 

513 n_offsets = len(offset_dict) 

514 

515 # This new array will be indexed by index and offset instead (and also transposed) 

516 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype) 

517 

518 for i in range(n_super): 

519 j = indices[i] # atom with index i in the supercell is of basis type j ... 

520 n = offset_dict[tuple(offsets[i])] # and its offset has index n 

521 new_velocities[:, :, j, n] = velocities[i].T 

522 

523 velocities = new_velocities 

524 

525 # Same story with the spatial phase factors 

526 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype) 

527 

528 for i in range(n_super): 

529 j = indices[i] 

530 n = offset_dict[tuple(offsets[i])] 

531 new_phase_factors[:, j, n] = phase_factors[:, i] 

532 

533 phase_factors = new_phase_factors 

534 

535 # calculate the density in a numba function 

536 density = _sed_inner_loop(phase_factors, velocities) 

537 

538 if not partial: 

539 density = np.sum(density, axis=(2, 3)) 

540 

541 # units 

542 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da 

543 

544 # the time delta in the fourier transform 

545 density = density * delta_t**2 

546 

547 # Divide by the length of the time signal 

548 density = density / (N_samples * delta_t) 

549 

550 # Divide by the number of primitive cells 

551 density = density / (n_super / n_prim) 

552 

553 # Factor so the sed can be integrated together with the returned omega 

554 # numpy fft works with ordinary/linear frequencies and not angular freqs 

555 density = density / (2*np.pi) 

556 

557 # angular frequencies 

558 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs 

559 

560 return w, density 

561 

562 

563@numba.njit(parallel=True, fastmath=True) 

564def _sed_inner_loop(phase_factors, velocities): 

565 """This numba function calculates the spatial 

566 Fourier transform using precomputed phase factors. 

567 

568 As the use case can be one or many q-points the parallelization is over the 

569 temporal frequency components instead. 

570 """ 

571 

572 n_qpts = phase_factors.shape[0] # q-point index 

573 n_prim = phase_factors.shape[1] # basis atom index 

574 n_super = phase_factors.shape[2] # unit cell index 

575 

576 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell 

577 

578 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64) 

579 

580 for w in numba.prange(n_freqs): 

581 for k in range(n_qpts): 

582 for a in range(3): 

583 for b in range(n_prim): 

584 tmp = 0.0j 

585 for n in range(n_super): 

586 tmp += phase_factors[k, b, n] * velocities[w, a, b, n] 

587 density[k, w, b, a] += np.abs(tmp)**2 

588 return density