Coverage for dynasor / correlation_functions.py: 94%

292 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 12:31 +0000

1import concurrent 

2from functools import partial 

3from itertools import combinations_with_replacement 

4from typing import Optional 

5 

6import numba 

7import numpy as np 

8from ase import Atoms 

9from ase.units import fs 

10from numpy.typing import NDArray 

11 

12from dynasor.logging_tools import logger 

13from dynasor.trajectory import Trajectory, WindowIterator 

14from dynasor.sample import DynamicSample, StaticSample 

15 

16from dynasor.post_processing import fourier_cos_filon 

17from dynasor.core.time_averager import TimeAverager 

18from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q 

19from dynasor.tools.structures import get_offset_index 

20from dynasor.units import radians_per_fs_to_meV 

21 

22 

23def compute_dynamic_structure_factors( 

24 traj: Trajectory, 

25 q_points: NDArray[float], 

26 dt: float, 

27 window_size: int, 

28 window_step: Optional[int] = 1, 

29 calculate_currents: Optional[bool] = False, 

30 calculate_incoherent: Optional[bool] = False, 

31) -> DynamicSample: 

32 r"""Compute the dynamic structure factors. The results are returned in the 

33 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>` 

34 object. 

35 

36 Parameters 

37 ---------- 

38 traj 

39 Input trajectory. 

40 q_points 

41 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates. 

42 dt 

43 Time difference in femtoseconds between two consecutive snapshots 

44 in the trajectory. Note that you should *not* change :attr:`dt` if you change 

45 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

46 window_size 

47 Length of the trajectory frame window to use for time correlation calculation. 

48 It is expressed in terms of the number of time lags to consider 

49 and thus determines the smallest frequency resolved. 

50 window_step 

51 Window step (or stride) given as the number of frames between consecutive trajectory 

52 windows. This parameter does *not* affect the time between consecutive frames in the 

53 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not 

54 be used. 

55 calculate_currents 

56 Calculate the current correlations. Requires velocities to be available in :attr:`traj`. 

57 calculate_incoherent 

58 Calculate the incoherent part (self-part) of :math:`F_\text{incoh}`. 

59 """ 

60 # sanity check input args 

61 if q_points.shape[1] != 3: 

62 raise ValueError('q-points array has the wrong shape.') 

63 if dt <= 0: 

64 raise ValueError(f'dt must be positive: dt= {dt}') 

65 if window_size <= 2: 

66 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}') 

67 if window_size % 2 != 0: 

68 raise ValueError(f'window_size must be even: window_size= {window_size}') 

69 if window_step <= 0: 

70 raise ValueError(f'window_step must be positive: window_step= {window_step}') 

71 

72 # define internal parameters 

73 n_qpoints = q_points.shape[0] 

74 delta_t = traj.frame_step * dt 

75 N_tc = window_size + 1 

76 

77 # log all setup information 

78 dw = np.pi / (window_size * delta_t) 

79 w_max = dw * window_size 

80 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency 

81 

82 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}') 

83 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs') 

84 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs') 

85 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs') 

86 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = ' 

87 f'{dw * radians_per_fs_to_meV:.3f} meV') 

88 logger.info(f'Maximum angular frequency (dw * window_size):' 

89 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV') 

90 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):' 

91 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV') 

92 

93 if calculate_currents: 

94 logger.info('Calculating current (velocity) correlations') 

95 if calculate_incoherent: 

96 logger.info('Calculating incoherent part (self-part) of correlations') 

97 

98 # log some info regarding q-points 

99 logger.info(f'Number of q-points: {n_qpoints}') 

100 

101 q_directions = q_points.copy() 

102 q_distances = np.linalg.norm(q_points, axis=1) 

103 nonzero = q_distances > 0 

104 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1) 

105 

106 # setup functions to process frames 

107 def f2_rho(frame): 

108 rho_qs_dict = dict() 

109 for atom_type in frame.positions_by_type.keys(): 

110 x = frame.positions_by_type[atom_type] 

111 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

112 frame.rho_qs_dict = rho_qs_dict 

