diff --git a/py_wake/utils/tensorflow_surrogate_utils.py b/py_wake/utils/tensorflow_surrogate_utils.py index 71cff543b5e85df093182219e983bbccc77d0ff8..8efba343eeb54c2d2e86b0a00734f0fbaf1089b0 100644 --- a/py_wake/utils/tensorflow_surrogate_utils.py +++ b/py_wake/utils/tensorflow_surrogate_utils.py @@ -101,7 +101,7 @@ class TensorflowSurrogate(): mi, ma = self.input_scaler.data_min_[i], self.input_scaler.data_max_[i] warnings.warn(f"Input, {k}, with value, {max_v} outside range {mi}-{ma}") - return self.output_scaler.inverse_transform(self.model.predict(x_scaled)) + return self.output_scaler.inverse_transform(self.model.predict(x_scaled, batch_size=x.shape[0])) @property def input_space(self):