Coverage for local_installation/dynasor/correlation_functions.py: 93%

290 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2025-04-16 06:13 +0000

1import numba 

2import concurrent 

3from functools import partial 

4from itertools import combinations_with_replacement 

5from typing import Tuple 

6 

7import numpy as np 

8from ase import Atoms 

9from ase.units import fs 

10from numpy.typing import NDArray 

11 

12from dynasor.logging_tools import logger 

13from dynasor.trajectory import Trajectory, WindowIterator 

14from dynasor.sample import DynamicSample, StaticSample 

15from dynasor.post_processing import fourier_cos_filon 

16from dynasor.core.time_averager import TimeAverager 

17from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q 

18from dynasor.qpoints.tools import get_index_offset 

19from dynasor.units import radians_per_fs_to_meV 

20 

21 

22def compute_dynamic_structure_factors( 

23 traj: Trajectory, 

24 q_points: NDArray[float], 

25 dt: float, 

26 window_size: int, 

27 window_step: int = 1, 

28 calculate_currents: bool = False, 

29 calculate_incoherent: bool = False, 

30) -> DynamicSample: 

31 """Compute dynamic structure factors. The results are returned in the 

32 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>` 

33 object. 

34 

35 Parameters 

36 ---------- 

37 traj 

38 Input trajectory 

39 q_points 

40 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

41 dt 

42 Time difference in femtoseconds between two consecutive snapshots 

43 in the trajectory. Note that you should *not* change :attr:`dt` if you change 

44 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

45 window_size 

46 Length of the trajectory frame window to use for time correlation calculation. 

47 It is expressed in terms of the number of time lags to consider 

48 and thus determines the smallest frequency resolved. 

49 window_step 

50 Window step (or stride) given as the number of frames between consecutive trajectory 

51 windows. This parameter does *not* affect the time between consecutive frames in the 

52 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not 

53 be used. 

54 calculate_currents 

55 Calculate the current correlations. Requires velocities to be available in :attr:`traj`. 

56 calculate_incoherent 

57 Calculate the incoherent part (self-part) of :math:`F_incoh`. 

58 """ 

59 # sanity check input args 

60 if q_points.shape[1] != 3: 

61 raise ValueError('q-points array has the wrong shape.') 

62 if dt <= 0: 

63 raise ValueError(f'dt must be positive: dt= {dt}') 

64 if window_size <= 2: 

65 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}') 

66 if window_size % 2 != 0: 

67 raise ValueError(f'window_size must be even: window_size= {window_size}') 

68 if window_step <= 0: 

69 raise ValueError(f'window_step must be positive: window_step= {window_step}') 

70 

71 # define internal parameters 

72 n_qpoints = q_points.shape[0] 

73 delta_t = traj.frame_step * dt 

74 N_tc = window_size + 1 

75 

76 # log all setup information 

77 dw = np.pi / (window_size * delta_t) 

78 w_max = dw * window_size 

79 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency 

80 

81 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}') 

82 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs') 

83 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs') 

84 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs') 

85 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = ' 

86 f'{dw * radians_per_fs_to_meV:.3f} meV') 

87 logger.info(f'Maximum angular frequency (dw * window_size):' 

88 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV') 

89 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):' 

90 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV') 

91 

92 if calculate_currents: 

93 logger.info('Calculating current (velocity) correlations') 

94 if calculate_incoherent: 

95 logger.info('Calculating incoherent part (self-part) of correlations') 

96 

97 # log some info regarding q-points 

98 logger.info(f'Number of q-points: {n_qpoints}') 

99 

100 q_directions = q_points.copy() 

101 q_distances = np.linalg.norm(q_points, axis=1) 

102 nonzero = q_distances > 0 

103 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1) 

104 

105 # setup functions to process frames 

106 def f2_rho(frame): 

107 rho_qs_dict = dict() 

108 for atom_type in frame.positions_by_type.keys(): 