113 return frame 

114 

115 def f2_rho_and_j(frame): 

116 rho_qs_dict = dict() 

117 jz_qs_dict = dict() 

118 jper_qs_dict = dict() 

119 

120 for atom_type in frame.positions_by_type.keys(): 

121 x = frame.positions_by_type[atom_type] 

122 v = frame.velocities_by_type[atom_type] 

123 rho_qs, j_qs = calc_rho_j_q(x, v, q_points) 

124 jz_qs = np.sum(j_qs * q_directions, axis=1) 

125 jper_qs = j_qs - (jz_qs[:, None] * q_directions) 

126 

127 rho_qs_dict[atom_type] = rho_qs 

128 jz_qs_dict[atom_type] = jz_qs 

129 jper_qs_dict[atom_type] = jper_qs 

130 

131 frame.rho_qs_dict = rho_qs_dict 

132 frame.jz_qs_dict = jz_qs_dict 

133 frame.jper_qs_dict = jper_qs_dict 

134 return frame 

135 

136 if calculate_currents: 

137 element_processor = f2_rho_and_j 

138 else: 

139 element_processor = f2_rho 

140 

141 # setup window iterator 

142 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step, 

143 element_processor=element_processor) 

144 

145 # define all pairs 

146 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

147 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

148 logger.debug('Considering pairs:') 

149 for pair in pairs: 

150 logger.debug(f' {pair}') 

151 

152 # set up all time averager instances 

153 F_q_t_averager = dict() 

154 for pair in pairs: 

155 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

156 if calculate_currents: 

157 Cl_q_t_averager = dict() 

158 Ct_q_t_averager = dict() 

159 for pair in pairs: 

160 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

161 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

162 if calculate_incoherent: 

163 F_s_q_t_averager = dict() 

164 for pair in traj.atom_types: 

165 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

166 

167 # define correlation function 

168 def calc_corr(window, time_i): 

169 # Calculate correlations between two frames in the window without normalization 1/N 

170 f0 = window[0] 

171 fi = window[time_i] 

172 for s1, s2 in pairs: 

173 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate()) 

174 if s1 != s2: 

175 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate()) 

176 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt) 

177 

178 if calculate_currents: 

179 for s1, s2 in pairs: 

180 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate()) 

181 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] * 

182 fi.jper_qs_dict[s2].conjugate(), axis=1)) 

183 if s1 != s2: 

184 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate()) 

185 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] * 

186 fi.jper_qs_dict[s1].conjugate(), axis=1)) 

187 

188 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt) 

189 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt) 

190 

191 if calculate_incoherent: 

192 for atom_type in traj.atom_types: 

193 xi = fi.positions_by_type[atom_type] 

194 x0 = f0.positions_by_type[atom_type] 

195 Fsqt = np.real(calc_rho_q(xi - x0, q_points)) 

196 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt) 

197 

198 # run calculation 

199 logging_interval = 1000 

200 with concurrent.futures.ThreadPoolExecutor() as tpe: 

201 # This is the "main loop" over the trajectory 

202 for window in window_iterator: 

203 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') 

204 

205 if window[0].frame_index % logging_interval == 0: 

206 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa 

207 

208 # The map conviniently applies calc_corr to all time-lags. However, 

209 # as everything is done in place nothing gets returned so in order 

210 # to start and wait for the processes to finish we must iterate 

211 # over the None values returned 

212 for _ in tpe.map(partial(calc_corr, window), range(len(window))): 

213 pass 

214 

215 # collect results into dict with numpy arrays (n_qpoints, N_tc) 

216 data_dict_corr = dict() 

217 time = delta_t * np.arange(N_tc, dtype=float) 

218 data_dict_corr['q_points'] = q_points 

219 data_dict_corr['time'] = time 

220 

221 F_q_t_tot = np.zeros((n_qpoints, N_tc)) 

222 S_q_w_tot = np.zeros((n_qpoints, N_tc)) 

223 for pair in pairs: 

224 key = '_'.join(pair) 

225 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all() 

