Coverage for local_installation/dynasor/correlation_functions.py: 93%

294 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-12-21 12:02 +0000

1import numba 

2import concurrent 

3from functools import partial 

4from itertools import combinations_with_replacement 

5from typing import Tuple 

6 

7import numpy as np 

8from ase import Atoms 

9from ase.units import fs 

10from numpy.typing import NDArray 

11 

12from dynasor.logging_tools import logger 

13from dynasor.trajectory import Trajectory, WindowIterator 

14from dynasor.sample import DynamicSample, StaticSample 

15from dynasor.post_processing import fourier_cos_filon 

16from dynasor.core.time_averager import TimeAverager 

17from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q 

18from dynasor.qpoints.tools import get_index_offset 

19from dynasor.units import radians_per_fs_to_meV 

20 

21 

22def compute_dynamic_structure_factors( 

23 traj: Trajectory, 

24 q_points: NDArray[float], 

25 dt: float, 

26 window_size: int, 

27 window_step: int = 1, 

28 calculate_currents: bool = False, 

29 calculate_incoherent: bool = False, 

30) -> DynamicSample: 

31 """Compute dynamic structure factors. The results are returned in the 

32 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>` 

33 object. 

34 

35 Parameters 

36 ---------- 

37 traj 

38 Input trajectory 

39 q_points 

40 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

41 dt 

42 Time difference in femtoseconds between two consecutive snapshots 

43 in the trajectory. Note that you should *not* change :attr:`dt` if you change 

44 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

45 window_size 

46 Length of the trajectory frame window to use for time correlation calculation. 

47 It is expressed in terms of the number of time lags to consider 

48 and thus determines the smallest frequency resolved. 

49 window_step 

50 Window step (or stride) given as the number of frames between consecutive trajectory 

51 windows. This parameter does *not* affect the time between consecutive frames in the 

52 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not 

53 be used. 

54 calculate_currents 

55 Calculate the current correlations. Requires velocities to be available in :attr:`traj`. 

56 calculate_incoherent 

57 Calculate the incoherent part (self-part) of :math:`F_incoh`. 

58 """ 

59 # sanity check input args 

60 if q_points.shape[1] != 3: 

61 raise ValueError('q-points array has the wrong shape.') 

62 if dt <= 0: 

63 raise ValueError(f'dt must be positive: dt= {dt}') 

64 if window_size <= 2: 

65 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}') 

66 if window_size % 2 != 0: 

67 raise ValueError(f'window_size must be even: window_size= {window_size}') 

68 if window_step <= 0: 

69 raise ValueError(f'window_step must be positive: window_step= {window_step}') 

70 

71 # define internal parameters 

72 n_qpoints = q_points.shape[0] 

73 delta_t = traj.frame_step * dt 

74 N_tc = window_size + 1 

75 

76 # log all setup information 

77 dw = np.pi / (window_size * delta_t) 

78 w_max = dw * window_size 

79 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency 

80 

81 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}') 

82 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs') 

83 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs') 

84 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs') 

85 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = ' 

86 f'{dw * radians_per_fs_to_meV:.3f} meV') 

87 logger.info(f'Maximum angular frequency (dw * window_size):' 

88 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV') 

89 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):' 

90 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV') 

91 

92 if calculate_currents: 

93 logger.info('Calculating current (velocity) correlations') 

94 if calculate_incoherent: 

95 logger.info('Calculating incoherent part (self-part) of correlations') 

96 

97 # log some info regarding q-points 

98 logger.info(f'Number of q-points: {n_qpoints}') 

99 

100 q_directions = q_points.copy() 

101 q_distances = np.linalg.norm(q_points, axis=1) 

102 nonzero = q_distances > 0 

103 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1) 

104 

105 # setup functions to process frames 

106 def f2_rho(frame): 

107 rho_qs_dict = dict() 

108 for atom_type in frame.positions_by_type.keys(): 

109 x = frame.positions_by_type[atom_type] 