109 x = frame.positions_by_type[atom_type] 

110 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

111 frame.rho_qs_dict = rho_qs_dict 

112 return frame 

113 

114 def f2_rho_and_j(frame): 

115 rho_qs_dict = dict() 

116 jz_qs_dict = dict() 

117 jper_qs_dict = dict() 

118 

119 for atom_type in frame.positions_by_type.keys(): 

120 x = frame.positions_by_type[atom_type] 

121 v = frame.velocities_by_type[atom_type] 

122 rho_qs, j_qs = calc_rho_j_q(x, v, q_points) 

123 jz_qs = np.sum(j_qs * q_directions, axis=1) 

124 jper_qs = j_qs - (jz_qs[:, None] * q_directions) 

125 

126 rho_qs_dict[atom_type] = rho_qs 

127 jz_qs_dict[atom_type] = jz_qs 

128 jper_qs_dict[atom_type] = jper_qs 

129 

130 frame.rho_qs_dict = rho_qs_dict 

131 frame.jz_qs_dict = jz_qs_dict 

132 frame.jper_qs_dict = jper_qs_dict 

133 return frame 

134 

135 if calculate_currents: 

136 element_processor = f2_rho_and_j 

137 else: 

138 element_processor = f2_rho 

139 

140 # setup window iterator 

141 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step, 

142 element_processor=element_processor) 

143 

144 # define all pairs 

145 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

146 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

147 logger.debug('Considering pairs:') 

148 for pair in pairs: 

149 logger.debug(f' {pair}') 

150 

151 # setup all time averager instances 

152 F_q_t_averager = dict() 

153 for pair in pairs: 

154 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

155 if calculate_currents: 

156 Cl_q_t_averager = dict() 

157 Ct_q_t_averager = dict() 

158 for pair in pairs: 

159 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

160 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

161 if calculate_incoherent: 

162 F_s_q_t_averager = dict() 

163 for pair in traj.atom_types: 

164 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

165 

166 # define correlation function 

167 def calc_corr(window, time_i): 

168 # Calculate correlations between two frames in the window without normalization 1/N 

169 f0 = window[0] 

170 fi = window[time_i] 

171 for s1, s2 in pairs: 

172 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate()) 

173 if s1 != s2: 

174 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate()) 

175 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt) 

176 

177 if calculate_currents: 

178 for s1, s2 in pairs: 

179 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate()) 

180 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] * 

181 fi.jper_qs_dict[s2].conjugate(), axis=1)) 

182 if s1 != s2: 

183 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate()) 

184 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] * 

185 fi.jper_qs_dict[s1].conjugate(), axis=1)) 

186 

187 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt) 

188 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt) 

189 

190 if calculate_incoherent: 

191 for atom_type in traj.atom_types: 

192 xi = fi.positions_by_type[atom_type] 

193 x0 = f0.positions_by_type[atom_type] 

194 Fsqt = np.real(calc_rho_q(xi - x0, q_points)) 

195 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt) 

196 

197 # run calculation 

198 logging_interval = 1000 

199 with concurrent.futures.ThreadPoolExecutor() as tpe: 

200 # This is the "main loop" over the trajectory 

201 for window in window_iterator: 

202 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') 

203 

204 if window[0].frame_index % logging_interval == 0: 

205 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa 

206 

207 # The map conviniently applies calc_corr to all time-lags. However, 

208 # as everything is done in place nothing gets returned so in order 

209 # to start and wait for the processes to finish we must iterate 

210 # over the None values returned 

211 for _ in tpe.map(partial(calc_corr, window), range(len(window))): 

212 pass 

213 

214 # collect results into dict with numpy arrays (n_qpoints, N_tc) 

215 data_dict_corr = dict() 

216 time = delta_t * np.arange(N_tc, dtype=float) 

217 data_dict_corr['q_points'] = q_points 

218 data_dict_corr['time'] = time 

219 

