Newer
Older
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_aliases()
tfp = os.path.join(os.path.dirname(__file__), 'test_files/')
def test_shearfile_save(self):
f = tfp + "tmp_shearfile1.dat"
shear_file.save(f, [-55, 55], [30, 100, 160] , u=np.array([[0.7, 1, 1.3], [0.7, 1, 1.3]]).T)
with open(f) as fid:
self.assertEqual(fid.read(),
""" # autogenerated shear file
2 3
# shear v component
0.0000000000 0.0000000000
0.0000000000 0.0000000000
0.0000000000 0.0000000000
# shear u component
0.7000000000 0.7000000000
1.0000000000 1.0000000000
1.3000000000 1.3000000000
# shear w component
0.0000000000 0.0000000000
0.0000000000 0.0000000000
0.0000000000 0.0000000000
30.0000000000
100.0000000000
160.0000000000
""")
os.remove(f)
shear_file.save(f, [-55, 55], [30, 100, 160] , u=np.array([0.7, 1, 1.3]).T)
with open(f) as fid:
self.assertEqual(fid.read(),
""" # autogenerated shear file
2 3
# shear v component
0.0000000000 0.0000000000
0.0000000000 0.0000000000
0.0000000000 0.0000000000
# shear u component
0.7000000000 0.7000000000
1.0000000000 1.0000000000
1.3000000000 1.3000000000
# shear w component
0.0000000000 0.0000000000
0.0000000000 0.0000000000
0.0000000000 0.0000000000
30.0000000000
100.0000000000
160.0000000000
""")
os.remove(f)
def test_shear_makedirs(self):
shear_file.save(f, [-55, 55], [30, 100, 160] , u=np.array([0.7, 1, 1.3]).T)
shutil.rmtree(tfp + "shear")
def test_shear_load(self):
shear_file = ShearFile.load(tfp+"data/user_shear.dat")
np.testing.assert_array_equal(shear_file.w_positions, [30,100,160])
self.assertEqual(shear_file.uvw(0,65)[0],.85)
self.assertEqual(shear_file.uvw(-55,65)[0],.9)
np.testing.assert_array_equal(shear_file.uvw([0,-55],[65,65])[0],[.85,.9])
shear_file = ShearFile.load_from_htc(tfp+"htcfiles/test.htc")
np.testing.assert_array_equal(shear_file.w_positions, [30,100,160])
np.testing.assert_array_almost_equal(shear_file.uvw([0,-55],[65,65])[0],np.array([.85,.9])*10+8.860807038)
if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.test_shearfile']
unittest.main()