110 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

111 frame.rho_qs_dict = rho_qs_dict 

112 return frame 

113 

114 def f2_rho_and_j(frame): 

115 rho_qs_dict = dict() 

116 jz_qs_dict = dict() 

117 jper_qs_dict = dict() 

118 

119 for atom_type in frame.positions_by_type.keys(): 

120 x = frame.positions_by_type[atom_type] 

121 v = frame.velocities_by_type[atom_type] 

122 rho_qs, j_qs = calc_rho_j_q(x, v, q_points) 

123 jz_qs = np.sum(j_qs * q_directions, axis=1) 

124 jper_qs = j_qs - (jz_qs[:, None] * q_directions) 

125 

126 rho_qs_dict[atom_type] = rho_qs 

127 jz_qs_dict[atom_type] = jz_qs 

128 jper_qs_dict[atom_type] = jper_qs 

129 

130 frame.rho_qs_dict = rho_qs_dict 

131 frame.jz_qs_dict = jz_qs_dict 

132 frame.jper_qs_dict = jper_qs_dict 

133 return frame 

134 

135 if calculate_currents: 

136 element_processor = f2_rho_and_j 

137 else: 

138 element_processor = f2_rho 

139 

140 # setup window iterator 

141 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step, 

142 element_processor=element_processor) 

143 

144 # define all pairs 

145 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

146 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

147 logger.debug('Considering pairs:') 

148 for pair in pairs: 

149 logger.debug(f' {pair}') 

150 

151 # setup all time averager instances 

152 F_q_t_averager = dict() 

153 for pair in pairs: 

154 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

155 if calculate_currents: 

156 Cl_q_t_averager = dict() 

157 Ct_q_t_averager = dict() 

158 for pair in pairs: 

159 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

160 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

161 if calculate_incoherent: 

162 F_s_q_t_averager = dict() 

163 for pair in traj.atom_types: 

164 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

165 

166 # define correlation function 

167 def calc_corr(window, time_i): 

168 # Calculate correlations between two frames in the window without normalization 1/N 

169 f0 = window[0] 

170 fi = window[time_i] 

171 for s1, s2 in pairs: 

172 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate()) 

173 if s1 != s2: 

174 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate()) 

175 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt) 

176 

177 if calculate_currents: 

178 for s1, s2 in pairs: 

179 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate()) 

180 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] * 

181 fi.jper_qs_dict[s2].conjugate(), axis=1)) 

182 if s1 != s2: 

183 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate()) 

184 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] * 

185 fi.jper_qs_dict[s1].conjugate(), axis=1)) 

186 

187 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt) 

188 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt) 

189 

190 if calculate_incoherent: 

191 for atom_type in traj.atom_types: 

192 xi = fi.positions_by_type[atom_type] 

193 x0 = f0.positions_by_type[atom_type] 

194 Fsqt = np.real(calc_rho_q(xi - x0, q_points)) 

195 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt) 

196 

197 # run calculation 

198 logging_interval = 1000 

199 with concurrent.futures.ThreadPoolExecutor() as tpe: 

200 # This is the "main loop" over the trajectory 

201 for window in window_iterator: 

202 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') 

203 

204 if window[0].frame_index % logging_interval == 0: 

205 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa 

206 

207 # The map conviniently applies calc_corr to all time-lags. However, 

208 # as everything is done in place nothing gets returned so in order 

209 # to start and wait for the processes to finish we must iterate 

210 # over the None values returned 

211 for _ in tpe.map(partial(calc_corr, window), range(len(window))): 

212 pass 

213 

214 # collect results into dict with numpy arrays (n_qpoints, N_tc) 

215 data_dict_corr = dict() 

216 time = delta_t * np.arange(N_tc, dtype=float) 

217 data_dict_corr['q_points'] = q_points 

218 data_dict_corr['time'] = time 

219 

220 F_q_t_tot = np.zeros((n_qpoints, N_tc)) 