220 F_q_t_tot = np.zeros((n_qpoints, N_tc)) 

221 S_q_w_tot = np.zeros((n_qpoints, N_tc)) 

222 for pair in pairs: 

223 key = '_'.join(pair) 

224 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all() 

225 w, S_q_w = fourier_cos_filon(F_q_t, delta_t) 

226 S_q_w = np.array(S_q_w) 

227 data_dict_corr['omega'] = w 

228 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t 

229 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w 

230 

231 # sum all partials to the total 

232 F_q_t_tot += F_q_t 

233 S_q_w_tot += S_q_w 

234 data_dict_corr['Fqt_coh'] = F_q_t_tot 

235 data_dict_corr['Sqw_coh'] = S_q_w_tot 

236 

237 if calculate_currents: 

238 Cl_q_t_tot = np.zeros((n_qpoints, N_tc)) 

239 Ct_q_t_tot = np.zeros((n_qpoints, N_tc)) 

240 Cl_q_w_tot = np.zeros((n_qpoints, N_tc)) 

241 Ct_q_w_tot = np.zeros((n_qpoints, N_tc)) 

242 for pair in pairs: 

243 key = '_'.join(pair) 

244 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all() 

245 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all() 

246 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t) 

247 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t) 

248 data_dict_corr[f'Clqt_{key}'] = Cl_q_t 

249 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t 

250 data_dict_corr[f'Clqw_{key}'] = Cl_q_w 

251 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w 

252 

253 # sum all partials to the total 

254 Cl_q_t_tot += Cl_q_t 

255 Ct_q_t_tot += Ct_q_t 

256 Cl_q_w_tot += Cl_q_w 

257 Ct_q_w_tot += Ct_q_w 

258 data_dict_corr['Clqt'] = Cl_q_t_tot 

259 data_dict_corr['Ctqt'] = Ct_q_t_tot 

260 data_dict_corr['Clqw'] = Cl_q_w_tot 

261 data_dict_corr['Ctqw'] = Ct_q_w_tot 

262 

263 if calculate_incoherent: 

264 Fs_q_t_tot = np.zeros((n_qpoints, N_tc)) 

265 Ss_q_w_tot = np.zeros((n_qpoints, N_tc)) 

266 for atom_type in traj.atom_types: 

267 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all() 

268 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t) 

269 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t 

270 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w 

271 

272 # sum all partials to the total 

273 Fs_q_t_tot += Fs_q_t 

274 Ss_q_w_tot += Ss_q_w 

275 

276 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot 

277 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot 

278 

279 # finalize results with additional meta data 

280 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs, 

281 particle_counts=particle_counts, cell=traj.cell, 

282 time_between_frames=delta_t, 

283 maximum_time_lag=delta_t * window_size, 

284 angular_frequency_resolution=dw, 

285 maximum_angular_frequency=w_max, 

286 number_of_frames=traj.number_of_frames_read) 

287 

288 return result 

289 

290 

291def compute_static_structure_factors( 

292 traj: Trajectory, 

293 q_points: NDArray[float], 

294) -> StaticSample: 

295 r"""Compute static structure factors. The results are returned in the 

296 form of a :class:`StaticSample <dynasor.sample.StaticSample>` 

297 object. 

298 

299 Parameters 

300 ---------- 

301 traj 

302 Input trajectory 

303 q_points 

304 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

305 """ 

306 # sanity check input args 

307 if q_points.shape[1] != 3: 

308 raise ValueError('q-points array has the wrong shape.') 

309 

310 n_qpoints = q_points.shape[0] 

311 logger.info(f'Number of q-points: {n_qpoints}') 

312 

313 # define all pairs 

314 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

315 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

316 logger.debug('Considering pairs:') 

317 for pair in pairs: 

318 logger.debug(f' {pair}') 

319 

320 # processing function 

321 def f2_rho(frame): 

322 rho_qs_dict = dict() 

