[Improvement] fix discourse loader to avoid rate limit (#953)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -22,7 +22,7 @@ os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
|||||||
|
|
||||||
app = App()
|
app = App()
|
||||||
|
|
||||||
app.add("openai", data_type="discourse", loader=dicourse_loader)
|
app.add("openai after:2023-10-1", data_type="discourse", loader=dicourse_loader)
|
||||||
|
|
||||||
question = "Where can I find the OpenAI API status page?"
|
question = "Where can I find the OpenAI API status page?"
|
||||||
app.query(question)
|
app.query(question)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import concurrent.futures
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -32,7 +32,11 @@ class DiscourseLoader(BaseLoader):
|
|||||||
def _load_post(self, post_id):
|
def _load_post(self, post_id):
|
||||||
post_url = f"{self.domain}posts/{post_id}.json"
|
post_url = f"{self.domain}posts/{post_id}.json"
|
||||||
response = requests.get(post_url)
|
response = requests.get(post_url)
|
||||||
response.raise_for_status()
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load post {post_id}: {e}")
|
||||||
|
return
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
post_contents = clean_string(response_data.get("raw"))
|
post_contents = clean_string(response_data.get("raw"))
|
||||||
meta_data = {
|
meta_data = {
|
||||||
@@ -55,18 +59,19 @@ class DiscourseLoader(BaseLoader):
|
|||||||
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
|
logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
|
||||||
search_url = f"{self.domain}search.json?q={query}"
|
search_url = f"{self.domain}search.json?q={query}"
|
||||||
response = requests.get(search_url)
|
response = requests.get(search_url)
|
||||||
response.raise_for_status()
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to search query {query}: {e}")
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
post_ids = response_data.get("grouped_search_result").get("post_ids")
|
post_ids = response_data.get("grouped_search_result").get("post_ids")
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
for id in post_ids:
|
||||||
future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
|
post_data = self._load_post(id)
|
||||||
for future in concurrent.futures.as_completed(future_to_post_id):
|
if post_data:
|
||||||
post_id = future_to_post_id[future]
|
data.append(post_data)
|
||||||
try:
|
data_contents.append(post_data.get("content"))
|
||||||
post_data = future.result()
|
# Sleep for 0.4 sec, to avoid rate limiting. Check `https://meta.discourse.org/t/api-rate-limits/208405/6`
|
||||||
data.append(post_data)
|
time.sleep(0.4)
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to load post {post_id}: {e}")
|
|
||||||
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
|
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
|
||||||
response_data = {"doc_id": doc_id, "data": data}
|
response_data = {"doc_id": doc_id, "data": data}
|
||||||
return response_data
|
return response_data
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import time
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class YoutubeVideoLoader(BaseLoader):
|
|||||||
doc = loader.load()
|
doc = loader.load()
|
||||||
output = []
|
output = []
|
||||||
if not len(doc):
|
if not len(doc):
|
||||||
raise ValueError("No data found")
|
raise ValueError(f"No data found for url: {url}")
|
||||||
content = doc[0].page_content
|
content = doc[0].page_content
|
||||||
content = clean_string(content)
|
content = clean_string(content)
|
||||||
meta_data = doc[0].metadata
|
meta_data = doc[0].metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeyp
|
|||||||
assert "meta_data" in post_data
|
assert "meta_data" in post_data
|
||||||
|
|
||||||
|
|
||||||
def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch):
|
def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch, caplog):
|
||||||
def mock_get(*args, **kwargs):
|
def mock_get(*args, **kwargs):
|
||||||
class MockResponse:
|
class MockResponse:
|
||||||
def raise_for_status(self):
|
def raise_for_status(self):
|
||||||
@@ -76,8 +76,9 @@ def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monke
|
|||||||
|
|
||||||
monkeypatch.setattr(requests, "get", mock_get)
|
monkeypatch.setattr(requests, "get", mock_get)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Test error"):
|
discourse_loader._load_post(123)
|
||||||
discourse_loader._load_post(123)
|
|
||||||
|
assert "Failed to load post" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
|
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
|
||||||
|
|||||||
Reference in New Issue
Block a user