From 8c00274adfe46e5d5e22452bb71567e7c1f968ad Mon Sep 17 00:00:00 2001
From: "Mads M. Pedersen" <mmpe@dtu.dk>
Date: Thu, 2 Feb 2017 14:54:35 +0100
Subject: [PATCH] Changed amplitude bins to range 0-max(ampl[weights>0]) in
 fatigue-cycle_matrix to avoid empty bins

---
 wetb/fatigue_tools/fatigue.py            | 17 +++++------------
 wetb/fatigue_tools/tests/test_fatigue.py |  7 +++++++
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/wetb/fatigue_tools/fatigue.py b/wetb/fatigue_tools/fatigue.py
index bde2d08..27b4c2c 100644
--- a/wetb/fatigue_tools/fatigue.py
+++ b/wetb/fatigue_tools/fatigue.py
@@ -116,8 +116,8 @@ def cycle_matrix(signals, ampl_bins=10, mean_bins=10, rainflow_func=rainflow_win
     Parameters
     ----------
     Signals : array-like or list of tuples
-        if array-like, the raw signal
-        if list of tuples, list of (weight, signal), e.g. [(0.1,sig1), (0.8,sig2), (.1,sig3)]
+        - if array-like, the raw signal\n
+        - if list of tuples, list of (weight, signal), e.g. [(0.1,sig1), (0.8,sig2), (.1,sig3)]\n
     ampl_bins : int or array-like, optional
         if int, Number of amplitude value bins (default is 10)
         if array-like, the bin edges for amplitude
@@ -149,19 +149,12 @@ def cycle_matrix(signals, ampl_bins=10, mean_bins=10, rainflow_func=rainflow_win
     """
 
     if isinstance(signals[0], tuple):
-        ampls = np.empty((0,), dtype=np.float64)
-        means = np.empty((0,), dtype=np.float64)
-        weights = np.empty((0,), dtype=np.float64)
-        for w, signal in signals:
-            a, m = rainflow_func(signal[:])
-            ampls = np.r_[ampls, a]
-            means = np.r_[means, m]
-            weights = np.r_[weights, (np.zeros_like(a) + w)]
+        weights, ampls, means = np.array([(np.zeros_like(ampl)+weight,ampl,mean) for weight, signal in signals for ampl,mean in rainflow_func(signal[:]).T], dtype=np.float64).T
     else:
         ampls, means = rainflow_func(signals[:])
         weights = np.ones_like(ampls)
     if isinstance(ampl_bins, int):
-        ampl_bins = np.linspace(0, 1, num=ampl_bins + 1) * ampls.max()
+        ampl_bins = np.linspace(0, 1, num=ampl_bins + 1) * ampls[weights>0].max()
     cycles, ampl_edges, mean_edges = np.histogram2d(ampls, means, [ampl_bins, mean_bins], weights=weights)
 
     ampl_bin_sum = np.histogram2d(ampls, means, [ampl_bins, mean_bins], weights=weights * ampls)[0]
@@ -169,7 +162,7 @@ def cycle_matrix(signals, ampl_bins=10, mean_bins=10, rainflow_func=rainflow_win
     mask = (cycles > 0)
     ampl_bin_mean[mask] = ampl_bin_sum[mask] / cycles[mask]
     mean_bin_sum = np.histogram2d(ampls, means, [ampl_bins, mean_bins], weights=weights * means)[0]
-    mean_bin_mean = np.zeros_like(cycles)
+    mean_bin_mean = np.zeros_like(cycles)+np.nan
     mean_bin_mean[cycles > 0] = mean_bin_sum[cycles > 0] / cycles[cycles > 0]
     cycles = cycles / 2  # to get full cycles
     return cycles, ampl_bin_mean, ampl_edges, mean_bin_mean, mean_edges
diff --git a/wetb/fatigue_tools/tests/test_fatigue.py b/wetb/fatigue_tools/tests/test_fatigue.py
index 93dd8ce..0ef2fef 100644
--- a/wetb/fatigue_tools/tests/test_fatigue.py
+++ b/wetb/fatigue_tools/tests/test_fatigue.py
@@ -73,6 +73,13 @@ class TestFatigueTools(unittest.TestCase):
                                                                                                            [  0., 1., 4., 0.],
                                                                                                            [  0., 0., 0., 0.],
                                                                                                            [  0., 1., 2., 0.]]) / 2, 0.001)
+        
+    def test_astm_weighted(self):
+        data = Hawc2io.ReadHawc2(testfilepath + "test").ReadBinary([2]).flatten()
+        np.testing.assert_allclose(cycle_matrix([(1, data),(1,data)], 4, 4, rainflow_func=rainflow_astm)[0], np.array([[ 24., 83., 53., 26.],
+                                                                                                           [  0., 1., 4., 0.],
+                                                                                                           [  0., 0., 0., 0.],
+                                                                                                           [  0., 1., 2., 0.]]) , 0.001)
 
 if __name__ == "__main__":
     #import sys;sys.argv = ['', 'Test.testName']
-- 
GitLab