226 w, S_q_w = fourier_cos_filon(F_q_t, delta_t) 

227 S_q_w = np.array(S_q_w) 

228 data_dict_corr['omega'] = w 

229 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t 

230 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w 

231 

232 # sum all partials to the total 

233 F_q_t_tot += F_q_t 

234 S_q_w_tot += S_q_w 

235 data_dict_corr['Fqt_coh'] = F_q_t_tot 

236 data_dict_corr['Sqw_coh'] = S_q_w_tot 

237 

238 if calculate_currents: 

239 Cl_q_t_tot = np.zeros((n_qpoints, N_tc)) 

240 Ct_q_t_tot = np.zeros((n_qpoints, N_tc)) 

241 Cl_q_w_tot = np.zeros((n_qpoints, N_tc)) 

242 Ct_q_w_tot = np.zeros((n_qpoints, N_tc)) 

243 for pair in pairs: 

244 key = '_'.join(pair) 

245 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all() 

246 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all() 

247 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t) 

248 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t) 

249 data_dict_corr[f'Clqt_{key}'] = Cl_q_t 

250 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t 

251 data_dict_corr[f'Clqw_{key}'] = Cl_q_w 

252 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w 

253 

254 # sum all partials to the total 

255 Cl_q_t_tot += Cl_q_t 

256 Ct_q_t_tot += Ct_q_t 

257 Cl_q_w_tot += Cl_q_w 

258 Ct_q_w_tot += Ct_q_w 

259 data_dict_corr['Clqt'] = Cl_q_t_tot 

260 data_dict_corr['Ctqt'] = Ct_q_t_tot 

261 data_dict_corr['Clqw'] = Cl_q_w_tot 

262 data_dict_corr['Ctqw'] = Ct_q_w_tot 

263 

264 if calculate_incoherent: 

265 Fs_q_t_tot = np.zeros((n_qpoints, N_tc)) 

266 Ss_q_w_tot = np.zeros((n_qpoints, N_tc)) 

267 for atom_type in traj.atom_types: 

268 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all() 

269 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t) 

270 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t 

271 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w 

272 

273 # sum all partials to the total 

274 Fs_q_t_tot += Fs_q_t 

275 Ss_q_w_tot += Ss_q_w 

276 

277 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot 

278 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot 

279 

280 # finalize results with additional metadata 

281 new_sample = DynamicSample( 

282 data_dict_corr, 

283 simulation_data=dict( 

284 atom_types=traj.atom_types, pairs=pairs, 

285 particle_counts=particle_counts, 

286 cell=traj.cell, 

287 time_between_frames=delta_t, 

288 maximum_time_lag=delta_t * window_size, 

289 angular_frequency_resolution=dw, 

290 maximum_angular_frequency=w_max, 

291 number_of_frames=traj.number_of_frames_read, 

292 )) 

293 new_sample._append_history( 

294 'compute_dynamic_structure_factors', 

295 dict( 

296 dt=dt, 

297 window_size=window_size, 

298 window_step=window_step, 

299 calculate_currents=calculate_currents, 

300 calculate_incoherent=calculate_incoherent, 

301 )) 

302 

303 return new_sample 

304 

305 

306def compute_static_structure_factors( 

307 traj: Trajectory, 

308 q_points: NDArray[float], 

309) -> StaticSample: 

310 r"""Compute the static structure factors. The results are returned in the 

311 form of a :class:`StaticSample <dynasor.sample.StaticSample>` 

312 object. 

313 

314 Parameters 

315 ---------- 

316 traj 

317 Input trajectory. 

318 q_points 

319 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates. 

320 """ 

321 # sanity check input args 

322 if q_points.shape[1] != 3: 

323 raise ValueError('q-points array has the wrong shape.') 

324 

325 n_qpoints = q_points.shape[0] 

326 logger.info(f'Number of q-points: {n_qpoints}') 

327 

328 # define all pairs 

329 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

330 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

331 logger.debug('Considering pairs:') 

332 for pair in pairs: 

333 logger.debug(f' {pair}') 

334 

