Coverage for dynasor / core / time_averager.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-10 06:28 +0000

1import numpy as np 

2from numpy.typing import NDArray 

3 

4 

5def truncate_nans(data): 

6 """ 

7 Truncate trailing time columns that contain NaN values. 

8 

9 The function assumes that the last axis of the array corresponds to 

10 time and removes all columns from the first occurrence of a NaN 

11 (in any row) to the end of the array. 

12 

13 This guarantees that the returned array contains **no NaN values**, 

14 even if this means discarding some valid data points that appear 

15 after a NaN in another row. 

16 

17 Parameters 

18 ---------- 

19 data : array_like, shape (N, T) 

20 Input array where rows represent independent signals (e.g. q-points) 

21 and the last axis represents time. 

22 

23 Returns 

24 ------- 

25 ndarray 

26 Array truncated along the time axis such that no NaN values remain. 

27 The returned array has the same number of rows as the input but may 

28 have fewer time columns. 

29 

30 Notes 

31 ----- 

32 Example:: 

33 

34 [[1.0, 0.8, 0.7, 0.6, 0.5, 0.3, nan], 

35 [1.0, 0.8, 0.7, 0.6, 0.5, nan, nan]] 

36 

37 becomes:: 

38 

39 [[1.0, 0.8, 0.7, 0.6, 0.5], 

40 [1.0, 0.8, 0.7, 0.6, 0.5]] 

41 

42 because the first NaN appears in column 5. 

43 """ 

44 

45 arr = np.asarray(data) 

46 assert arr.ndim == 2 

47 

48 # Determine which time columns contain any NaN values 

49 column_has_nan = np.any(np.isnan(arr), axis=0) 

50 

51 # Find the first column containing a NaN 

52 nan_indices = np.where(column_has_nan)[0] 

53 

54 # No NaNs present, return array unchanged 

55 if len(nan_indices) == 0: 

56 return arr 

57 

58 # NaNs were found, Keep only the column without NaNs 

59 first_nan_column = nan_indices[0] 

60 return arr[:, :first_nan_column] 

61 

62 

63class TimeAverager: 

64 """Naive special purpose averager class used in dynasor to collect and time-average arrays 

65 obtained from sliding time-window averaging. 

66 

67 It assists with keeping track of how many data samples have been added to each slot. 

68 

69 It will time-average arrays of shape ``(Nq, time_window)`` where ``Ǹq`` is the 

70 number of q-points and ``time_window`` is the size of the time window. 

71 

72 Parameters 

73 ---------- 

74 time_window 

75 Size of the time window in which the time-average happens. 

76 array_length 

77 Length of the array to be averaged for each time-lag, i.e., number of q-points. 

78 """ 

79 

80 def __init__(self, time_window: int, array_length: int): 

81 assert time_window >= 1 

82 self._time_window = time_window 

83 self._array_length = array_length 

84 

85 self._counts = np.zeros(time_window, dtype=int) 

86 self._arrays = [np.zeros(array_length) for _ in range(time_window)] 

87 

88 def add_sample(self, time_lag: int, sample: np.ndarray): 

89 assert len(sample) == self._array_length 

90 self._counts[time_lag] += 1 

91 self._arrays[time_lag] += sample 

92 

93 def get_average_at_timelag(self, time_lag: int): 

94 if self._counts[time_lag] == 0: 

95 array = np.full((self._array_length, ), np.nan) 

96 return array 

97 return self._arrays[time_lag] / self._counts[time_lag] 

98 

99 def get_average_all(self) -> NDArray[float]: 

100 """ 

101 Returns an averaged array of shape ``(array_length, time_window)``. 

102 """ 

103 return np.array([self.get_average_at_timelag(t) for t in range(self._time_window)]).T