diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 3dfd4f9a..f8ffa89d 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -372,13 +372,21 @@ class EmbedChain: def reset(self): """ Resets the database. Deletes all embeddings irreversibly. - `App` has to be reinitialized after using this method. + `App` does not have to be reinitialized after using this method. """ # Send anonymous telemetry thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",)) thread_telemetry.start() - + + collection_name = self.collection.name self.db.reset() + self.collection = self.config.db._get_or_create_collection(collection_name) + # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset. + # A downside of this implementation is, if you have two instances, + # the other instance will not get the updated `self.collection` attribute. + # A better way would be to create the collection if it is called again after being reset. + # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't. + # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do. @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None): @@ -397,4 +405,4 @@ class EmbedChain: metadata.update(extra_metadata) response = requests.post(url, json={"metadata": metadata}) - response.raise_for_status() + response.raise_for_status() \ No newline at end of file diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index bc483d2b..5673ffbd 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -37,3 +37,25 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): app.chat("What text did I give you?") self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge]) + + def test_add_after_reset(self): + """ + Test if the `App` instance is correctly reconstructed after a reset. + """ + app = App() + app.reset() + + # Make sure the client is still healthy + app.db.client.heartbeat() + # Make sure the collection exists, and can be added to + app.collection.add( + embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]], + metadatas=[ + {"chapter": "3", "verse": "16"}, + {"chapter": "3", "verse": "5"}, + {"chapter": "29", "verse": "11"}, + ], + ids=["id1", "id2", "id3"], + ) + + app.reset() diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 37252e4b..bb14ef6b 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -245,8 +245,7 @@ class TestChromaDbCollection(unittest.TestCase): # Resetting the first one should reset them all. app1.reset() - # Reinstantiate them - app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False)) + # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319) app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False)) app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))