323 for atom_type in frame.positions_by_type.keys(): 

324 x = frame.positions_by_type[atom_type] 

325 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

326 frame.rho_qs_dict = rho_qs_dict 

327 return frame 

328 

329 # setup averager 

330 Sq_averager = dict() 

331 for pair in pairs: 

332 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0 

333 

334 # main loop 

335 for frame in traj: 

336 

337 # process_frame 

338 f2_rho(frame) 

339 logger.debug(f'Processing frame {frame.frame_index}') 

340 

341 for s1, s2 in pairs: 

342 # compute correlation 

343 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate()) 

344 if s1 != s2: 

345 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate()) 

346 Sq_averager[(s1, s2)].add_sample(0, Sq_pair) 

347 

348 # collect results 

349 data_dict = dict() 

350 data_dict['q_points'] = q_points 

351 S_q_tot = np.zeros((n_qpoints, 1)) 

352 for s1, s2 in pairs: 

353 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1) 

354 data_dict[f'Sq_{s1}_{s2}'] = Sq 

355 S_q_tot += Sq 

356 data_dict['Sq'] = S_q_tot 

357 

358 # finalize results 

359 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs, 

360 particle_counts=particle_counts, cell=traj.cell, 

361 number_of_frames=traj.number_of_frames_read) 

362 return result 

363 

364 

365def compute_spectral_energy_density( 

366 traj: Trajectory, 

367 ideal_supercell: Atoms, 

368 primitive_cell: Atoms, 

369 q_points: NDArray[float], 

370 dt: float, 

371 partial: bool = False 

372) -> Tuple[NDArray[float], NDArray[float]]: 

373 r""" 

374 Compute the spectral energy density (SED) at specific q-points. The results 

375 are returned in the form of a tuple, which comprises the angular 

376 frequencies in an array of length ``N_times`` in units of rad/fs and the 

377 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``. 

378 The normalization is chosen such that integrating the SED of a q-point 

379 together with the supplied angular frequenceies omega (rad/fs) yields 

380 1/2kBT * number of bands (where number of bands = len(prim) * 3) 

381 

382 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010), 

383 which should be cited when using this function along with the dynasor reference. 

384 

385 **Note 1:** 

386 SED analysis is only suitable for crystalline materials without diffusion as 

387 atoms are assumed to move around fixed reference positions throughout the entire trajectory. 

388 

389 **Note 2:** 

390 This implementation reads the full trajectory and can thus consume a lot of memory. 

391 

392 Parameters 

393 ---------- 

394 traj 

395 Input trajectory 

396 ideal_supercell 

397 Ideal structure defining the reference positions. Do not change the 

398 masses in the ASE atoms objects to dynasor internal units, this will be 

399 done internally 

400 primitive_cell 

401 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`. 

402 q_points 

403 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

404 dt 

405 Time difference in femtoseconds between two consecutive snapshots in 

406 the trajectory. Note that you should not change :attr:`dt` if you change 

407 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

408 partial 

409 If True the SED will be returned decomposed per basis and Cartesian direction. 

410 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)`` 

411 """ 

412 

413 delta_t = traj.frame_step * dt 

414 

415 # logger 

416 logger.info('Running SED') 

417 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs') 

418 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}') 

419 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}') 

420 logger.info(f'Number of q-points: {q_points.shape[0]}') 

421 

422 # check that the ideal supercell agrees with traj 

423 if traj.n_atoms != len(ideal_supercell): 

424 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

425 

426 if len(primitive_cell) >= len(ideal_supercell): 426 ↛ 427line 426 didn't jump to line 427, because the condition on line 426 was never true

427 raise ValueError('primitive_cell contains more atoms than ideal_supercell.') 

428 

429 # colllect all velocities, and scale with sqrt(masses) 

430 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu 

431 velocities = [] 

432 for it, frame in enumerate(traj): 

433 logger.debug(f'Reading frame {it}') 

434 if frame.velocities_by_type is None: 

