From b423313ad0a851a4b42e7f4946d283ae01edf53f Mon Sep 17 00:00:00 2001
From: "Mads M. Pedersen" <mmpe@dtu.dk>
Date: Fri, 3 Dec 2021 12:28:22 +0000
Subject: [PATCH] automatic detect args4deficit if not set + add all extra
 surrogate variable to TensorflowSurrogate

---
 py_wake/deficit_models/deficit_model.py     |  5 +++++
 py_wake/deficit_models/gaussian.py          |  2 +-
 py_wake/utils/tensorflow_surrogate_utils.py | 13 ++++---------
 3 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/py_wake/deficit_models/deficit_model.py b/py_wake/deficit_models/deficit_model.py
index f1e93acbf..3594d89fb 100644
--- a/py_wake/deficit_models/deficit_model.py
+++ b/py_wake/deficit_models/deficit_model.py
@@ -1,11 +1,16 @@
 from abc import ABC, abstractmethod
 import numpy as np
 from numpy import newaxis as na
+import inspect
 
 
 class DeficitModel(ABC):
     deficit_initalized = False
 
+    def __init__(self):
+        if not hasattr(self, 'args4deficit'):
+            self.args4deficit = set(inspect.getfullargspec(self.calc_deficit).args) - {'self'}
+
     def _calc_layout_terms(self, **_):
         """Calculate layout dependent terms, which is not updated during simulation"""
 
diff --git a/py_wake/deficit_models/gaussian.py b/py_wake/deficit_models/gaussian.py
index c35d5fbc6..0e8bf9bb4 100644
--- a/py_wake/deficit_models/gaussian.py
+++ b/py_wake/deficit_models/gaussian.py
@@ -14,11 +14,11 @@ class BastankhahGaussianDeficit(ConvectionDeficitModel):
     A new analytical model for wind-turbine wakes.
     J. Renew. Energy. 2014;70:116-23.
     """
-    args4deficit = ['WS_ilk', 'WS_eff_ilk', 'dw_ijlk', 'cw_ijlk', 'D_src_il', 'ct_ilk']
 
     def __init__(self, k=0.0324555, use_effective_ws=False):
         self._k = k
         self.use_effective_ws = use_effective_ws
+        ConvectionDeficitModel.__init__(self)
 
     def k_ilk(self, **_):
         return np.reshape(self._k, (1, 1, 1))
diff --git a/py_wake/utils/tensorflow_surrogate_utils.py b/py_wake/utils/tensorflow_surrogate_utils.py
index 8efba343e..306b04278 100644
--- a/py_wake/utils/tensorflow_surrogate_utils.py
+++ b/py_wake/utils/tensorflow_surrogate_utils.py
@@ -41,13 +41,8 @@ class TensorflowSurrogate():
         path = Path(path)
         with open(path / 'extra_data.json') as fid:
             extra_data = json.load(fid)
-
-        self.input_channel_names = extra_data['input_channel_names']
-        self.output_channel_name = extra_data['output_channel_name']
-        self.wind_speed_cut_in = extra_data['wind_speed_cut_in']
-        self.wind_speed_cut_out = extra_data['wind_speed_cut_out']
-        if 'wohler_exponent' in extra_data:
-            self.wohler_exponent = extra_data['wohler_exponent']
+        for k, v in extra_data.items():
+            setattr(self, k, v)
 
         # Create the MinMaxScaler scaler objects.
         def json2scaler(d):
@@ -56,8 +51,8 @@ class TensorflowSurrogate():
                 setattr(scaler, k, v)
             return scaler
 
-        self.input_scaler = json2scaler(extra_data['input_scalers'][set_name])
-        self.output_scaler = json2scaler(extra_data['output_scalers'][set_name])
+        self.input_scaler = json2scaler(self.input_scalers[set_name])
+        self.output_scaler = json2scaler(self.output_scalers[set_name])
 
         self.model = tf.keras.models.load_model(path / f'model_set_{set_name}.h5')
 
-- 
GitLab