Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/apps/sdk/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def update(tenant_id, chat_id):
req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
model_type = llm.pop("model_type")
model_type = llm.get("model_type")
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
Expand Down
4 changes: 2 additions & 2 deletions api/db/services/connector_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from common.constants import TaskStatus
from common.time_utils import current_timestamp, timestamp_to_date


class ConnectorService(CommonService):
model = Connector

Expand Down Expand Up @@ -202,14 +201,15 @@ def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
return None

class FileObj(BaseModel):
id: str
filename: str
blob: bytes

def read(self) -> bytes:
return self.blob

errs = []
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
doc_ids = []
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
errs.extend(err)
Expand Down
10 changes: 9 additions & 1 deletion api/db/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,15 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str

err, files = [], []
for file in file_objs:
doc_id = file.id if hasattr(file, "id") else get_uuid()
e, doc = DocumentService.get_by_id(doc_id)
if e:
blob = file.read()
settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
doc.size = len(blob)
doc = doc.to_dict()
DocumentService.update_by_id(doc["id"], doc)
continue
try:
DocumentService.check_doc_health(kb.tenant_id, file.filename)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
Expand All @@ -455,7 +464,6 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str
blob = read_potential_broken_pdf(blob)
settings.STORAGE_IMPL.put(kb.id, location, blob)

doc_id = get_uuid()

img = thumbnail_img(filename, blob)
thumbnail_location = ""
Expand Down
5 changes: 5 additions & 0 deletions api/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import xxhash


def string_to_bytes(string):
return string if isinstance(
Expand All @@ -22,3 +24,6 @@ def string_to_bytes(string):
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")

# 128 bit = 32 character
def hash128(data: str) -> str:
return xxhash.xxh128(data).hexdigest()
3 changes: 2 additions & 1 deletion rag/svr/sync_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from flask import json

from api.utils.common import hash128
from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
Expand Down Expand Up @@ -126,7 +127,7 @@ async def _run_task_logic(self, task: dict):
docs = []
for doc in document_batch:
d = {
"id": doc.id,
"id": hash128(doc.id),
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
Expand Down