Coverage for dynasor / correlation_functions.py: 94%
299 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 06:28 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 06:28 +0000
1import concurrent
2from functools import partial
3from itertools import combinations_with_replacement
4from typing import Optional
6import numba
7import numpy as np
8from ase import Atoms
9from ase.units import fs
10from numpy.typing import NDArray
12from dynasor.logging_tools import logger
13from dynasor.trajectory import Trajectory, WindowIterator
14from dynasor.sample import DynamicSample, StaticSample
16from dynasor.tools.acfs import psd_from_acf_2d
17from dynasor.core.time_averager import TimeAverager, truncate_nans
18from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q
19from dynasor.tools.structures import get_offset_index
20from dynasor.units import radians_per_fs_to_meV
23def compute_dynamic_structure_factors(
24 traj: Trajectory,
25 q_points: NDArray[float],
26 dt: float,
27 window_size: int,
28 window_step: Optional[int] = 1,
29 calculate_currents: Optional[bool] = False,
30 calculate_incoherent: Optional[bool] = False,
31 logging_interval: Optional[int] = 1000,
32) -> DynamicSample:
33 r"""Compute the dynamic structure factors. The results are returned in the
34 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>`
35 object.
37 Parameters
38 ----------
39 traj
40 Input trajectory.
41 q_points
42 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates.
43 dt
44 Time difference in femtoseconds between two consecutive snapshots
45 in the trajectory. Note that you should *not* change :attr:`dt` if you change
46 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
47 window_size
48 Length of the trajectory frame window to use for time correlation calculation.
49 It is expressed in terms of the number of time lags to consider
50 and thus determines the smallest frequency resolved.
51 window_step
52 Window step (or stride) given as the number of frames between consecutive trajectory
53 windows. This parameter does *not* affect the time between consecutive frames in the
54 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not
55 be used.
56 calculate_currents
57 Calculate the current correlations. Requires velocities to be available in :attr:`traj`.
58 calculate_incoherent
59 Calculate the incoherent part (self-part) of :math:`F_\text{incoh}`.
60 logging_interval
61 Log progress at ``INFO`` level every this many windows. Set to ``0`` to disable
62 progress logging.
63 """
64 # sanity check input args
65 if q_points.shape[1] != 3:
66 raise ValueError('q-points array has the wrong shape.')
67 if dt <= 0:
68 raise ValueError(f'dt must be positive: dt= {dt}')
69 if window_size <= 2:
70 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}')
71 if window_step <= 0:
72 raise ValueError(f'window_step must be positive: window_step= {window_step}')
74 # define internal parameters
75 n_qpoints = q_points.shape[0]
76 delta_t = traj.frame_step * dt
77 N_tc = window_size + 1
79 # log all setup information
80 dw = np.pi / (window_size * delta_t)
81 w_max = dw * window_size
82 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency
84 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}')
85 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs')
86 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs')
87 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs')
88 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = '
89 f'{dw * radians_per_fs_to_meV:.3f} meV')
90 logger.info(f'Maximum angular frequency (dw * window_size):'
91 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV')
92 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):'
93 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV')
95 if calculate_currents:
96 logger.info('Calculating current (velocity) correlations')
97 if calculate_incoherent:
98 logger.info('Calculating incoherent part (self-part) of correlations')
100 # log some info regarding q-points
101 logger.info(f'Number of q-points: {n_qpoints}')
103 q_directions = q_points.copy()
104 q_distances = np.linalg.norm(q_points, axis=1)
105 nonzero = q_distances > 0
106 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1)
108 # setup functions to process frames
109 def f2_rho(frame):
110 rho_qs_dict = dict()
111 for atom_type in frame.positions_by_type.keys():
112 x = frame.positions_by_type[atom_type]
113 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
114 frame.rho_qs_dict = rho_qs_dict
115 return frame
117 def f2_rho_and_j(frame):
118 rho_qs_dict = dict()
119 jz_qs_dict = dict()
120 jper_qs_dict = dict()
122 for atom_type in frame.positions_by_type.keys():
123 x = frame.positions_by_type[atom_type]
124 v = frame.velocities_by_type[atom_type]
125 rho_qs, j_qs = calc_rho_j_q(x, v, q_points)
126 jz_qs = np.sum(j_qs * q_directions, axis=1)
127 jper_qs = j_qs - (jz_qs[:, None] * q_directions)
129 rho_qs_dict[atom_type] = rho_qs
130 jz_qs_dict[atom_type] = jz_qs
131 jper_qs_dict[atom_type] = jper_qs
133 frame.rho_qs_dict = rho_qs_dict
134 frame.jz_qs_dict = jz_qs_dict
135 frame.jper_qs_dict = jper_qs_dict
136 return frame
138 if calculate_currents:
139 element_processor = f2_rho_and_j
140 else:
141 element_processor = f2_rho
143 # setup window iterator
144 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step,
145 element_processor=element_processor)
147 # define all pairs
148 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
149 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
150 logger.debug('Considering pairs:')
151 for pair in pairs:
152 logger.debug(f' {pair}')
154 # set up all time averager instances
155 F_q_t_averager = dict()
156 for pair in pairs:
157 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
158 if calculate_currents:
159 Cl_q_t_averager = dict()
160 Ct_q_t_averager = dict()
161 for pair in pairs:
162 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
163 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
164 if calculate_incoherent:
165 F_s_q_t_averager = dict()
166 for pair in traj.atom_types:
167 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
169 # define correlation function
170 def calc_corr(window, time_i):
171 # Calculate correlations between two frames in the window without normalization 1/N
172 f0 = window[0]
173 fi = window[time_i]
174 for s1, s2 in pairs:
175 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate())
176 if s1 != s2:
177 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate())
178 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt)
180 if calculate_currents:
181 for s1, s2 in pairs:
182 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate())
183 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] *
184 fi.jper_qs_dict[s2].conjugate(), axis=1))
185 if s1 != s2:
186 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate())
187 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] *
188 fi.jper_qs_dict[s1].conjugate(), axis=1))
190 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt)
191 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt)
193 if calculate_incoherent:
194 for atom_type in traj.atom_types:
195 xi = fi.positions_by_type[atom_type]
196 x0 = f0.positions_by_type[atom_type]
197 Fsqt = np.real(calc_rho_q(xi - x0, q_points))
198 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt)
200 # run calculation
201 with concurrent.futures.ThreadPoolExecutor() as tpe:
202 # This is the "main loop" over the trajectory
203 for window in window_iterator:
204 if logging_interval and window[0].frame_index % logging_interval == 0:
205 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa
206 else:
207 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa
209 # The map conviniently applies calc_corr to all time-lags. However,
210 # as everything is done in place nothing gets returned so in order
211 # to start and wait for the processes to finish we must iterate
212 # over the None values returned
213 for _ in tpe.map(partial(calc_corr, window), range(len(window))):
214 pass
216 # collect results into dict with numpy arrays (n_qpoints, N_tc)
217 data_dict_corr = dict()
218 data_dict_corr['q_points'] = q_points
220 time = None
221 for pair in pairs:
222 key = '_'.join(pair)
223 F_q_t = 1 / traj.n_atoms * truncate_nans(F_q_t_averager[pair].get_average_all())
224 w, S_q_w = psd_from_acf_2d(F_q_t, delta_t)
225 S_q_w = np.array(S_q_w)
227 if time is None:
228 # Determine the actual length of time signal after truncation (if needed)
229 time = delta_t * np.arange(F_q_t.shape[1], dtype=float)
230 N_tc_actual = len(time)
231 F_q_t_tot = np.zeros((n_qpoints, N_tc_actual))
232 S_q_w_tot = np.zeros((n_qpoints, N_tc_actual))
233 else:
234 assert F_q_t.shape[1] == len(time)
236 data_dict_corr['omega'] = w
237 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t
238 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w
240 # sum all partials to the total
241 F_q_t_tot += F_q_t
242 S_q_w_tot += S_q_w
244 if N_tc_actual < N_tc:
245 logger.warning('Truncating ACF due to NaNs, likely time_window is longer than length of Trajectory') # noqa
246 data_dict_corr['time'] = time
247 data_dict_corr['Fqt_coh'] = F_q_t_tot
248 data_dict_corr['Sqw_coh'] = S_q_w_tot
250 if calculate_currents:
251 Cl_q_t_tot = np.zeros((n_qpoints, N_tc_actual))
252 Ct_q_t_tot = np.zeros((n_qpoints, N_tc_actual))
253 Cl_q_w_tot = np.zeros((n_qpoints, N_tc_actual))
254 Ct_q_w_tot = np.zeros((n_qpoints, N_tc_actual))
255 for pair in pairs:
256 key = '_'.join(pair)
257 Cl_q_t = 1 / traj.n_atoms * truncate_nans(Cl_q_t_averager[pair].get_average_all())
258 Ct_q_t = 1 / traj.n_atoms * truncate_nans(Ct_q_t_averager[pair].get_average_all())
259 _, Cl_q_w = psd_from_acf_2d(Cl_q_t, delta_t)
260 _, Ct_q_w = psd_from_acf_2d(Ct_q_t, delta_t)
261 data_dict_corr[f'Clqt_{key}'] = Cl_q_t
262 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t
263 data_dict_corr[f'Clqw_{key}'] = Cl_q_w
264 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w
266 # sum all partials to the total
267 Cl_q_t_tot += Cl_q_t
268 Ct_q_t_tot += Ct_q_t
269 Cl_q_w_tot += Cl_q_w
270 Ct_q_w_tot += Ct_q_w
271 data_dict_corr['Clqt'] = Cl_q_t_tot
272 data_dict_corr['Ctqt'] = Ct_q_t_tot
273 data_dict_corr['Clqw'] = Cl_q_w_tot
274 data_dict_corr['Ctqw'] = Ct_q_w_tot
276 if calculate_incoherent:
277 Fs_q_t_tot = np.zeros((n_qpoints, N_tc_actual))
278 Ss_q_w_tot = np.zeros((n_qpoints, N_tc_actual))
279 for atom_type in traj.atom_types:
280 Fs_q_t = 1 / traj.n_atoms * truncate_nans(F_s_q_t_averager[atom_type].get_average_all())
281 _, Ss_q_w = psd_from_acf_2d(Fs_q_t, delta_t)
282 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t
283 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w
285 # sum all partials to the total
286 Fs_q_t_tot += Fs_q_t
287 Ss_q_w_tot += Ss_q_w
289 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot
290 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot
292 # finalize results with additional metadata
293 new_sample = DynamicSample(
294 data_dict_corr,
295 simulation_data=dict(
296 atom_types=traj.atom_types, pairs=pairs,
297 particle_counts=particle_counts,
298 cell=traj.cell,
299 time_between_frames=delta_t,
300 maximum_time_lag=delta_t * window_size,
301 angular_frequency_resolution=dw,
302 maximum_angular_frequency=w_max,
303 number_of_frames=traj.number_of_frames_read,
304 ))
305 new_sample._append_history(
306 'compute_dynamic_structure_factors',
307 dict(
308 dt=dt,
309 window_size=window_size,
310 window_step=window_step,
311 calculate_currents=calculate_currents,
312 calculate_incoherent=calculate_incoherent,
313 ))
315 return new_sample
318def compute_static_structure_factors(
319 traj: Trajectory,
320 q_points: NDArray[float],
321 logging_interval: Optional[int] = 1000,
322) -> StaticSample:
323 r"""Compute the static structure factors. The results are returned in the
324 form of a :class:`StaticSample <dynasor.sample.StaticSample>`
325 object.
327 Parameters
328 ----------
329 traj
330 Input trajectory.
331 q_points
332 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates.
333 logging_interval
334 Log progress at ``INFO`` level every this many frames. Set to ``0`` to disable
335 progress logging.
336 """
337 # sanity check input args
338 if q_points.shape[1] != 3:
339 raise ValueError('q-points array has the wrong shape.')
341 n_qpoints = q_points.shape[0]
342 logger.info(f'Number of q-points: {n_qpoints}')
344 # define all pairs
345 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
346 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
347 logger.debug('Considering pairs:')
348 for pair in pairs:
349 logger.debug(f' {pair}')
351 # processing function
352 def f2_rho(frame):
353 rho_qs_dict = dict()
354 for atom_type in frame.positions_by_type.keys():
355 x = frame.positions_by_type[atom_type]
356 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
357 frame.rho_qs_dict = rho_qs_dict
358 return frame
360 # setup averager
361 Sq_averager = dict()
362 for pair in pairs:
363 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0
365 # main loop
366 for frame in traj:
368 # process_frame
369 f2_rho(frame)
370 if logging_interval and frame.frame_index % logging_interval == 0:
371 logger.info(f'Processing frame {frame.frame_index}')
372 else:
373 logger.debug(f'Processing frame {frame.frame_index}')
375 for s1, s2 in pairs:
376 # compute correlation
377 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate())
378 if s1 != s2:
379 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate())
380 Sq_averager[(s1, s2)].add_sample(0, Sq_pair)
382 # collect results
383 data_dict = dict()
384 data_dict['q_points'] = q_points
385 S_q_tot = np.zeros((n_qpoints, 1))
386 for s1, s2 in pairs:
387 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1)
388 data_dict[f'Sq_{s1}_{s2}'] = Sq
389 S_q_tot += Sq
390 data_dict['Sq'] = S_q_tot
392 # finalize results
393 new_sample = StaticSample(
394 data_dict,
395 simulation_data=dict(
396 atom_types=traj.atom_types,
397 pairs=pairs,
398 particle_counts=particle_counts,
399 cell=traj.cell,
400 number_of_frames=traj.number_of_frames_read,
401 ))
402 new_sample._append_history('compute_static_structure_factors')
404 return new_sample
407def compute_spectral_energy_density(
408 traj: Trajectory,
409 ideal_supercell: Atoms,
410 primitive_cell: Atoms,
411 q_points: NDArray[float],
412 dt: float,
413 partial: Optional[bool] = False,
414 logging_interval: Optional[int] = 1000,
415) -> tuple[NDArray[float], NDArray[float]]:
416 r"""
417 Compute the spectral energy density (SED) at specific q-points. The results
418 are returned in the form of a tuple, which comprises the angular
419 frequencies in an array of length ``N_times`` in units of rad/fs and the
420 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``.
421 The normalization is chosen such that integrating the SED of a q-point
422 together with the supplied angular frequencies omega (rad/fs) yields
423 `1/2 kB T` * number of bands (where number of bands = `len(prim) * 3`)
425 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010),
426 which should be cited when using this function along with the dynasor reference.
428 **Note 1:**
429 SED analysis is only suitable for crystalline materials without diffusion as
430 atoms are assumed to move around fixed reference positions throughout the entire trajectory.
432 **Note 2:**
433 This implementation reads the full trajectory and can thus consume a lot of memory.
435 Parameters
436 ----------
437 traj
438 Input trajectory.
439 ideal_supercell
440 Ideal structure defining the reference positions. Do not change the masses
441 in the ASE :class:`Atoms` objects to dynasor internal units, this will be
442 done internally.
443 primitive_cell
444 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`.
445 q_points
446 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates.
447 dt
448 Time difference in femtoseconds between two consecutive snapshots in
449 the trajectory. Note that you should not change :attr:`dt` if you change
450 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
451 partial
452 If True the SED will be returned decomposed per basis and Cartesian direction.
453 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)``.
454 logging_interval
455 Log progress at ``INFO`` level every this many frames. Set to ``0`` to disable
456 progress logging.
457 """
459 delta_t = traj.frame_step * dt
461 # logger
462 logger.info('Running SED')
463 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs')
464 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}')
465 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}')
466 logger.info(f'Number of q-points: {q_points.shape[0]}')
468 # check that the ideal supercell agrees with traj
469 if traj.n_atoms != len(ideal_supercell):
470 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
472 if len(primitive_cell) >= len(ideal_supercell):
473 raise ValueError('primitive_cell contains more atoms than ideal_supercell.')
475 # colllect all velocities, and scale with sqrt(masses)
476 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu
477 velocities = []
478 for it, frame in enumerate(traj):
479 if logging_interval and it % logging_interval == 0:
480 logger.info(f'Reading frame {it}')
481 else:
482 logger.debug(f'Reading frame {it}')
483 if frame.velocities_by_type is None:
484 raise ValueError(f'Could not read velocities from frame {it}')
485 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs
486 velocities.append(np.sqrt(masses) * v)
487 logger.info(f'Number of snapshots: {len(velocities)}')
489 # Perform the FFT on the last axis for extra speed (maybe not needed)
490 N_samples = len(velocities)
491 velocities = np.array(velocities)
492 # places time index last and makes a copy for continuity
493 velocities = velocities.transpose(1, 2, 0).copy()
494 # #atoms in supercell x 3 directions x #frequencies
495 velocities = np.fft.rfft(velocities, axis=2)
497 # Calcualte indices and offsets needed for the sed method
498 offsets, indices = get_offset_index(primitive_cell, ideal_supercell)
500 # Phase factor for use in FT. #qpoints x #atoms in supercell
501 cell_positions = np.dot(offsets, primitive_cell.cell)
502 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells
503 phase_factors = np.exp(1.0j * phase)
505 # This dict maps the offsets to an index so ndarrays can be over
506 # offset,index instead of atoms in supercell
507 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))}
509 # Pick out some shapes
510 n_super, _, n_w = velocities.shape
511 n_qpts = len(q_points)
512 n_prim = len(primitive_cell)
513 n_offsets = len(offset_dict)
515 # This new array will be indexed by index and offset instead (and also transposed)
516 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype)
518 for i in range(n_super):
519 j = indices[i] # atom with index i in the supercell is of basis type j ...
520 n = offset_dict[tuple(offsets[i])] # and its offset has index n
521 new_velocities[:, :, j, n] = velocities[i].T
523 velocities = new_velocities
525 # Same story with the spatial phase factors
526 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype)
528 for i in range(n_super):
529 j = indices[i]
530 n = offset_dict[tuple(offsets[i])]
531 new_phase_factors[:, j, n] = phase_factors[:, i]
533 phase_factors = new_phase_factors
535 # calculate the density in a numba function
536 density = _sed_inner_loop(phase_factors, velocities)
538 if not partial:
539 density = np.sum(density, axis=(2, 3))
541 # units
542 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da
544 # the time delta in the fourier transform
545 density = density * delta_t**2
547 # Divide by the length of the time signal
548 density = density / (N_samples * delta_t)
550 # Divide by the number of primitive cells
551 density = density / (n_super / n_prim)
553 # Factor so the sed can be integrated together with the returned omega
554 # numpy fft works with ordinary/linear frequencies and not angular freqs
555 density = density / (2*np.pi)
557 # angular frequencies
558 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs
560 return w, density
563@numba.njit(parallel=True, fastmath=True)
564def _sed_inner_loop(phase_factors, velocities):
565 """This numba function calculates the spatial
566 Fourier transform using precomputed phase factors.
568 As the use case can be one or many q-points the parallelization is over the
569 temporal frequency components instead.
570 """
572 n_qpts = phase_factors.shape[0] # q-point index
573 n_prim = phase_factors.shape[1] # basis atom index
574 n_super = phase_factors.shape[2] # unit cell index
576 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell
578 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64)
580 for w in numba.prange(n_freqs):
581 for k in range(n_qpts):
582 for a in range(3):
583 for b in range(n_prim):
584 tmp = 0.0j
585 for n in range(n_super):
586 tmp += phase_factors[k, b, n] * velocities[w, a, b, n]
587 density[k, w, b, a] += np.abs(tmp)**2
588 return density