Coverage for dynasor / core / time_averager.py: 100%
29 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 06:28 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 06:28 +0000
1import numpy as np
2from numpy.typing import NDArray
5def truncate_nans(data):
6 """
7 Truncate trailing time columns that contain NaN values.
9 The function assumes that the last axis of the array corresponds to
10 time and removes all columns from the first occurrence of a NaN
11 (in any row) to the end of the array.
13 This guarantees that the returned array contains **no NaN values**,
14 even if this means discarding some valid data points that appear
15 after a NaN in another row.
17 Parameters
18 ----------
19 data : array_like, shape (N, T)
20 Input array where rows represent independent signals (e.g. q-points)
21 and the last axis represents time.
23 Returns
24 -------
25 ndarray
26 Array truncated along the time axis such that no NaN values remain.
27 The returned array has the same number of rows as the input but may
28 have fewer time columns.
30 Notes
31 -----
32 Example::
34 [[1.0, 0.8, 0.7, 0.6, 0.5, 0.3, nan],
35 [1.0, 0.8, 0.7, 0.6, 0.5, nan, nan]]
37 becomes::
39 [[1.0, 0.8, 0.7, 0.6, 0.5],
40 [1.0, 0.8, 0.7, 0.6, 0.5]]
42 because the first NaN appears in column 5.
43 """
45 arr = np.asarray(data)
46 assert arr.ndim == 2
48 # Determine which time columns contain any NaN values
49 column_has_nan = np.any(np.isnan(arr), axis=0)
51 # Find the first column containing a NaN
52 nan_indices = np.where(column_has_nan)[0]
54 # No NaNs present, return array unchanged
55 if len(nan_indices) == 0:
56 return arr
58 # NaNs were found, Keep only the column without NaNs
59 first_nan_column = nan_indices[0]
60 return arr[:, :first_nan_column]
63class TimeAverager:
64 """Naive special purpose averager class used in dynasor to collect and time-average arrays
65 obtained from sliding time-window averaging.
67 It assists with keeping track of how many data samples have been added to each slot.
69 It will time-average arrays of shape ``(Nq, time_window)`` where ``Ǹq`` is the
70 number of q-points and ``time_window`` is the size of the time window.
72 Parameters
73 ----------
74 time_window
75 Size of the time window in which the time-average happens.
76 array_length
77 Length of the array to be averaged for each time-lag, i.e., number of q-points.
78 """
80 def __init__(self, time_window: int, array_length: int):
81 assert time_window >= 1
82 self._time_window = time_window
83 self._array_length = array_length
85 self._counts = np.zeros(time_window, dtype=int)
86 self._arrays = [np.zeros(array_length) for _ in range(time_window)]
88 def add_sample(self, time_lag: int, sample: np.ndarray):
89 assert len(sample) == self._array_length
90 self._counts[time_lag] += 1
91 self._arrays[time_lag] += sample
93 def get_average_at_timelag(self, time_lag: int):
94 if self._counts[time_lag] == 0:
95 array = np.full((self._array_length, ), np.nan)
96 return array
97 return self._arrays[time_lag] / self._counts[time_lag]
99 def get_average_all(self) -> NDArray[float]:
100 """
101 Returns an averaged array of shape ``(array_length, time_window)``.
102 """
103 return np.array([self.get_average_at_timelag(t) for t in range(self._time_window)]).T