221 S_q_w_tot = np.zeros((n_qpoints, N_tc)) 

222 for pair in pairs: 

223 key = '_'.join(pair) 

224 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all() 

225 w, S_q_w = fourier_cos_filon(F_q_t, delta_t) 

226 S_q_w = np.array(S_q_w) 

227 data_dict_corr['omega'] = w 

228 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t 

229 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w 

230 

231 # sum all partials to the total 

232 F_q_t_tot += F_q_t 

233 S_q_w_tot += S_q_w 

234 data_dict_corr['Fqt_coh'] = F_q_t_tot 

235 data_dict_corr['Sqw_coh'] = S_q_w_tot 

236 

237 if calculate_currents: 

238 Cl_q_t_tot = np.zeros((n_qpoints, N_tc)) 

239 Ct_q_t_tot = np.zeros((n_qpoints, N_tc)) 

240 Cl_q_w_tot = np.zeros((n_qpoints, N_tc)) 

241 Ct_q_w_tot = np.zeros((n_qpoints, N_tc)) 

242 for pair in pairs: 

243 key = '_'.join(pair) 

244 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all() 

245 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all() 

246 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t) 

247 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t) 

248 data_dict_corr[f'Clqt_{key}'] = Cl_q_t 

249 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t 

250 data_dict_corr[f'Clqw_{key}'] = Cl_q_w 

251 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w 

252 

253 # sum all partials to the total 

254 Cl_q_t_tot += Cl_q_t 

255 Ct_q_t_tot += Ct_q_t 

256 Cl_q_w_tot += Cl_q_w 

257 Ct_q_w_tot += Ct_q_w 

258 data_dict_corr['Clqt'] = Cl_q_t_tot 

259 data_dict_corr['Ctqt'] = Ct_q_t_tot 

260 data_dict_corr['Clqw'] = Cl_q_w_tot 

261 data_dict_corr['Ctqw'] = Ct_q_w_tot 

262 

263 if calculate_incoherent: 

264 Fs_q_t_tot = np.zeros((n_qpoints, N_tc)) 

265 Ss_q_w_tot = np.zeros((n_qpoints, N_tc)) 

266 for atom_type in traj.atom_types: 

267 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all() 

268 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t) 

269 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t 

270 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w 

271 

272 # sum all partials to the total 

273 Fs_q_t_tot += Fs_q_t 

274 Ss_q_w_tot += Ss_q_w 

275 

276 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot 

277 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot 

278 

279 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'] + data_dict_corr['Fqt_incoh'] 

280 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'] + data_dict_corr['Sqw_incoh'] 

281 else: 

282 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'].copy() 

283 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'].copy() 

284 

285 # finalize results with additional meta data 

286 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs, 

287 particle_counts=particle_counts, cell=traj.cell, 

288 time_between_frames=delta_t, 

289 maximum_time_lag=delta_t * window_size, 

290 angular_frequency_resolution=dw, 

291 maximum_angular_frequency=w_max, 

292 number_of_frames=traj.number_of_frames_read) 

293 

294 return result 

295 

296 

297def compute_static_structure_factors( 

298 traj: Trajectory, 

299 q_points: NDArray[float], 

300) -> StaticSample: 

301 r"""Compute static structure factors. The results are returned in the 

302 form of a :class:`StaticSample <dynasor.sample.StaticSample>` 

303 object. 

304 

305 Parameters 

306 ---------- 

307 traj 

308 Input trajectory 

309 q_points 

310 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

311 """ 

312 # sanity check input args 

313 if q_points.shape[1] != 3: 

314 raise ValueError('q-points array has the wrong shape.') 

315 

316 n_qpoints = q_points.shape[0] 

317 logger.info(f'Number of q-points: {n_qpoints}') 

318 

319 # define all pairs 

320 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

321 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

322 logger.debug('Considering pairs:') 

323 for pair in pairs: 

324 logger.debug(f' {pair}') 

325 

326 # processing function 

327 def f2_rho(frame): 

