Coverage for dynasor / post_processing / atomic_weighting.py: 99%
75 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 12:31 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 12:31 +0000
1from copy import deepcopy
2from typing import Optional
3from warnings import warn
4import numpy as np
5from dynasor.post_processing.weights import Weights
6from dynasor.sample import Sample, StaticSample, DynamicSample
7from numpy.typing import NDArray
10def get_weighted_sample(sample: Sample,
11 weights: Weights,
12 atom_type_map: Optional[dict[str, str]] = None) -> Sample:
13 r"""
14 Weights correlation functions with atomic weighting factors
16 The weighting of a partial dynamic structure factor
17 :math:`S_\mathrm{AB}(\boldsymbol{q}, \omega)`
18 for atom types :math:`A` and :math:`B` is carried out as
20 .. math::
22 S_\mathrm{AB}(\boldsymbol{q}, \omega)
23 = f_\mathrm{A}(\boldsymbol{q}) f_\mathrm{B}(\boldsymbol{q})
24 S_\mathrm{AB}(\boldsymbol{q}, \omega)
26 :math:`f_\mathrm{A}(\boldsymbol{q})` and :math:`f_\mathrm{B}(\boldsymbol{q})`
27 are atom-type and :math:`\boldsymbol{q}`-point dependent weights.
29 If sample has incoherent correlation functions, but :attr:`weights` does not contain
30 information on how to weight the incoherent part, then it will be dropped from the
31 returned :attr:`Sample` object (and analogously for current correlation functions).
33 Parameters
34 ----------
35 sample
36 Input sample to be weighted.
37 weights
38 Object containing the weights :math:`f_\mathrm{X}(\boldsymbol{q})`.
39 atom_type_map
40 Map between the atom types in the :class:`Sample` and the ones used in
41 the :class:`Weights` object, e.g., Ba → Ba(2+).
43 Returns
44 -------
45 A :class:`Sample` instance with the weighted partial and total structure factors.
46 """
48 # check input arguments
49 if sample.has_incoherent and not weights.supports_incoherent:
50 warn('The Weights class does not support incoherent scattering, dropping the latter '
51 'from the weighted sample.')
53 if sample.has_currents and not weights.supports_currents:
54 warn('The Weights class does not support current correlations, dropping the latter '
55 'from the weighted sample.')
57 # setup new input dicts for new Sample
58 data_dict = dict()
59 for key in sample.dimensions:
60 data_dict[key] = sample[key]
62 # Map the atom types in the sample to the types in the weights object.
63 # Useful, for instance, when using weights for charged atomic species.
64 atom_types = [(at, at) for at in sample.atom_types]
65 if atom_type_map is not None:
66 # Fallback to use the atom type from species (`at`) if it's not mapped.
67 atom_types = [(at, atom_type_map.get(at, at)) for at in atom_type_map]
69 # generate atomic weights for each q-point and compile to arrays
70 if 'q_norms' in sample.dimensions:
71 q_norms = sample.q_norms
72 else:
73 q_norms = np.linalg.norm(sample.q_points, axis=1)
75 weights_coh = dict()
76 for at, weight_at in atom_types:
77 weight_array = np.reshape([weights.get_weight_coh(weight_at, q) for q in q_norms], (-1, 1))
78 weights_coh[at] = weight_array
79 if sample.has_incoherent and weights.supports_incoherent:
80 weights_incoh = dict()
81 for at, weight_at in atom_types:
82 weight_array = np.reshape([
83 weights.get_weight_incoh(weight_at, q) for q in q_norms
84 ], (-1, 1))
85 weights_incoh[at] = weight_array
87 # weighting of correlation functions
88 if isinstance(sample, StaticSample):
89 data_dict_Sq = _compute_weighting_coherent(sample, 'Sq', weights_coh)
90 data_dict.update(data_dict_Sq)
91 elif isinstance(sample, DynamicSample): 91 ↛ 117line 91 didn't jump to line 117 because the condition on line 91 was always true
92 # coherent
93 Fqt_coh_dict = _compute_weighting_coherent(sample, 'Fqt_coh', weights_coh)
94 data_dict.update(Fqt_coh_dict)
95 Sqw_coh_dict = _compute_weighting_coherent(sample, 'Sqw_coh', weights_coh)
96 data_dict.update(Sqw_coh_dict)
98 # incoherent
99 if sample.has_incoherent and weights.supports_incoherent:
100 Fqt_incoh_dict = _compute_weighting_incoherent(sample, 'Fqt_incoh', weights_incoh)
101 data_dict.update(Fqt_incoh_dict)
102 Sqw_incoh_dict = _compute_weighting_incoherent(sample, 'Sqw_incoh', weights_incoh)
103 data_dict.update(Sqw_incoh_dict)
105 # currents
106 if sample.has_currents and weights.supports_currents:
107 Clqt_dict = _compute_weighting_coherent(sample, 'Clqt', weights_coh)
108 data_dict.update(Clqt_dict)
109 Clqw_dict = _compute_weighting_coherent(sample, 'Clqw', weights_coh)
110 data_dict.update(Clqw_dict)
112 Ctqt_dict = _compute_weighting_coherent(sample, 'Ctqt', weights_coh)
113 data_dict.update(Ctqt_dict)
114 Ctqw_dict = _compute_weighting_coherent(sample, 'Ctqw', weights_coh)
115 data_dict.update(Ctqw_dict)
117 new_sample = sample.__class__(
118 data_dict,
119 simulation_data=deepcopy(sample.simulation_data),
120 history=deepcopy(sample.history))
121 new_sample._append_history(
122 'get_weighted_sample',
123 dict(
124 atom_type_map=atom_type_map,
125 weights_class=weights.__class__.__name__,
126 weights_parameters=weights.parameters.to_dict(),
127 ))
129 return new_sample
132def _compute_weighting_coherent(
133 sample: Sample,
134 name: str,
135 weight_dict: dict,
136) -> dict[str, NDArray[float]]:
137 """
138 Helper function for weighting and summing partial coherent correlation functions.
139 """
140 data_dict = dict()
141 total = np.zeros(sample[name].shape)
142 for s1, s2 in sample.pairs:
143 key_pair = f'{name}_{s1}_{s2}'
144 partial = np.real(np.conjugate(weight_dict[s1]) * weight_dict[s2]) * sample[key_pair]
145 data_dict[key_pair] = partial
146 total += partial
147 data_dict[name] = total
148 return data_dict
151def _compute_weighting_incoherent(
152 sample: Sample,
153 name: str,
154 weight_dict: dict,
155) -> dict[str, NDArray[float]]:
156 """
157 Helper function for weighting and summing partial incoherent correlation functions.
158 """
159 data_dict = dict()
160 total = np.zeros(sample[name].shape)
161 for s1 in sample.atom_types:
162 key = f'{name}_{s1}'
163 partial = np.real(np.conjugate(weight_dict[s1]) * weight_dict[s1]) * sample[key]
164 data_dict[key] = partial
165 total += partial
166 data_dict[name] = total
167 return data_dict