diff --git a/example/extract/extract_html.ipynb b/example/extract/extract_html.ipynb index a7e27aa5..0aac390d 100644 --- a/example/extract/extract_html.ipynb +++ b/example/extract/extract_html.ipynb @@ -67,13 +67,13 @@ { "data": { "text/plain": [ - "{'extract': ['ExtractImageFlow',\n", + "{'extract': ['ExtractHTMLFlow',\n", + " 'ExtractImageFlow',\n", " 'ExtractIpynbFlow',\n", " 'ExtractMarkdownFlow',\n", " 'ExtractPDFFlow',\n", " 'ExtractTxtFlow',\n", - " 'ExtractS3TxtFlow',\n", - " 'ExtractHTMLFlow'],\n", + " 'ExtractS3TxtFlow'],\n", " 'transform': ['TransformAzureOpenAIFlow',\n", " 'TransformCopyFlow',\n", " 'TransformHuggingFaceFlow',\n", @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -141,21 +141,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/1 [00:00 20]\n", - "pprint.pprint(text[:20])" + "text = output[0]['output'][0]['text'][0:30]\n", + "text = [p for p in text if len(p) > 10]\n", + "pprint.pprint(text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## End of the notebook\n", + "\n", + "Check more Uniflow use cases in the [example folder](https://github.com/CambioML/uniflow/tree/main/example/model#examples)!\n", + "\n", + "\n", + " \n", + "" ] } ], diff --git a/example/extract/extract_pdf_with_recursive_splitter.ipynb b/example/extract/extract_pdf_with_recursive_splitter.ipynb new file mode 100644 index 00000000..d524e467 --- /dev/null +++ b/example/extract/extract_pdf_with_recursive_splitter.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of loading PDF using recursive splitter\n", + "\n", + "Recursive Splitter: Splitting text by recursively look at characters.\n", + "Recursively tries to split by different characters to find one that works." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Before running the code\n", + "\n", + "You will need to `uniflow` conda environment to run this notebook. You can set up the environment following the instruction: https://github.com/CambioML/uniflow/tree/main#installation. Furthermore, make sure you have the following packages installed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pip3 install nougat-ocr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "\n", + "sys.path.append(\".\")\n", + "sys.path.append(\"..\")\n", + "sys.path.append(\"../..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/anaconda3/envs/uniflow/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "import pprint\n", + "from uniflow.flow.client import ExtractClient, TransformClient\n", + "from uniflow.flow.config import TransformOpenAIConfig, ExtractPDFConfig\n", + "from uniflow.op.model.model_config import OpenAIModelConfig, NougatModelConfig\n", + "from uniflow.op.prompt import PromptTemplate, Context\n", + "from uniflow.op.extract.split.splitter_factory import SplitterOpsFactory\n", + "from uniflow.op.extract.split.constants import RECURSIVE_CHARACTER_SPLITTER" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare the input data\n", + "\n", + "First, let's set current directory and input data directory, and load the raw data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "dir_cur = os.getcwd()\n", + "pdf_file = \"1408.5882_page-1.pdf\"\n", + "input_file = os.path.join(f\"{dir_cur}/data/raw_input/\", pdf_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### List all the available splitters\n", + "These are the different splitters we can use to post-process the loaded PDF." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['ParagraphSplitter', 'MarkdownHeaderSplitter', 'RecursiveCharacterSplitter']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SplitterOpsFactory.list()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Load the pdf using recursive splitter" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/anaconda3/envs/uniflow/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + } + ], + "source": [ + "data = [\n", + " {\"filename\": input_file},\n", + "]\n", + "\n", + "config = ExtractPDFConfig(\n", + " model_config=NougatModelConfig(\n", + " model_name = \"0.1.0-small\",\n", + " batch_size = 1 # When batch_size>1, nougat will run on CUDA, otherwise it will run on CPU\n", + " ),\n", + " splitter=RECURSIVE_CHARACTER_SPLITTER,\n", + ")\n", + "nougat_client = ExtractClient(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1 [00:00\n", + " \n", + "" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "uniflow", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/uniflow/op/extract/load/html_op.py b/uniflow/op/extract/load/html_op.py index 5e7dbc90..134a37c6 100644 --- a/uniflow/op/extract/load/html_op.py +++ b/uniflow/op/extract/load/html_op.py @@ -10,6 +10,19 @@ class ExtractHTMLOp(Op): """Extract HTML Op Class.""" + def __init__(self, name: str) -> None: + try: + import requests # pylint: disable=import-outside-toplevel + from bs4 import BeautifulSoup # pylint: disable=import-outside-toplevel + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "Please install bs4. You can use `pip install bs4` to install them." + ) from exc + + super().__init__(name) + self._requests_client = requests + self._beautiful_soup_parser = BeautifulSoup + def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: """Run Model Op. @@ -22,19 +35,32 @@ def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: output_nodes = [] for node in nodes: value_dict = copy.deepcopy(node.value_dict) + if "url" in value_dict: - import requests # pylint: disable=import-outside-toplevel + resp = self._requests_client.get(url=value_dict["url"], timeout=300) + if not resp.ok: + raise ValueError(f"URL return an error: {resp.status_code}") + + content_type = resp.headers.get("Content-Type", "") + if not content_type.startswith("text/html"): + raise ValueError( + f"Expected content type text/html. Got {content_type}." + ) - resp = requests.get(url=value_dict["url"], timeout=300) text = resp.text - else: + + elif "filename" in value_dict: with open( value_dict["filename"], "r", encoding=value_dict.get("encoding", "utf-8"), ) as f: text = f.read() - text = self.parse_html(text) + + else: + raise ValueError("Expected url or filename param.") + + text = self._parse_html(text) output_nodes.append( Node( name=self.unique_name(), @@ -44,23 +70,23 @@ def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: ) return output_nodes - def parse_html(self, text): - """Function Parse Html.""" - try: - from bs4 import BeautifulSoup # pylint: disable=import-outside-toplevel - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "Please install bs4. You can use `pip install bs4` to install them." - ) from exc + def _parse_html(self, text: str) -> str: + """Function Parse Html. + + Args: + text (str): Raw html text. - soup = BeautifulSoup(text, "html.parser") + Returns: + str: Parsed html text. + """ + soup = self._beautiful_soup_parser(text, "html.parser") if soup.title: - title = str(soup.title.string) + title = str(soup.title.string) + "\n\n" else: title = "" - return title + "\n".join(soup.body.stripped_strings) + return title + "\n\n".join(soup.body.stripped_strings) class ProcessHTMLOp(Op): diff --git a/uniflow/op/extract/load/ipynb_op.py b/uniflow/op/extract/load/ipynb_op.py index babceefd..fd1b7113 100644 --- a/uniflow/op/extract/load/ipynb_op.py +++ b/uniflow/op/extract/load/ipynb_op.py @@ -10,15 +10,7 @@ class ExtractIpynbOp(Op): """Extract ipynb Op Class.""" - def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: - """Run Model Op. - - Args: - nodes (Sequence[Node]): Nodes to run. - - Returns: - Sequence[Node]: Nodes after running. - """ + def __init__(self, name: str) -> None: try: import nbformat # pylint: disable=import-outside-toplevel from nbconvert import ( # pylint: disable=import-outside-toplevel @@ -28,11 +20,26 @@ def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: raise ModuleNotFoundError( "Please install nbformat and nbconvert to load ipynb file. You can use `pip install nbformat nbconvert` to install them." ) from exc + + super().__init__(name) + self._nbformat = nbformat + self._markdown_exporter = MarkdownExporter + + def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: + """Run Model Op. + + Args: + nodes (Sequence[Node]): Nodes to run. + + Returns: + Sequence[Node]: Nodes after running. + """ + output_nodes = [] for node in nodes: value_dict = copy.deepcopy(node.value_dict) - nb = nbformat.read(value_dict["filename"], as_version=4) - md_exporter = MarkdownExporter() + nb = self._nbformat.read(value_dict["filename"], as_version=4) + md_exporter = self._markdown_exporter() (text, _) = md_exporter.from_notebook_node(nb) output_nodes.append( Node( diff --git a/uniflow/op/extract/split/constants.py b/uniflow/op/extract/split/constants.py index b751eef2..c7e9192c 100644 --- a/uniflow/op/extract/split/constants.py +++ b/uniflow/op/extract/split/constants.py @@ -2,3 +2,4 @@ PARAGRAPH_SPLITTER = "ParagraphSplitter" MARKDOWN_HEADER_SPLITTER = "MarkdownHeaderSplitter" +RECURSIVE_CHARACTER_SPLITTER = "RecursiveCharacterSplitter" diff --git a/uniflow/op/extract/split/recursive_character_splitter.py b/uniflow/op/extract/split/recursive_character_splitter.py new file mode 100644 index 00000000..1c90fc2b --- /dev/null +++ b/uniflow/op/extract/split/recursive_character_splitter.py @@ -0,0 +1,174 @@ +"""Recursive character split op.""" + +import copy +import re +from typing import Iterable, List, Optional, Sequence + +from uniflow.node import Node +from uniflow.op.op import Op + + +class RecursiveCharacterSplitter(Op): + """Recursive character splitter class.""" + + default_separators = ["\n\n", "\n", " ", ""] + + def __init__( + self, + name: str, + chunk_size: int = 1024, + chunk_overlap_size: int = 0, + separators: Optional[List[str]] = None, + ) -> None: + """Recursive Splitter Op Constructor + + This has the effect of trying to keep all paragraphs (and then sentences, and then words) together + as long as possible, as those would generically seem to be the strongest semantically related pieces of text. + + Args: + name (str): Name of the op. + chunk_size (int): Maximum size of chunks to return. + chunk_overlap_size (int): Overlap in characters between chunks. + separators (List[str]): Separators to use. + keep_separator: Whether to keep the separator in the chunks. + """ + super().__init__(name) + self._chunk_size = chunk_size + self._chunk_overlap_size = chunk_overlap_size + self._separators = separators or self.default_separators + + def __call__(self, nodes: Sequence[Node]) -> Sequence[Node]: + """Run Model Op. + + Args: + nodes (Sequence[Node]): Nodes to run. + + Returns: + Sequence[Node]: Nodes after running. + """ + output_nodes = [] + for node in nodes: + value_dict = copy.deepcopy(node.value_dict) + text = value_dict["text"] + text = self._recursive_splitter(text.strip(), self._separators) + output_nodes.append( + Node( + name=self.unique_name(), + value_dict={"text": text}, + prev_nodes=[node], + ) + ) + return output_nodes + + def _recursive_splitter(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks. + + It takes in the large text then tries to split it by the first character \n\n. If the first split by \n\n is + still large then it moves to the next character which is \n and tries to split by it. If it is still larger + than our specified chunk size it moves to the next character in the set until we get a split that is less than + our specified chunk size. The default separators list is ["\n\n", "\n", " ", ""]) + + Args: + text (str): Text to split. + separators(List[str]): separators for split. + + Returns: + List[str]: Chunks after split. + """ + final_chunks, next_separators = [], [] + + if len(separators) == 0: + return final_chunks + + # Get current and next separators + cur_separator = separators[-1] + for i, _s in enumerate(separators): + _separator = re.escape(_s) + if _s == "": + cur_separator = _s + break + if re.search(_separator, text): + cur_separator = _s + next_separators = separators[(i + 1) :] + break + + # Splited by current separator firstly + cur_separator = re.escape(cur_separator) + splits = [s for s in re.split(cur_separator, text) if s != ""] + + # Then go merging things, recursively splitting longer texts. + _tmp_splits, _separator = [], "" + for s in splits: + if len(s) < self._chunk_size: + _tmp_splits.append(s) + else: + # merge splitted texts into a chunk + if _tmp_splits: + merged_text = self._merge_splits(_tmp_splits, _separator) + final_chunks.extend(merged_text) + # reset tmp_splits + _tmp_splits = [] + + # recursively split using next separators + if not next_separators: + final_chunks.append(s) + else: + other_info = self._recursive_splitter(s, next_separators) + final_chunks.extend(other_info) + + if _tmp_splits: + merged_text = self._merge_splits(_tmp_splits, _separator) + final_chunks.extend(merged_text) + + return final_chunks + + def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: + """Combine these smaller pieces into medium size chunks. + + Args: + splits (Iterable[str]): Smaller pieces before merge. + separator (str): Separator for merge. + + Returns: + List[str]: Merged medium size chunks. + """ + separator_len = len(separator) + + docs, total = [], 0 + current_doc: List[str] = [] + for s in splits: + _len = len(s) + current_length = ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + ) + + if current_length > self._chunk_size: + if total > self._chunk_size: + print( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = separator.join(current_doc).strip() + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap_size or ( + current_length > self._chunk_size and total > 0 + ): + total -= len(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + + current_doc.append(s) + total += _len + (separator_len if len(current_doc) > 1 else 0) + + doc = separator.join(current_doc).strip() + + if doc is not None: + docs.append(doc) + + return docs diff --git a/uniflow/op/extract/split/splitter_factory.py b/uniflow/op/extract/split/splitter_factory.py index 3b805b7d..dfb50435 100644 --- a/uniflow/op/extract/split/splitter_factory.py +++ b/uniflow/op/extract/split/splitter_factory.py @@ -5,9 +5,13 @@ from uniflow.op.extract.split.constants import ( MARKDOWN_HEADER_SPLITTER, PARAGRAPH_SPLITTER, + RECURSIVE_CHARACTER_SPLITTER, ) from uniflow.op.extract.split.markdown_header_splitter import MarkdownHeaderSplitter from uniflow.op.extract.split.pattern_splitter_op import PatternSplitter +from uniflow.op.extract.split.recursive_character_splitter import ( + RecursiveCharacterSplitter, +) class SplitterOpsFactory: @@ -18,6 +22,9 @@ class SplitterOpsFactory: MARKDOWN_HEADER_SPLITTER: MarkdownHeaderSplitter( name="markdown_header_split_op" ), + RECURSIVE_CHARACTER_SPLITTER: RecursiveCharacterSplitter( + name="recursive_character_split_op" + ), } @staticmethod