335 # processing function 

336 def f2_rho(frame): 

337 rho_qs_dict = dict() 

338 for atom_type in frame.positions_by_type.keys(): 

339 x = frame.positions_by_type[atom_type] 

340 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

341 frame.rho_qs_dict = rho_qs_dict 

342 return frame 

343 

344 # setup averager 

345 Sq_averager = dict() 

346 for pair in pairs: 

347 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0 

348 

349 # main loop 

350 for frame in traj: 

351 

352 # process_frame 

353 f2_rho(frame) 

354 logger.debug(f'Processing frame {frame.frame_index}') 

355 

356 for s1, s2 in pairs: 

357 # compute correlation 

358 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate()) 

359 if s1 != s2: 

360 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate()) 

361 Sq_averager[(s1, s2)].add_sample(0, Sq_pair) 

362 

363 # collect results 

364 data_dict = dict() 

365 data_dict['q_points'] = q_points 

366 S_q_tot = np.zeros((n_qpoints, 1)) 

367 for s1, s2 in pairs: 

368 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1) 

369 data_dict[f'Sq_{s1}_{s2}'] = Sq 

370 S_q_tot += Sq 

371 data_dict['Sq'] = S_q_tot 

372 

373 # finalize results 

374 new_sample = StaticSample( 

375 data_dict, 

376 simulation_data=dict( 

377 atom_types=traj.atom_types, 

378 pairs=pairs, 

379 particle_counts=particle_counts, 

380 cell=traj.cell, 

381 number_of_frames=traj.number_of_frames_read, 

382 )) 

383 new_sample._append_history('compute_static_structure_factors') 

384 

385 return new_sample 

386 

387 

388def compute_spectral_energy_density( 

389 traj: Trajectory, 

390 ideal_supercell: Atoms, 

391 primitive_cell: Atoms, 

392 q_points: NDArray[float], 

393 dt: float, 

394 partial: Optional[bool] = False 

395) -> tuple[NDArray[float], NDArray[float]]: 

396 r""" 

397 Compute the spectral energy density (SED) at specific q-points. The results 

398 are returned in the form of a tuple, which comprises the angular 

399 frequencies in an array of length ``N_times`` in units of rad/fs and the 

400 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``. 

401 The normalization is chosen such that integrating the SED of a q-point 

402 together with the supplied angular frequencies omega (rad/fs) yields 

403 `1/2 kB T` * number of bands (where number of bands = `len(prim) * 3`) 

404 

405 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010), 

406 which should be cited when using this function along with the dynasor reference. 

407 

408 **Note 1:** 

409 SED analysis is only suitable for crystalline materials without diffusion as 

410 atoms are assumed to move around fixed reference positions throughout the entire trajectory. 

411 

412 **Note 2:** 

413 This implementation reads the full trajectory and can thus consume a lot of memory. 

414 

415 Parameters 

416 ---------- 

417 traj 

418 Input trajectory. 

419 ideal_supercell 

420 Ideal structure defining the reference positions. Do not change the masses 

421 in the ASE :class:`Atoms` objects to dynasor internal units, this will be 

422 done internally. 

423 primitive_cell 

424 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`. 

425 q_points 

426 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates. 

427 dt 

428 Time difference in femtoseconds between two consecutive snapshots in 

429 the trajectory. Note that you should not change :attr:`dt` if you change 

430 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

431 partial 

432 If True the SED will be returned decomposed per basis and Cartesian direction. 

433 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)``. 

434 """ 

435 

436 delta_t = traj.frame_step * dt 

437 

438 # logger 

439 logger.info('Running SED') 

440 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs') 

441 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}') 

442 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}') 

443 logger.info(f'Number of q-points: {q_points.shape[0]}') 

444 

445 # check that the ideal supercell agrees with traj 

446 if traj.n_atoms != len(ideal_supercell): 

447 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

448 

449 if len(primitive_cell) >= len(ideal_supercell): 

450 raise ValueError('primitive_cell contains more atoms than ideal_supercell.') 

451 

