From 9e2524dde26e82bb7203414c829b66d07c798951 Mon Sep 17 00:00:00 2001 From: "Mads M. Pedersen" <mmpe@dtu.dk> Date: Tue, 21 Aug 2018 15:27:09 +0200 Subject: [PATCH] A few improvements --- topfarm/_topfarm.py | 8 ++++++-- topfarm/recorders.py | 5 +++++ topfarm/tests/test_topfarm_utils/test_recorders.py | 2 ++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/topfarm/_topfarm.py b/topfarm/_topfarm.py index 1a9a8614..d3de26ec 100644 --- a/topfarm/_topfarm.py +++ b/topfarm/_topfarm.py @@ -75,9 +75,13 @@ class TopFarmProblem(Problem): self.recorder = NestedTopFarmListRecorder(self.cost_comp, self.record_id) else: self.recorder = TopFarmListRecorder(self.record_id) - if state == {} and len(self.recorder.driver_iteration_lst) > 0: - state = {k: self.recorder.get(k)[-1] for k in self.state.keys()} self.update_state(state) + if len(self.recorder.driver_iteration_lst) > 0: + try: + self.update_state({k: self.recorder.get(k)[-1] for k in self.state.keys() if k not in state}) + except ValueError: + pass # loaded state does not fit into dimension of current state + self.driver.add_recorder(self.recorder) self.run_driver() self.cleanup() diff --git a/topfarm/recorders.py b/topfarm/recorders.py index c3a1eba2..216dcc1c 100644 --- a/topfarm/recorders.py +++ b/topfarm/recorders.py @@ -116,6 +116,9 @@ class ListRecorder(BaseRecorder): if res.shape[-1] == 1: res = res[:, 0] return res + + def __getitem__(self, key): + return self.get(key) def keys(self): return list(np.unique(['counter', 'iteration_coordinate', 'timestamp', 'success', 'msg'] + @@ -264,6 +267,8 @@ class TopFarmListRecorder(ListRecorder): plt.axis('equal') return ln + init() + def update(frame): title.set_text("%f (%.2f%%)" % (cost[frame], (cost[frame] - cost[0]) / cost[0] * 100)) diff --git a/topfarm/tests/test_topfarm_utils/test_recorders.py b/topfarm/tests/test_topfarm_utils/test_recorders.py index 49ca7b74..c52afd42 100644 --- a/topfarm/tests/test_topfarm_utils/test_recorders.py +++ b/topfarm/tests/test_topfarm_utils/test_recorders.py @@ -70,6 +70,8 @@ def test_ListRecorder(): cases = recorder.driver_cases assert cases.num_cases == 4 npt.assert_array_equal(recorder.get('counter'), range(1, 5)) + npt.assert_array_equal(recorder['counter'], range(1, 5)) + npt.assert_array_almost_equal(recorder.get(['x', 'y', 'f_xy']), xyf, 4) for xyf, k in zip(xyf[0], ['x', 'y', 'f_xy']): npt.assert_allclose(cases.get_case(0).outputs[k][0], xyf) -- GitLab