328 rho_qs_dict = dict() 

329 for atom_type in frame.positions_by_type.keys(): 

330 x = frame.positions_by_type[atom_type] 

331 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

332 frame.rho_qs_dict = rho_qs_dict 

333 return frame 

334 

335 # setup averager 

336 Sq_averager = dict() 

337 for pair in pairs: 

338 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0 

339 

340 # main loop 

341 for frame in traj: 

342 

343 # process_frame 

344 f2_rho(frame) 

345 logger.debug(f'Processing frame {frame.frame_index}') 

346 

347 for s1, s2 in pairs: 

348 # compute correlation 

349 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate()) 

350 if s1 != s2: 

351 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate()) 

352 Sq_averager[(s1, s2)].add_sample(0, Sq_pair) 

353 

354 # collect results 

355 data_dict = dict() 

356 data_dict['q_points'] = q_points 

357 S_q_tot = np.zeros((n_qpoints, 1)) 

358 for s1, s2 in pairs: 

359 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1) 

360 data_dict[f'Sq_{s1}_{s2}'] = Sq 

361 S_q_tot += Sq 

362 data_dict['Sq'] = S_q_tot 

363 

364 # finalize results 

365 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs, 

366 particle_counts=particle_counts, cell=traj.cell, 

367 number_of_frames=traj.number_of_frames_read) 

368 return result 

369 

370 

371def compute_spectral_energy_density( 

372 traj: Trajectory, 

373 ideal_supercell: Atoms, 

374 primitive_cell: Atoms, 

375 q_points: NDArray[float], 

376 dt: float, 

377 partial: bool = False 

378) -> Tuple[NDArray[float], NDArray[float]]: 

379 r""" 

380 Compute the spectral energy density (SED) at specific q-points. The results 

381 are returned in the form of a tuple, which comprises the angular 

382 frequencies in an array of length ``N_times`` in units of rad/fs and the 

383 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``. 

384 The normalization is chosen such that integrating the SED of a q-point 

385 together with the supplied angular frequenceies omega (rad/fs) yields 

386 1/2kBT * number of bands (where number of bands = len(prim) * 3) 

387 

388 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010), 

389 which should be cited when using this function along with the dynasor reference. 

390 

391 **Note 1:** 

392 SED analysis is only suitable for crystalline materials without diffusion as 

393 atoms are assumed to move around fixed reference positions throughout the entire trajectory. 

394 

395 **Note 2:** 

396 This implementation reads the full trajectory and can thus consume a lot of memory. 

397 

398 Parameters 

399 ---------- 

400 traj 

401 Input trajectory 

402 ideal_supercell 

403 Ideal structure defining the reference positions. Do not change the 

404 masses in the ASE atoms objects to dynasor internal units, this will be 

405 done internally 

406 primitive_cell 

407 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`. 

408 q_points 

409 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

410 dt 

411 Time difference in femtoseconds between two consecutive snapshots in 

412 the trajectory. Note that you should not change :attr:`dt` if you change 

413 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

414 partial 

415 If True the SED will be returned decomposed per basis and Cartesian direction. 

416 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)`` 

417 """ 

418 

419 delta_t = traj.frame_step * dt 

420 

421 # logger 

422 logger.info('Running SED') 

423 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs') 

424 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}') 

425 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}') 

426 logger.info(f'Number of q-points: {q_points.shape[0]}') 

427 

428 # check that the ideal supercell agrees with traj 

429 if traj.n_atoms != len(ideal_supercell): 

430 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

431 

432 if len(primitive_cell) >= len(ideal_supercell): 432 ↛ 433line 432 didn't jump to line 433, because the condition on line 432 was never true

433 raise ValueError('primitive_cell contains more atoms than ideal_supercell.') 

434 

435 # colllect all velocities, and scale with sqrt(masses) 

436 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu 

437 velocities = [] 

438 for it, frame in enumerate(traj): 

439 logger.debug(f'Reading frame {it}') 

