diff --git a/topfarm/_topfarm.py b/topfarm/_topfarm.py index 1a9a861434d12c48c43a451ef4a7d652895bf8c7..d3de26ec9fef144aad7691a50fa8481faaec62e3 100644 --- a/topfarm/_topfarm.py +++ b/topfarm/_topfarm.py @@ -75,9 +75,13 @@ class TopFarmProblem(Problem): self.recorder = NestedTopFarmListRecorder(self.cost_comp, self.record_id) else: self.recorder = TopFarmListRecorder(self.record_id) - if state == {} and len(self.recorder.driver_iteration_lst) > 0: - state = {k: self.recorder.get(k)[-1] for k in self.state.keys()} self.update_state(state) + if len(self.recorder.driver_iteration_lst) > 0: + try: + self.update_state({k: self.recorder.get(k)[-1] for k in self.state.keys() if k not in state}) + except ValueError: + pass # loaded state does not fit into dimension of current state + self.driver.add_recorder(self.recorder) self.run_driver() self.cleanup() diff --git a/topfarm/recorders.py b/topfarm/recorders.py index c3a1eba21aaf0b835b4faf1ef66001000de466da..216dcc1ccb80f9aac158a2a175dd2bdbbf3c13dc 100644 --- a/topfarm/recorders.py +++ b/topfarm/recorders.py @@ -116,6 +116,9 @@ class ListRecorder(BaseRecorder): if res.shape[-1] == 1: res = res[:, 0] return res + + def __getitem__(self, key): + return self.get(key) def keys(self): return list(np.unique(['counter', 'iteration_coordinate', 'timestamp', 'success', 'msg'] + @@ -264,6 +267,8 @@ class TopFarmListRecorder(ListRecorder): plt.axis('equal') return ln + init() + def update(frame): title.set_text("%f (%.2f%%)" % (cost[frame], (cost[frame] - cost[0]) / cost[0] * 100)) diff --git a/topfarm/tests/test_topfarm_utils/test_recorders.py b/topfarm/tests/test_topfarm_utils/test_recorders.py index 49ca7b74bbfcdb9b56460911d6e9efde0164ee42..c52afd42e78d052d19b0b512c54c78db9a9bd012 100644 --- a/topfarm/tests/test_topfarm_utils/test_recorders.py +++ b/topfarm/tests/test_topfarm_utils/test_recorders.py @@ -70,6 +70,8 @@ def test_ListRecorder(): cases = recorder.driver_cases assert cases.num_cases == 4 npt.assert_array_equal(recorder.get('counter'), range(1, 5)) + npt.assert_array_equal(recorder['counter'], range(1, 5)) + npt.assert_array_almost_equal(recorder.get(['x', 'y', 'f_xy']), xyf, 4) for xyf, k in zip(xyf[0], ['x', 'y', 'f_xy']): npt.assert_allclose(cases.get_case(0).outputs[k][0], xyf)