435 raise ValueError(f'Could not read velocities from frame {it}') 

436 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs 

437 velocities.append(np.sqrt(masses) * v) 

438 logger.info(f'Number of snapshots: {len(velocities)}') 

439 

440 # Perform the FFT on the last axis for extra speed (maybe not needed) 

441 N_samples = len(velocities) 

442 velocities = np.array(velocities) 

443 # places time index last and makes a copy for continuity 

444 velocities = velocities.transpose(1, 2, 0).copy() 

445 # #atoms in supercell x 3 directions x #frequencies 

446 velocities = np.fft.rfft(velocities, axis=2) 

447 

448 # Calcualte indices and offsets needed for the sed method 

449 indices, offsets = get_index_offset(ideal_supercell, primitive_cell) 

450 

451 # Phase factor for use in FT. #qpoints x #atoms in supercell 

452 cell_positions = np.dot(offsets, primitive_cell.cell) 

453 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells 

454 phase_factors = np.exp(1.0j * phase) 

455 

456 # This dict maps the offsets to an index so ndarrays can be over 

457 # offset,index instead of atoms in supercell 

458 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))} 

459 

460 # Pick out some shapes 

461 n_super, _, n_w = velocities.shape 

462 n_qpts = len(q_points) 

463 n_prim = len(primitive_cell) 

464 n_offsets = len(offset_dict) 

465 

466 # This new array will be indexed by index and offset instead (and also transposed) 

467 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype) 

468 

469 for i in range(n_super): 

470 j = indices[i] # atom with index i in the supercell is of basis type j ... 

471 n = offset_dict[tuple(offsets[i])] # and its offset has index n 

472 new_velocities[:, :, j, n] = velocities[i].T 

473 

474 velocities = new_velocities 

475 

476 # Same story with the spatial phase factors 

477 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype) 

478 

479 for i in range(n_super): 

480 j = indices[i] 

481 n = offset_dict[tuple(offsets[i])] 

482 new_phase_factors[:, j, n] = phase_factors[:, i] 

483 

484 phase_factors = new_phase_factors 

485 

486 # calcualte the density in a numba function 

487 density = _sed_inner_loop(phase_factors, velocities) 

488 

489 if not partial: 489 ↛ 496line 489 didn't jump to line 496, because the condition on line 489 was never false

490 density = np.sum(density, axis=(2, 3)) 

491 

492 # units 

493 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da 

494 

495 # the time delta in the fourier transform 

496 density = density * delta_t**2 

497 

498 # Divide by the length of the time signal 

499 density = density / (N_samples * delta_t) 

500 

501 # Divide by the number of primitive cells 

502 density = density / (n_super / n_prim) 

503 

504 # Factor so the sed can be integrated together with the returned omega 

505 # numpy fft works with ordinary/linear frequencies and not angular freqs 

506 density = density / (2*np.pi) 

507 

508 # angular frequencies 

509 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs 

510 

511 return w, density 

512 

513 

514@numba.njit(parallel=True, fastmath=True) 

515def _sed_inner_loop(phase_factors, velocities): 

516 """This numba function calculates the spatial FT using precomputed phase factors 

517 

518 As the use case can be one or many q-points the parallelization is over the 

519 temporal frequency components instead. 

520 """ 

521 

522 n_qpts = phase_factors.shape[0] # q-point index 

523 n_prim = phase_factors.shape[1] # basis atom index 

524 n_super = phase_factors.shape[2] # unit cell index 

525 

526 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell 

527 

528 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64) 

529 

530 for w in numba.prange(n_freqs): 

531 for k in range(n_qpts): 

532 for a in range(3): 

533 for b in range(n_prim): 

534 tmp = 0.0j 

535 for n in range(n_super): 

536 tmp += phase_factors[k, b, n] * velocities[w, a, b, n] 

537 density[k, w, b, a] += np.abs(tmp)**2 

538 return density