452 # colllect all velocities, and scale with sqrt(masses) 

453 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu 

454 velocities = [] 

455 for it, frame in enumerate(traj): 

456 logger.debug(f'Reading frame {it}') 

457 if frame.velocities_by_type is None: 

458 raise ValueError(f'Could not read velocities from frame {it}') 

459 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs 

460 velocities.append(np.sqrt(masses) * v) 

461 logger.info(f'Number of snapshots: {len(velocities)}') 

462 

463 # Perform the FFT on the last axis for extra speed (maybe not needed) 

464 N_samples = len(velocities) 

465 velocities = np.array(velocities) 

466 # places time index last and makes a copy for continuity 

467 velocities = velocities.transpose(1, 2, 0).copy() 

468 # #atoms in supercell x 3 directions x #frequencies 

469 velocities = np.fft.rfft(velocities, axis=2) 

470 

471 # Calcualte indices and offsets needed for the sed method 

472 offsets, indices = get_offset_index(primitive_cell, ideal_supercell) 

473 

474 # Phase factor for use in FT. #qpoints x #atoms in supercell 

475 cell_positions = np.dot(offsets, primitive_cell.cell) 

476 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells 

477 phase_factors = np.exp(1.0j * phase) 

478 

479 # This dict maps the offsets to an index so ndarrays can be over 

480 # offset,index instead of atoms in supercell 

481 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))} 

482 

483 # Pick out some shapes 

484 n_super, _, n_w = velocities.shape 

485 n_qpts = len(q_points) 

486 n_prim = len(primitive_cell) 

487 n_offsets = len(offset_dict) 

488 

489 # This new array will be indexed by index and offset instead (and also transposed) 

490 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype) 

491 

492 for i in range(n_super): 

493 j = indices[i] # atom with index i in the supercell is of basis type j ... 

494 n = offset_dict[tuple(offsets[i])] # and its offset has index n 

495 new_velocities[:, :, j, n] = velocities[i].T 

496 

497 velocities = new_velocities 

498 

499 # Same story with the spatial phase factors 

500 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype) 

501 

502 for i in range(n_super): 

503 j = indices[i] 

504 n = offset_dict[tuple(offsets[i])] 

505 new_phase_factors[:, j, n] = phase_factors[:, i] 

506 

507 phase_factors = new_phase_factors 

508 

509 # calculate the density in a numba function 

510 density = _sed_inner_loop(phase_factors, velocities) 

511 

512 if not partial: 

513 density = np.sum(density, axis=(2, 3)) 

514 

515 # units 

516 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da 

517 

518 # the time delta in the fourier transform 

519 density = density * delta_t**2 

520 

521 # Divide by the length of the time signal 

522 density = density / (N_samples * delta_t) 

523 

524 # Divide by the number of primitive cells 

525 density = density / (n_super / n_prim) 

526 

527 # Factor so the sed can be integrated together with the returned omega 

528 # numpy fft works with ordinary/linear frequencies and not angular freqs 

529 density = density / (2*np.pi) 

530 

531 # angular frequencies 

532 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs 

533 

534 return w, density 

535 

536 

537@numba.njit(parallel=True, fastmath=True) 

538def _sed_inner_loop(phase_factors, velocities): 

539 """This numba function calculates the spatial 

540 Fourier transform using precomputed phase factors. 

541 

542 As the use case can be one or many q-points the parallelization is over the 

543 temporal frequency components instead. 

544 """ 

545 

546 n_qpts = phase_factors.shape[0] # q-point index 

547 n_prim = phase_factors.shape[1] # basis atom index 

548 n_super = phase_factors.shape[2] # unit cell index 

549 

550 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell 

551 

552 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64) 

553 

554 for w in numba.prange(n_freqs): 

555 for k in range(n_qpts): 

556 for a in range(3): 

557 for b in range(n_prim): 

558 tmp = 0.0j 

559 for n in range(n_super): 

560 tmp += phase_factors[k, b, n] * velocities[w, a, b, n] 

561 density[k, w, b, a] += np.abs(tmp)**2 

562 return density