From c1cddc59d3a6a301b9ebc9a85af24e07ec760421 Mon Sep 17 00:00:00 2001
From: David Robert Verelst <dave@dtu.dk>
Date: Fri, 8 Feb 2019 09:45:01 +0100
Subject: [PATCH] hawc2.StFile: add method to save st file (plus test)

---
 wetb/hawc2/st_file.py            | 40 +++++++++++++++++++++++++-------
 wetb/hawc2/tests/test_st_file.py | 23 +++++++++++++++---
 2 files changed, 52 insertions(+), 11 deletions(-)

diff --git a/wetb/hawc2/st_file.py b/wetb/hawc2/st_file.py
index d999671c..2d062d43 100644
--- a/wetb/hawc2/st_file.py
+++ b/wetb/hawc2/st_file.py
@@ -15,6 +15,10 @@ standard_library.install_aliases()
 
 import numpy as np
 
+
+stcols = "r m x_cg y_cg ri_x ri_y x_sh y_sh E G I_x I_y I_p k_x k_y A pitch x_e y_e"
+
+
 class StFile(object):
     """Read HAWC2 St (beam element structural data) file
 
@@ -69,7 +73,7 @@ class StFile(object):
         with open (filename) as fid:
             txt = fid.read()
 #         Some files starts with first set ("#1...") with out specifying number of sets
-#         no_maindata_sets = int(txt.strip()[0]) 
+#         no_maindata_sets = int(txt.strip()[0])
 #         assert no_maindata_sets == txt.count("#")
         self.main_data_sets = {}
         for mset in txt.split("#")[1:]:
@@ -83,8 +87,9 @@ class StFile(object):
                 set_data_dict[set_nr] = np.array([set_lines[i].split() for i in range(1, no_rows + 1)], dtype=np.float)
             self.main_data_sets[mset_nr] = set_data_dict
 
-        for i, name in enumerate("r m x_cg y_cg ri_x ri_y x_sh y_sh E G I_x I_y I_p k_x k_y A pitch x_e y_e".split()):
-            setattr(self, name, lambda radius=None, mset=1, set=1, column=i: self._value(radius, column, mset, set))
+        for i, name in enumerate(stcols.split()):
+            setattr(self, name, lambda radius=None, mset=1, set=1,
+                    column=i: self._value(radius, column, mset, set))
 
     def _value(self, radius, column, mset_nr=1, set_nr=1):
         st_data = self.main_data_sets[mset_nr][set_nr]
@@ -98,14 +103,35 @@ class StFile(object):
             return r
         return r[np.argmin(np.abs(r - radius))]
 
-    def to_str(self, mset=1, set=1):
+    def to_str(self, mset=1, set=1, precision='%12.5e '):
         d = self.main_data_sets[mset][set]
-        return "\n".join([("%12.5e "*d.shape[1]) % tuple(row) for row in d])
+        return '\n'.join([(precision*d.shape[1]) % tuple(row) for row in d])
+
+    def save(self, filename, precision='%15.07e', encoding='utf-8'):
+        """Save all data defined in main_data_sets to st file.
+        """
+        colwidth = len(precision % 1)
+        sep = '='*colwidth*len(stcols) + '\n'
+        colhead = ''.join([k.center(colwidth) for k in stcols.split()]) + '\n'
+
+        nsets = len(self.main_data_sets)
+
+        with open(filename, 'w', encoding=encoding) as fid:
+            fid.write('%i ; number of sets, Nset\n' % nsets)
+            for mset, set_data_dict in self.main_data_sets.items():
+                fid.write('#%i ; set number\n' % mset)
+                for set, set_array in set_data_dict.items():
+                    dstr = self.to_str(mset=mset, set=set, precision=precision)
+                    npoints = self.main_data_sets[mset][set].shape[0]
+                    fid.write(sep + colhead + sep)
+                    fid.write('$%i %i\n' % (set, npoints))
+                    fid.write(dstr + '\n')
 
 
 if __name__ == "__main__":
     import os
-    st = StFile(os.path.dirname(__file__) + r"/tests/test_files/DTU_10MW_RWT_Blade_st.dat")
+    cwd = os.path.dirname(__file__)
+    st = StFile(os.path.join(cwd, r'tests/test_files/DTU_10MW_RWT_Blade_st.dat'))
     print (st.m())
     print (st.E(radius=36, mset=1, set=1))  # Elastic blade
     print (st.E(radius=36, mset=1, set=2))  # stiff blade
@@ -122,7 +148,5 @@ if __name__ == "__main__":
     print (st.pitch(67.8883 - 0.01687))
     print (st.pitch(23.2446))
 
-
-
     #print (st.)
     #print (st.)
diff --git a/wetb/hawc2/tests/test_st_file.py b/wetb/hawc2/tests/test_st_file.py
index 8b678b65..5961b003 100644
--- a/wetb/hawc2/tests/test_st_file.py
+++ b/wetb/hawc2/tests/test_st_file.py
@@ -10,15 +10,19 @@ from __future__ import absolute_import
 from future import standard_library
 standard_library.install_aliases()
 import unittest
-from wetb.hawc2.st_file import StFile
 import os
 
+from numpy import testing
+
+from wetb.hawc2.st_file import StFile
+
+
 testfilepath = os.path.join(os.path.dirname(__file__), 'test_files/')  # test file path
 class TestStFile(unittest.TestCase):
 
 
     def test_stfile(self):
-        st = StFile(testfilepath + "DTU_10MW_RWT_Blade_st.dat")
+        st = StFile(testfilepath + 'DTU_10MW_RWT_Blade_st.dat')
         self.assertEqual(st.radius_st()[2], 3.74238)
         self.assertEqual(st.radius_st(3), 3.74238)
         self.assertEqual(st.x_e(67.7351), 4.4320990737400E-01)
@@ -27,10 +31,23 @@ class TestStFile(unittest.TestCase):
 
 
     def test_stfile_interpolate(self):
-        st = StFile(testfilepath + "DTU_10MW_RWT_Blade_st.dat")
+        st = StFile(testfilepath + 'DTU_10MW_RWT_Blade_st.dat')
         self.assertAlmostEqual(st.x_e(72.2261), 0.381148048)
         self.assertAlmostEqual(st.y_e(72.2261), 0.016692967)
 
+    def test_save(self):
+        fname = os.path.join(testfilepath, 'DTU_10MW_RWT_Blade_st.dat')
+        fname2 = os.path.join(testfilepath, 'DTU_10MW_RWT_Blade_st2.dat')
+        st = StFile(fname)
+        st.save(fname2, encoding='utf-8', precision='%20.12e')
+        st2 = StFile(fname2)
+        self.assertEqual(len(st.main_data_sets), len(st2.main_data_sets))
+        self.assertEqual(len(st.main_data_sets[1]), len(st2.main_data_sets[1]))
+        for k in st.main_data_sets[1]:
+            testing.assert_almost_equal(st.main_data_sets[1][k],
+                                        st2.main_data_sets[1][k], decimal=12)
+        os.remove(fname2)
+
 
 if __name__ == "__main__":
     #import sys;sys.argv = ['', 'Test.testName']
-- 
GitLab