440 if frame.velocities_by_type is None: 

441 raise ValueError(f'Could not read velocities from frame {it}') 

442 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs 

443 velocities.append(np.sqrt(masses) * v) 

444 logger.info(f'Number of snapshots: {len(velocities)}') 

445 

446 # Perform the FFT on the last axis for extra speed (maybe not needed) 

447 N_samples = len(velocities) 

448 velocities = np.array(velocities) 

449 # places time index last and makes a copy for continuity 

450 velocities = velocities.transpose(1, 2, 0).copy() 

451 # #atoms in supercell x 3 directions x #frequencies 

452 velocities = np.fft.rfft(velocities, axis=2) 

453 

454 # Calcualte indices and offsets needed for the sed method 

455 indices, offsets = get_index_offset(ideal_supercell, primitive_cell) 

456 

457 # Phase factor for use in FT. #qpoints x #atoms in supercell 

458 cell_positions = np.dot(offsets, primitive_cell.cell) 

459 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells 

460 phase_factors = np.exp(1.0j * phase) 

461 

462 # This dict maps the offsets to an index so ndarrays can be over 

463 # offset,index instead of atoms in supercell 

464 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))} 

465 

466 # Pick out some shapes 

467 n_super, _, n_w = velocities.shape 

468 n_qpts = len(q_points) 

469 n_prim = len(primitive_cell) 

470 n_offsets = len(offset_dict) 

471 

472 # This new array will be indexed by index and offset instead (and also transposed) 

473 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype) 

474 

475 for i in range(n_super): 

476 j = indices[i] # atom with index i in the supercell is of basis type j ... 

477 n = offset_dict[tuple(offsets[i])] # and its offset has index n 

478 new_velocities[:, :, j, n] = velocities[i].T 

479 

480 velocities = new_velocities 

481 

482 # Same story with the spatial phase factors 

483 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype) 

484 

485 for i in range(n_super): 

486 j = indices[i] 

487 n = offset_dict[tuple(offsets[i])] 

488 new_phase_factors[:, j, n] = phase_factors[:, i] 

489 

490 phase_factors = new_phase_factors 

491 

492 # calcualte the density in a numba function 

493 density = _sed_inner_loop(phase_factors, velocities) 

494 

495 if not partial: 495 ↛ 502line 495 didn't jump to line 502, because the condition on line 495 was never false

496 density = np.sum(density, axis=(2, 3)) 

497 

498 # units 

499 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da 

500 

501 # the time delta in the fourier transform 

502 density = density * delta_t**2 

503 

504 # Divide by the length of the time signal 

505 density = density / (N_samples * delta_t) 

506 

507 # Divide by the number of primitive cells 

508 density = density / (n_super / n_prim) 

509 

510 # Factor so the sed can be integrated together with the returned omega 

511 # numpy fft works with ordinary/linear frequencies and not angular freqs 

512 density = density / (2*np.pi) 

513 

514 # angular frequencies 

515 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs 

516 

517 return w, density 

518 

519 

520@numba.njit(parallel=True, fastmath=True) 

521def _sed_inner_loop(phase_factors, velocities): 

522 """This numba function calculates the spatial FT using precomputed phase factors 

523 

524 As the use case can be one or many q-points the parallelization is over the 

525 temporal frequency components instead. 

526 """ 

527 

528 n_qpts = phase_factors.shape[0] # q-point index 

529 n_prim = phase_factors.shape[1] # basis atom index 

530 n_super = phase_factors.shape[2] # unit cell index 

531 

532 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell 

533 

534 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64) 

535 

536 for w in numba.prange(n_freqs): 

537 for k in range(n_qpts): 

538 for a in range(3): 

539 for b in range(n_prim): 

540 tmp = 0.0j 

541 for n in range(n_super): 

542 tmp += phase_factors[k, b, n] * velocities[w, a, b, n] 

543 density[k, w, b, a] += np.abs(tmp)**2 

544 return density