diff --git a/README.md b/README.md
deleted file mode 100644
index 67194e4..0000000
--- a/README.md
+++ /dev/null
@@ -1,133 +0,0 @@
-
-
-
- A unified toolkit for Deep Learning Based Document Image Analysis
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
----
-
-## What is LayoutParser
-
-
-
-LayoutParser aims to provide a wide range of tools that aims to streamline Document Image Analysis (DIA) tasks. Please check the LayoutParser [demo video](https://youtu.be/8yA5xB4Dg8c) (1 min) or [full talk](https://www.youtube.com/watch?v=YG0qepPgyGY) (15 min) for details. And here are some key features:
-
-- LayoutParser provides a rich repository of deep learning models for layout detection as well as a set of unified APIs for using them. For example,
-
-
- Perform DL layout detection in 4 lines of code
-
- ```python
- import layoutparser as lp
- model = lp.AutoLayoutModel('lp://EfficientDete/PubLayNet')
- # image = Image.open("path/to/image")
- layout = model.detect(image)
- ```
-
-
-
-- LayoutParser comes with a set of layout data structures with carefully designed APIs that are optimized for document image analysis tasks. For example,
-
-
- Selecting layout/textual elements in the left column of a page
-
- ```python
- image_width = image.size[0]
- left_column = lp.Interval(0, image_width/2, axis='x')
- layout.filter_by(left_column, center=True) # select objects in the left column
- ```
-
-
-
-
- Performing OCR for each detected Layout Region
-
- ```python
- ocr_agent = lp.TesseractAgent()
- for layout_region in layout:
- image_segment = layout_region.crop(image)
- text = ocr_agent.detect(image_segment)
- ```
-
-
-
-
- Flexible APIs for visualizing the detected layouts
-
- ```python
- lp.draw_box(image, layout, box_width=1, show_element_id=True, box_alpha=0.25)
- ```
-
-
-
-
-
-
- Loading layout data stored in json, csv, and even PDFs
-
- ```python
- layout = lp.load_json("path/to/json")
- layout = lp.load_csv("path/to/csv")
- pdf_layout = lp.load_pdf("path/to/pdf")
- ```
-
-
-
-- LayoutParser is also a open platform that enables the sharing of layout detection models and DIA pipelines among the community.
-
- Check the LayoutParser open platform
-
-
-
- Submit your models/pipelines to LayoutParser
-
-
-## Installation
-
-After several major updates, layoutparser provides various functionalities and deep learning models from different backends. But it still easy to install layoutparser, and we designed the installation method in a way such that you can choose to install only the needed dependencies for your project:
-
-```bash
-pip install layoutparser # Install the base layoutparser library with
-pip install "layoutparser[layoutmodels]" # Install DL layout model toolkit
-pip install "layoutparser[ocr]" # Install OCR toolkit
-```
-
-Extra steps are needed if you want to use Detectron2-based models. Please check [installation.md](installation.md) for additional details on layoutparser installation.
-
-## Examples
-
-We provide a series of examples for to help you start using the layout parser library:
-
-1. [Table OCR and Results Parsing](https://github.com/Layout-Parser/layout-parser/blob/main/examples/OCR%20Tables%20and%20Parse%20the%20Output.ipynb): `layoutparser` can be used for conveniently OCR documents and convert the output in to structured data.
-
-2. [Deep Layout Parsing Example](https://github.com/Layout-Parser/layout-parser/blob/main/examples/Deep%20Layout%20Parsing.ipynb): With the help of Deep Learning, `layoutparser` supports the analysis very complex documents and processing of the hierarchical structure in the layouts.
-
-## Contributing
-
-We encourage you to contribute to Layout Parser! Please check out the [Contributing guidelines](.github/CONTRIBUTING.md) for guidelines about how to proceed. Join us!
-
-## Citing `layoutparser`
-
-If you find `layoutparser` helpful to your work, please consider citing our tool and [paper](https://arxiv.org/pdf/2103.15348.pdf) using the following BibTeX entry.
-
-```
-@article{shen2021layoutparser,
- title={LayoutParser: A Unified Toolkit for Deep Learning Based Document Image Analysis},
- author={Shen, Zejiang and Zhang, Ruochen and Dell, Melissa and Lee, Benjamin Charles Germain and Carlson, Jacob and Li, Weining},
- journal={arXiv preprint arXiv:2103.15348},
- year={2021}
-}
-```
\ No newline at end of file
diff --git a/installation.md b/installation.md
deleted file mode 100644
index bb3778e..0000000
--- a/installation.md
+++ /dev/null
@@ -1,70 +0,0 @@
-# Installation
-
-## Install Python
-
-LayoutParser is a Python package that requires Python >= 3.6. If you do not have Python installed on your computer, you might want to turn to [the official instruction](https://www.python.org/downloads/) to download and install the appropriate version of Python.
-
-
-
-## Install the LayoutParser library
-
-After several major updates, LayoutParser provides various functionalities and deep learning models from different backends. However, you might only need a fraction of the functions, and it would be redundant for you to install all the dependencies when they are not required. Therefore, we design highly customizable ways for installing the LayoutParser library:
-
-
-| Command | Description |
-| --- | --- |
-| `pip install layoutparser` | **Install the base LayoutParser Library**
It will support all key functions in LayoutParser, including:
1. Layout Data Structure and operations
2. Layout Visualization
3. Load/export the layout data |
-| `pip install "layoutparser[effdet]"` | **Install LayoutParser with Layout Detection Model Support**
It will install the LayoutParser base library as well as
supporting dependencies for the ***EfficientDet***-based layout detection models. |
-| `pip install layoutparser torchvision && pip install "git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2"` | **Install LayoutParser with Layout Detection Model Support**
It will install the LayoutParser base library as well as
supporting dependencies for the ***Detectron2***-based layout detection models. See details in [Additional Instruction: Install Detectron2 Layout Model Backend](#additional-instruction-install-detectron2-layout-model-backend). |
-| `pip install "layoutparser[paddledetection]"` | **Install LayoutParser with Layout Detection Model Support**
It will install the LayoutParser base library as well as
supporting dependencies for the ***PaddleDetection***-based layout detection models. |
-| `pip install "layoutparser[ocr]"` | **Install LayoutParser with OCR Support**
It will install the LayoutParser base library as well as
supporting dependencies for performing OCRs. See details in [Additional Instruction: Install OCR utils](#additional-instruction-install-ocr-utils). |
-
-### Additional Instruction: Install Detectron2 Layout Model Backend
-
-#### For Mac OS and Linux Users
-
-If you would like to use the Detectron2 models for layout detection, you might need to run the following command:
-
-```bash
-pip install layoutparser torchvision && pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2"
-```
-
-This might take some time as the command will *compile* the library. If you also want to install a Detectron2 version
-with GPU support or encounter some issues during the installation process, please refer to the official Detectron2
-[installation instruction](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md) for detailed
-information.
-
-#### For Windows users
-
-As reported by many users, the installation of Detectron2 can be rather tricky on Windows platforms. In our extensive tests, we find that it is nearly impossible to provide a one-line installation command for Windows users. As a workaround solution, for now we list the possible challenges for installing Detectron2 on Windows, and attach helpful resources for solving them. We are also investigating other possibilities to avoid installing Detectron2 to use pre-trained models. If you have any suggestions or ideas, please feel free to [submit an issue](https://github.com/Layout-Parser/layout-parser/issues) in our repo.
-
-1. Challenges for installing `pycocotools`
- - You can find detailed instructions on [this post](https://changhsinlee.com/pycocotools/) from Chang Hsin Lee.
- - Another solution is try to install `pycocotools-windows`, see https://github.com/cocodataset/cocoapi/issues/415.
-2. Challenges for installing `Detectron2`
- - [@ivanpp](https://github.com/ivanpp) curates a detailed description for installing `Detectron2` on Windows: [Detectron2 walkthrough (Windows)](https://ivanpp.cc/detectron2-walkthrough-windows/#step3installdetectron2)
- - `Detectron2` maintainers claim that they won't provide official support for Windows (see [1](https://github.com/facebookresearch/detectron2/issues/9#issuecomment-540974288) and [2](https://detectron2.readthedocs.io/en/latest/tutorials/install.html)), but Detectron2 is continuously built on windows with CircleCI (see [3](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md#common-installation-issues)). Hopefully this situation will be improved in the future.
-
-
-### Additional Instructions: Install OCR utils
-
-Layout Parser also comes with supports for OCR functions. In order to use them, you need to install the OCR utils via:
-
-```bash
-pip install "layoutparser[ocr]"
-```
-
-Additionally, if you want to use the Tesseract-OCR engine, you also need to install it on your computer. Please check the
-[official documentation](https://tesseract-ocr.github.io/tessdoc/Installation.html) for detailed installation instructions.
-
-## Known issues
-
-Error: instantiating `lp.GCVAgent.with_credential` returns module 'google.cloud.vision' has no attribute 'types'.
-
-
-In this case, you have a newer version of the google-cloud-vision. Please consider downgrading the API using:
-```bash
-pip install -U layoutparser[ocr]
-```
-
-
\ No newline at end of file
diff --git a/src/layoutparser/__init__.py b/src/effocr-layout/__init__.py
similarity index 93%
rename from src/layoutparser/__init__.py
rename to src/effocr-layout/__init__.py
index 512f123..38306f0 100644
--- a/src/layoutparser/__init__.py
+++ b/src/effocr-layout/__init__.py
@@ -23,6 +23,7 @@
is_effdet_available,
is_pytesseract_available,
is_gcv_available,
+ is_effocr_available,
)
_import_structure = {
@@ -51,6 +52,7 @@
"is_paddle_available",
"is_pytesseract_available",
"is_gcv_available",
+ "is_effocr_available",
"requires_backends"
],
"tools": [
@@ -80,6 +82,9 @@
if is_gcv_available():
_import_structure["ocr.gcv_agent"] = ["GCVAgent", "GCVFeatureType"]
+if is_effocr_available():
+ _import_structure["ocr.effocr_agent"] = ["EffOCRAgent", "EffOCRFeatureType"]
+
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
diff --git a/src/layoutparser/elements/__init__.py b/src/effocr-layout/elements/__init__.py
similarity index 100%
rename from src/layoutparser/elements/__init__.py
rename to src/effocr-layout/elements/__init__.py
diff --git a/src/layoutparser/elements/base.py b/src/effocr-layout/elements/base.py
similarity index 100%
rename from src/layoutparser/elements/base.py
rename to src/effocr-layout/elements/base.py
diff --git a/src/layoutparser/elements/errors.py b/src/effocr-layout/elements/errors.py
similarity index 100%
rename from src/layoutparser/elements/errors.py
rename to src/effocr-layout/elements/errors.py
diff --git a/src/layoutparser/elements/layout.py b/src/effocr-layout/elements/layout.py
similarity index 100%
rename from src/layoutparser/elements/layout.py
rename to src/effocr-layout/elements/layout.py
diff --git a/src/layoutparser/elements/layout_elements.py b/src/effocr-layout/elements/layout_elements.py
similarity index 100%
rename from src/layoutparser/elements/layout_elements.py
rename to src/effocr-layout/elements/layout_elements.py
diff --git a/src/layoutparser/elements/utils.py b/src/effocr-layout/elements/utils.py
similarity index 100%
rename from src/layoutparser/elements/utils.py
rename to src/effocr-layout/elements/utils.py
diff --git a/src/layoutparser/file_utils.py b/src/effocr-layout/file_utils.py
similarity index 91%
rename from src/layoutparser/file_utils.py
rename to src/effocr-layout/file_utils.py
index b10a747..9153379 100644
--- a/src/layoutparser/file_utils.py
+++ b/src/effocr-layout/file_utils.py
@@ -88,6 +88,14 @@
except ModuleNotFoundError:
_gcv_available = False
+try:
+ _effocr_available = importlib.util.find_spec("onnxruntime") is not None \
+ and importlib.util.find_spec("onnx") is not None \
+ and importlib.util.find_spec("faiss") is not None
+except ModuleNotFoundError:
+ _effocr_available = False
+
+
def is_torch_available():
return _torch_available
@@ -121,6 +129,9 @@ def is_pytesseract_available():
def is_gcv_available():
return _gcv_available
+def is_effocr_available():
+ return _effocr_available
+
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
@@ -154,6 +165,13 @@ def is_gcv_available():
`pip install google-cloud-vision==1`
"""
+EFFOCR_IMPORT_ERROR = """
+{0} requires the onnxruntime, onnx and faiss libraries but at least one was not found in your environment. You can install it with pip:
+`pip install onnxruntime onnx faiss`
+Note that `faiss` can be installed with eiter the CPU or GPU version, but the GPU version requires CUDA. See
+https://github.com/facebookresearch/faiss/blob/main/INSTALL.md for more details.
+"""
+
BACKENDS_MAPPING = dict(
[
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
@@ -162,6 +180,7 @@ def is_gcv_available():
("effdet", (is_effdet_available, EFFDET_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
("google-cloud-vision", (is_gcv_available, GCV_IMPORT_ERROR)),
+ ("effocr", (is_effocr_available, EFFOCR_IMPORT_ERROR))
]
)
diff --git a/src/layoutparser/io/__init__.py b/src/effocr-layout/io/__init__.py
similarity index 100%
rename from src/layoutparser/io/__init__.py
rename to src/effocr-layout/io/__init__.py
diff --git a/src/layoutparser/io/basic.py b/src/effocr-layout/io/basic.py
similarity index 100%
rename from src/layoutparser/io/basic.py
rename to src/effocr-layout/io/basic.py
diff --git a/src/layoutparser/io/pdf.py b/src/effocr-layout/io/pdf.py
similarity index 100%
rename from src/layoutparser/io/pdf.py
rename to src/effocr-layout/io/pdf.py
diff --git a/src/layoutparser/misc/NotoSerifCJKjp-Regular.otf b/src/effocr-layout/misc/NotoSerifCJKjp-Regular.otf
similarity index 100%
rename from src/layoutparser/misc/NotoSerifCJKjp-Regular.otf
rename to src/effocr-layout/misc/NotoSerifCJKjp-Regular.otf
diff --git a/src/layoutparser/models/__init__.py b/src/effocr-layout/models/__init__.py
similarity index 100%
rename from src/layoutparser/models/__init__.py
rename to src/effocr-layout/models/__init__.py
diff --git a/src/layoutparser/models/auto_layoutmodel.py b/src/effocr-layout/models/auto_layoutmodel.py
similarity index 100%
rename from src/layoutparser/models/auto_layoutmodel.py
rename to src/effocr-layout/models/auto_layoutmodel.py
diff --git a/src/layoutparser/models/base_catalog.py b/src/effocr-layout/models/base_catalog.py
similarity index 100%
rename from src/layoutparser/models/base_catalog.py
rename to src/effocr-layout/models/base_catalog.py
diff --git a/src/layoutparser/models/base_layoutmodel.py b/src/effocr-layout/models/base_layoutmodel.py
similarity index 100%
rename from src/layoutparser/models/base_layoutmodel.py
rename to src/effocr-layout/models/base_layoutmodel.py
diff --git a/src/layoutparser/models/detectron2/__init__.py b/src/effocr-layout/models/detectron2/__init__.py
similarity index 100%
rename from src/layoutparser/models/detectron2/__init__.py
rename to src/effocr-layout/models/detectron2/__init__.py
diff --git a/src/layoutparser/models/detectron2/catalog.py b/src/effocr-layout/models/detectron2/catalog.py
similarity index 100%
rename from src/layoutparser/models/detectron2/catalog.py
rename to src/effocr-layout/models/detectron2/catalog.py
diff --git a/src/layoutparser/models/detectron2/layoutmodel.py b/src/effocr-layout/models/detectron2/layoutmodel.py
similarity index 100%
rename from src/layoutparser/models/detectron2/layoutmodel.py
rename to src/effocr-layout/models/detectron2/layoutmodel.py
diff --git a/src/layoutparser/models/effdet/__init__.py b/src/effocr-layout/models/effdet/__init__.py
similarity index 100%
rename from src/layoutparser/models/effdet/__init__.py
rename to src/effocr-layout/models/effdet/__init__.py
diff --git a/src/layoutparser/models/effdet/catalog.py b/src/effocr-layout/models/effdet/catalog.py
similarity index 100%
rename from src/layoutparser/models/effdet/catalog.py
rename to src/effocr-layout/models/effdet/catalog.py
diff --git a/src/layoutparser/models/effdet/layoutmodel.py b/src/effocr-layout/models/effdet/layoutmodel.py
similarity index 100%
rename from src/layoutparser/models/effdet/layoutmodel.py
rename to src/effocr-layout/models/effdet/layoutmodel.py
diff --git a/src/layoutparser/models/model_config.py b/src/effocr-layout/models/model_config.py
similarity index 100%
rename from src/layoutparser/models/model_config.py
rename to src/effocr-layout/models/model_config.py
diff --git a/src/layoutparser/models/paddledetection/__init__.py b/src/effocr-layout/models/paddledetection/__init__.py
similarity index 100%
rename from src/layoutparser/models/paddledetection/__init__.py
rename to src/effocr-layout/models/paddledetection/__init__.py
diff --git a/src/layoutparser/models/paddledetection/catalog.py b/src/effocr-layout/models/paddledetection/catalog.py
similarity index 100%
rename from src/layoutparser/models/paddledetection/catalog.py
rename to src/effocr-layout/models/paddledetection/catalog.py
diff --git a/src/layoutparser/models/paddledetection/layoutmodel.py b/src/effocr-layout/models/paddledetection/layoutmodel.py
similarity index 100%
rename from src/layoutparser/models/paddledetection/layoutmodel.py
rename to src/effocr-layout/models/paddledetection/layoutmodel.py
diff --git a/src/layoutparser/ocr/__init__.py b/src/effocr-layout/ocr/__init__.py
similarity index 92%
rename from src/layoutparser/ocr/__init__.py
rename to src/effocr-layout/ocr/__init__.py
index 66efd76..ab7be48 100644
--- a/src/layoutparser/ocr/__init__.py
+++ b/src/effocr-layout/ocr/__init__.py
@@ -13,4 +13,5 @@
# limitations under the License.
from .gcv_agent import GCVAgent, GCVFeatureType
-from .tesseract_agent import TesseractAgent, TesseractFeatureType
\ No newline at end of file
+from .tesseract_agent import TesseractAgent, TesseractFeatureType
+from .effocr_agent import EffOCRAgent, EffOCRFeatureType
\ No newline at end of file
diff --git a/src/layoutparser/ocr/base.py b/src/effocr-layout/ocr/base.py
similarity index 100%
rename from src/layoutparser/ocr/base.py
rename to src/effocr-layout/ocr/base.py
diff --git a/src/effocr-layout/ocr/effocr/__init__.py b/src/effocr-layout/ocr/effocr/__init__.py
new file mode 100644
index 0000000..5e1aac0
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/__init__.py
@@ -0,0 +1,3 @@
+from .engines import EffLineDetector, EffLocalizer, EffRecognizer
+from .utils import create_paired_transform, create_paired_transform_word, letterbox, non_max_suppression
+from .infer_transcripton import run_effocr_word
\ No newline at end of file
diff --git a/src/effocr-layout/ocr/effocr/engines/__init__.py b/src/effocr-layout/ocr/effocr/engines/__init__.py
new file mode 100644
index 0000000..602e943
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/engines/__init__.py
@@ -0,0 +1,3 @@
+from .localizer_engine import EffLocalizer
+from .recognizer_engine import EffRecognizer
+from .line_det_engine import EffLineDetector
\ No newline at end of file
diff --git a/src/effocr-layout/ocr/effocr/engines/line_det_engine.py b/src/effocr-layout/ocr/effocr/engines/line_det_engine.py
new file mode 100644
index 0000000..cde6ea5
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/engines/line_det_engine.py
@@ -0,0 +1,251 @@
+import os
+import sys
+# import mmcv
+import torch
+import numpy as np
+import onnxruntime as ort
+import torchvision
+from torchvision.ops import nms
+import cv2
+import onnx
+from math import floor, ceil
+
+from .ops import non_max_suppression as yolov8_nms
+from .ops import get_onnx_input_name
+from ..utils import letterbox, non_max_suppression
+
+DEFAULT_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
+DEFAULT_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)
+
+class EffLineDetector:
+ """
+ Class for running the EffOCR line detection model. Essentially a wrapper for the onnxruntime
+ inference session based on the model, wit some additional postprocessing, especially regarding splitting and
+ recombining especailly tall layout regions
+ """
+
+ def __init__(self, model_path, iou_thresh = 0.15, conf_thresh = 0.20,
+ num_cores = None, providers=None, input_shape = (640, 640), model_backend='yolo',
+ min_seg_ratio = 2, visualize = None):
+ """Instantiates the object, including setting up the wrapped ONNX InferenceSession
+
+ Args:
+ model_path (str): Path to ONNX model that will be used
+ iou_thresh (float, optional): IOU filter for line detection NMS. Defaults to 0.15.
+ conf_thresh (float, optional): Confidence filter for line detection NMS. Defaults to 0.20.
+ num_cores (_type_, optional): Number of cores to use during inference. Defaults to None, meaning no intra op thread limit.
+ providers (_type_, optional): Any particular ONNX providers to use. Defaults to None, meaning results of ort.get_available_providers() will be used.
+ input_shape (tuple, optional): Shape of input images. Defaults to (640, 640).
+ model_backend (str, optional): Original model backend being used. Defaults to 'yolo'. Options are mmdetection, detectron2, yolo, yolov8.
+ """
+
+
+ # Set up and instantiate a ort InfernceSession
+ sess_options = ort.SessionOptions()
+ if num_cores is not None:
+ sess_options.intra_op_num_threads = num_cores
+
+ if providers is None:
+ providers = ort.get_available_providers()
+
+ self._eng_net = ort.InferenceSession(
+ model_path,
+ sess_options,
+ providers=providers,
+ )
+
+ # Load in the model as a standard ONNX model to get the input shape and name
+ base_model = onnx.load(model_path)
+ self._input_name = get_onnx_input_name(base_model)
+ self._model_input_shape = self._eng_net.get_inputs()[0].shape
+
+ # Rest of the params
+ self._iou_thresh = iou_thresh
+ self._conf_thresh = conf_thresh
+
+ if isinstance(self._model_input_shape[-1], int) and isinstance(self._model_input_shape[-2], int):
+ self._input_shape = (self._model_input_shape[-2], self._model_input_shape[-1])
+ else:
+ self._input_shape = input_shape
+ self._model_backend = model_backend
+ self.min_seg_ratio = min_seg_ratio # Ratio that determines at what point the model will split a region into two
+
+
+
+ def __call__(self, imgs, visualize = None):
+ """Wraps the run method, allowing the object to be called directly
+
+ Args:
+ imgs (list or str or np.ndarray): List of image paths, list of images as np.ndarrays, or single image path, or single image as np.ndarray
+
+ Returns:
+ _type_: _description_
+ """
+ return self.run(imgs, visualize = visualize)
+
+ def run(self, imgs, visualize = None):
+ orig_img = imgs.copy()
+ if isinstance(imgs, list):
+ if all(isinstance(img, str) for img in imgs):
+ imgs = [self.load_line_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
+ elif all(isinstance(img, np.ndarray) for img in imgs):
+ imgs = [self.get_crops_from_layout_image(img) for img in imgs]
+ imgs = [self.format_line_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
+ else:
+ raise ValueError('Invalid combination if input types in Line Detection list! Must be all str or all np.ndarray')
+ elif isinstance(imgs, str):
+ imgs = [self.load_line_img(imgs, self._input_shape, backend=self._model_backend)]
+ elif isinstance(imgs, np.ndarray):
+ imgs = self.get_crops_from_layout_image(imgs)
+ orig_shapes = [img.shape for img in imgs]
+ imgs = [self.format_line_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
+ else:
+ raise ValueError('Input type {} is not implemented'.format(type(imgs)))
+
+ results = [self._eng_net.run(None, {self._input_name: img}) for img in imgs]
+ return self._postprocess(results, imgs, orig_shapes, orig_img, viz_lines = visualize)
+
+ def _postprocess(self, results, imgs, orig_shapes, orig_img, viz_lines = None):
+ #YOLO NMS is carried out now, other backends will filter by bbox confidence score later
+ if self._model_backend == 'yolo':
+ preds = [torch.from_numpy(pred[0]) for pred in results]
+ preds = [non_max_suppression(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=100)[0] for pred in preds]
+
+ elif self._model_backend == 'yolov8':
+ preds = [torch.from_numpy(pred[0]) for pred in results]
+ preds = [yolov8_nms(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=100)[0] for pred in preds]
+
+ elif self._model_backend == 'detectron2' or self._model_backend == 'mmdetection':
+ return results
+
+ preds = self.adjust_line_preds(preds, imgs, orig_shapes)
+ final_preds = self.readjust_line_predictions(preds, imgs[0].shape[1])
+
+ line_crops, line_coords = [], []
+ for i, line_proj_crop in enumerate(final_preds):
+ x0, y0, x1, y1 = map(round, line_proj_crop)
+ line_crop = orig_img[y0:y1, x0:x1]
+ if line_crop.shape[0] == 0 or line_crop.shape[1] == 0:
+ continue
+
+ # Line crops becomes a list of tuples (bbox_id, line_crop [the image itself], line_proj_crop [the coordinates of the line in the layout image])
+ line_crops.append(np.array(line_crop).astype(np.float32))
+ line_coords.append((y0, x0, y1, x1))
+
+ # If asked to visualize the line detections, draw a rectangle representing each line crop on the original image
+ if viz_lines is not None:
+ cv2.rectangle(orig_img, (x0, y0), (x1, y1), (255, 0, 0), 2)
+
+ # If asked to visualize, output the image with the line detections drawn on it
+ if viz_lines is not None:
+ cv2.imwrite(viz_lines, orig_img)
+
+ return line_crops, line_coords
+
+
+ def adjust_line_preds(self, preds, imgs, orig_shapes):
+ adjusted_preds = []
+
+ for pred, shape in zip(preds, orig_shapes):
+ line_predictions = pred[pred[:, 1].sort()[1]]
+ line_bboxes, line_confs, line_labels = line_predictions[:, :4], line_predictions[:, -2], line_predictions[:, -1]
+
+ im_width, im_height = shape[1], shape[0]
+ if im_width > im_height:
+ h_ratio = (im_height / im_width) * 640
+ h_trans = 640 * ((1 - (im_height / im_width)) / 2)
+ else:
+ h_trans = 0
+ h_ratio = 640
+
+ line_proj_crops = []
+ for line_bbox in line_bboxes:
+ x0, y0, x1, y1 = torch.round(line_bbox)
+ x0, y0, x1, y1 = 0, int(floor((y0.item() - h_trans) * im_height / h_ratio)), \
+ im_width, int(ceil((y1.item() - h_trans) * im_height / h_ratio))
+
+ line_proj_crops.append((x0, y0, x1, y1))
+
+ adjusted_preds.append((line_proj_crops, line_confs, line_labels))
+
+ return adjusted_preds
+
+ def readjust_line_predictions(self, line_preds, orig_img_width):
+ y0 = 0
+ dif = int(orig_img_width * 1.5)
+ all_preds, final_preds = [], []
+ for j in range(len(line_preds)):
+ preds, probs, labels = line_preds[j]
+ for i, pred in enumerate(preds):
+ all_preds.append((pred[0], pred[1] + y0, pred[2], pred[3] + y0, probs[i]))
+ y0 += dif
+
+ all_preds = torch.tensor(all_preds)
+ if all_preds.dim() > 1:
+ keep_preds = nms(all_preds[:, :4], all_preds[:, -1], iou_threshold=0.15)
+ filtered_preds = all_preds[keep_preds, :4]
+ filtered_preds = filtered_preds[filtered_preds[:, 1].sort()[1]]
+ for pred in filtered_preds:
+ x0, y0, x1, y1 = torch.round(pred)
+ x0, y0, x1, y1 = x0.item(), y0.item(), x1.item(), y1.item()
+ final_preds.append((x0, y0, x1, y1))
+ return final_preds
+ else:
+ return []
+
+ def format_line_img(self, img, input_shape, backend='yolo'):
+ if backend == 'yolo' or backend == 'yolov8':
+ im = letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ im = im.astype(np.float32) / 255.0 # 0 - 255 to 0.0 - 1.0
+ if im.ndim == 3:
+ im = np.expand_dims(im, 0)
+
+ elif backend == 'detectron2':
+ im = letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ im = im.astype(np.float32)
+
+ elif backend == 'mmdetection':
+ im = mmcv.imrescale(img, (input_shape[0], input_shape[1]))
+ im = mmcv.impad(im, shape = input_shape, pad_val=0)
+ im = mmcv.imnormalize(im, DEFAULT_MEAN, DEFAULT_STD, to_rgb=True)
+ im = im.transpose(2, 0, 1)
+ if im.ndim == 3:
+ im = np.expand_dims(im, 0)
+
+
+ else:
+ raise NotImplementedError('Backend {} is not implemented'.format(backend))
+
+ return im
+
+ def load_line_img(self, input_path, input_shape, backend='yolo'):
+ if backend == 'yolo' or backend == 'yolov8' or backend == 'detectron2':
+ im0 = cv2.imread(input_path)
+ im0 = self.get_crops_from_layout_image(im0)
+ return [self.format_line_img(im, input_shape, backend=backend) for im in im0]
+ elif backend == 'mmdetection':
+ one_img = mmcv.imread(input_path)
+ one_img = self.get_crops_from_layout_image(one_img)
+ return [self.format_line_img(one_im, input_shape, backend=backend) for one_im in one_img]
+ else:
+ raise NotImplementedError('Backend {} is not implemented'.format(backend))
+
+ def get_crops_from_layout_image(self, image):
+ im_width, im_height = image.shape[0], image.shape[1]
+ if im_height <= im_width * self.min_seg_ratio:
+ return [image]
+ else:
+ y0 = 0
+ y1 = im_width * self.min_seg_ratio
+ crops = []
+ while y1 <= im_height:
+ crops.append(image.crop((0, y0, im_width, y1)))
+ y0 += int(im_width * self.min_seg_ratio * 0.75) # .75 factor ensures there is overlap between crops
+ y1 += int(im_width * self.min_seg_ration * 0.75)
+
+ crops.append(image.crop((0, y0, im_width, im_height)))
+ return crops
\ No newline at end of file
diff --git a/src/effocr-layout/ocr/effocr/engines/localizer_engine.py b/src/effocr-layout/ocr/effocr/engines/localizer_engine.py
new file mode 100644
index 0000000..8c71399
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/engines/localizer_engine.py
@@ -0,0 +1,320 @@
+import os
+import sys
+# import mmcv
+import torch
+import numpy as np
+import onnxruntime as ort
+import torchvision
+import cv2
+import onnx
+
+from .ops import non_max_suppression as yolov8_nms
+
+DEFAULT_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
+DEFAULT_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)
+
+class EffLocalizer:
+
+ def __init__(self, model_path, iou_thresh = 0.01, conf_thresh = 0.30, vertical = False,
+ num_cores = None, providers=None, input_shape = (640, 640), model_backend='yolo'):
+ sess_options = ort.SessionOptions()
+ if num_cores is not None:
+ sess_options.intra_op_num_threads = num_cores
+
+ if providers is None:
+ providers = ort.get_available_providers()
+
+ self._eng_net = ort.InferenceSession(
+ model_path,
+ sess_options,
+ providers=providers,
+ )
+
+ base_model = onnx.load(model_path)
+ self._input_name = EffLocalizer.get_onnx_input_name(base_model)
+ self._model_input_shape = self._eng_net.get_inputs()[0].shape
+ self._iou_thresh = iou_thresh
+ self._conf_thresh = conf_thresh
+ self._vertical = vertical
+
+ if isinstance(self._model_input_shape[-1], int) and isinstance(self._model_input_shape[-2], int):
+ self._input_shape = (self._model_input_shape[-2], self._model_input_shape[-1])
+ else:
+ self._input_shape = input_shape
+ self._model_backend = model_backend
+
+
+
+ def __call__(self, imgs):
+ return self.run(imgs)
+
+ def run(self, imgs):
+ if isinstance(imgs, list):
+ if isinstance(imgs[0], str):
+ imgs = [EffLocalizer.load_localizer_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
+ else:
+ imgs = [EffLocalizer.format_localizer_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
+ elif isinstance(imgs, str):
+ imgs = [EffLocalizer.load_localizer_img(imgs, self._input_shape, backend=self._model_backend)]
+ elif isinstance(imgs, np.ndarray):
+ imgs = [EffLocalizer.format_localizer_img(imgs, self._input_shape, backend=self._model_backend)]
+ else:
+ raise NotImplementedError('Input type {} is not implemented'.format(type(imgs)))
+
+ results = [self._eng_net.run(None, {self._input_name: img}) for img in imgs]
+ return self._postprocess(results)
+
+ def _postprocess(self, results):
+ #YOLO NMS is carried out now, other backends will filter by bbox confidence score later
+ if self._model_backend == 'yolo':
+
+ preds = [torch.from_numpy(pred[0]) for pred in results]
+ preds = [self.non_max_suppression(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=1000)[0] for pred in preds]
+ return preds
+
+ elif self._model_backend == 'yolov8':
+ preds = [torch.from_numpy(pred[0]) for pred in results]
+ preds = [yolov8_nms(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=50)[0] for pred in preds]
+ return preds
+
+ elif self._model_backend == 'detectron2' or self._model_backend == 'mmdetection':
+ return results
+
+ @staticmethod
+ def get_onnx_input_name(model):
+ input_all = [node.name for node in model.graph.input]
+ input_initializer = [node.name for node in model.graph.initializer]
+ net_feed_input = list(set(input_all) - set(input_initializer))
+ return net_feed_input[0]
+
+ @staticmethod
+ def format_localizer_img(img, input_shape, backend='yolo'):
+ if backend == 'yolo' or backend == 'yolov8':
+ im = EffLocalizer.letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ im = im.astype(np.float32) / 255.0 # 0 - 255 to 0.0 - 1.0
+ if im.ndim == 3:
+ im = np.expand_dims(im, 0)
+ return im
+ elif backend == 'detectron2':
+ im = EffLocalizer.letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ im = im.astype(np.float32)
+ return im
+ elif backend == 'mmdetection':
+ one_img = mmcv.imrescale(img, (input_shape[0], input_shape[1]))
+ one_img = mmcv.impad(one_img, shape = input_shape, pad_val=0)
+ one_img = mmcv.imnormalize(one_img, DEFAULT_MEAN, DEFAULT_STD, to_rgb=True)
+ one_img = one_img.transpose(2, 0, 1)
+ if one_img.ndim == 3:
+ one_img = np.expand_dims(one_img, 0)
+
+ return one_img
+ else:
+ raise NotImplementedError('Backend {} is not implemented'.format(backend))
+
+ @staticmethod
+ def load_localizer_img(input_path, input_shape, backend='yolo'):
+ if backend == 'yolo' or backend == 'yolov8':
+ im0 = cv2.imread(input_path)
+ im = EffLocalizer.letterbox(im0, input_shape, stride=32, auto=False)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ im = im.astype(np.float32) / 255.0 # 0 - 255 to 0.0 - 1.0
+ if im.ndim == 3:
+ im = np.expand_dims(im, 0)
+ return im
+ elif backend == 'detectron2':
+ im0 = cv2.imread(input_path)
+ im = EffLocalizer.letterbox(im0, input_shape, stride=32, auto=False)[0] # padded resize
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = np.ascontiguousarray(im) # contiguous
+ im = im.astype(np.float32)
+ return im
+ elif backend == 'mmdetection':
+ one_img = mmcv.imread(input_path)
+ one_img = mmcv.imrescale(one_img, (input_shape[0], input_shape[1]))
+ one_img = mmcv.impad(one_img, shape = input_shape, pad_val=0)
+ one_img = mmcv.imnormalize(one_img, DEFAULT_MEAN, DEFAULT_STD, to_rgb=True)
+ one_img = one_img.transpose(2, 0, 1)
+ if one_img.ndim == 3:
+ one_img = np.expand_dims(one_img, 0)
+
+ return one_img
+ else:
+ raise NotImplementedError('Backend {} is not implemented'.format(backend))
+
+
+ @staticmethod
+ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
+ # Resize and pad image while meeting stride-multiple constraints
+ shape = im.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
+ elif scaleFill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return im, ratio, (dw, dh)
+
+ @staticmethod
+ def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+ @staticmethod
+ def box_iou(box1, box2, eps=1e-7):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+ (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
+
+ # IoU = inter / (area1 + area2 - inter)
+ return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
+
+ @staticmethod
+ def non_max_suppression(
+ prediction,
+ conf_thres=0.25,
+ iou_thres=0.45,
+ classes=None,
+ agnostic=False,
+ multi_label=False,
+ labels=(),
+ max_det=300,
+ nm=0, ):
+
+ if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
+ prediction = prediction[0] # select only inference output
+
+ device = prediction.device
+ mps = 'mps' in device.type # Apple MPS
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
+ prediction = prediction.cpu()
+ bs = prediction.shape[0] # batch size
+ nc = prediction.shape[2] - nm - 5 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Checks
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+
+ # Settings
+ # min_wh = 2 # (pixels) minimum box width and height
+ max_wh = 7680 # (pixels) maximum box width and height
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
+ time_limit = 0.5 + 0.05 * bs # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ mi = 5 + nc # mask start index
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ lb = labels[xi]
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
+ v[:, :4] = lb[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box/Mask
+ box = EffLocalizer.xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
+ mask = x[:, mi:] # zero columns if no masks
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
+ else: # best class only
+ conf, j = x[:, 5:mi].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Apply finite constraint
+ # if not torch.isfinite(x).all():
+ # x = x[torch.isfinite(x).all(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+ elif n > max_nms: # excess boxes
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
+ else:
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if i.shape[0] > max_det: # limit detections
+ i = i[:max_det]
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = EffLocalizer.box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if mps:
+ output[xi] = output[xi].to(device)
+
+ return output
\ No newline at end of file
diff --git a/src/effocr-layout/ocr/effocr/engines/ops.py b/src/effocr-layout/ocr/effocr/engines/ops.py
new file mode 100644
index 0000000..6daabc9
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/engines/ops.py
@@ -0,0 +1,723 @@
+import contextlib
+import math
+import re
+import time
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+
+def get_onnx_input_name(model):
+ input_all = [node.name for node in model.graph.input]
+ input_initializer = [node.name for node in model.graph.initializer]
+ net_feed_input = list(set(input_all) - set(input_initializer))
+ return net_feed_input[0]
+
+class Profile(contextlib.ContextDecorator):
+ """
+ YOLOv8 Profile class.
+ Usage: as a decorator with @Profile() or as a context manager with 'with Profile():'
+ """
+
+ def __init__(self, t=0.0):
+ """
+ Initialize the Profile class.
+
+ Args:
+ t (float): Initial time. Defaults to 0.0.
+ """
+ self.t = t
+ self.cuda = torch.cuda.is_available()
+
+ def __enter__(self):
+ """
+ Start timing.
+ """
+ self.start = self.time()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ """
+ Stop timing.
+ """
+ self.dt = self.time() - self.start # delta-time
+ self.t += self.dt # accumulate dt
+
+ def time(self):
+ """
+ Get current time.
+ """
+ if self.cuda:
+ torch.cuda.synchronize()
+ return time.time()
+
+
+def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
+ return [
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
+
+
+def segment2box(segment, width=640, height=640):
+ """
+ Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
+
+ Args:
+ segment (torch.Tensor): the segment label
+ width (int): the width of the image. Defaults to 640
+ height (int): The height of the image. Defaults to 640
+
+ Returns:
+ (np.ndarray): the minimum and maximum x and y values of the segment.
+ """
+ # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
+ x, y = segment.T # segment xy
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
+ x, y, = x[inside], y[inside]
+ return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
+ 4, dtype=segment.dtype) # xyxy
+
+
+def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
+ """
+ Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
+ (img1_shape) to the shape of a different image (img0_shape).
+
+ Args:
+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
+ boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
+ img0_shape (tuple): the shape of the target image, in the format of (height, width).
+ ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
+ calculated based on the size difference between the two images.
+
+ Returns:
+ boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
+ """
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ boxes[..., [0, 2]] -= pad[0] # x padding
+ boxes[..., [1, 3]] -= pad[1] # y padding
+ boxes[..., :4] /= gain
+ clip_boxes(boxes, img0_shape)
+ return boxes
+
+
+def make_divisible(x, divisor):
+ """
+ Returns the nearest number that is divisible by the given divisor.
+
+ Args:
+ x (int): The number to make divisible.
+ divisor (int) or (torch.Tensor): The divisor.
+
+ Returns:
+ (int): The nearest number divisible by the divisor.
+ """
+ if isinstance(divisor, torch.Tensor):
+ divisor = int(divisor.max()) # to int
+ return math.ceil(x / divisor) * divisor
+
+
+def box_iou(box1, box2, eps=1e-7):
+ """
+ Calculate intersection-over-union (IoU) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+
+ Args:
+ box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
+ box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
+ eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+ Returns:
+ (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
+ """
+
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+ (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
+
+ # IoU = inter / (area1 + area2 - inter)
+ return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
+
+def non_max_suppression(
+ prediction,
+ conf_thres=0.25,
+ iou_thres=0.45,
+ classes=None,
+ agnostic=False,
+ multi_label=False,
+ labels=(),
+ max_det=300,
+ nc=0, # number of classes (optional)
+ max_time_img=0.05,
+ max_nms=30000,
+ max_wh=7680,
+):
+ """
+ Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
+
+ Arguments:
+ prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
+ containing the predicted boxes, classes, and masks. The tensor should be in the format
+ output by a model, such as YOLO.
+ conf_thres (float): The confidence threshold below which boxes will be filtered out.
+ Valid values are between 0.0 and 1.0.
+ iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
+ Valid values are between 0.0 and 1.0.
+ classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
+ agnostic (bool): If True, the model is agnostic to the number of classes, and all
+ classes will be considered as one.
+ multi_label (bool): If True, each box may have multiple labels.
+ labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
+ list contains the apriori labels for a given image. The list should be in the format
+ output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
+ max_det (int): The maximum number of boxes to keep after NMS.
+ nc (int): (optional) The number of classes output by the model. Any indices after this will be considered masks.
+ max_time_img (float): The maximum time (seconds) for processing one image.
+ max_nms (int): The maximum number of boxes into torchvision.ops.nms().
+ max_wh (int): The maximum box width and height in pixels
+
+ Returns:
+ (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
+ shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
+ (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
+ """
+
+ # Checks
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+ if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
+ prediction = prediction[0] # select only inference output
+
+ device = prediction.device
+ mps = 'mps' in device.type # Apple MPS
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
+ prediction = prediction.cpu()
+ bs = prediction.shape[0] # batch size
+ nc = nc or (prediction.shape[1] - 4) # number of classes
+ nm = prediction.shape[1] - nc - 4
+ mi = 4 + nc # mask start index
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
+
+ # Settings
+ # min_wh = 2 # (pixels) minimum box width and height
+ time_limit = 0.5 + max_time_img * bs # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = False # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
+ x = x.transpose(0, -1)[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ lb = labels[xi]
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
+ v[:, :4] = lb[:, 1:5] # box
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ box, cls, mask = x.split((4, nc, nm), 1)
+ box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
+ if multi_label:
+ i, j = (cls > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
+ else: # best class only
+ conf, j = cls.max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Apply finite constraint
+ # if not torch.isfinite(x).all():
+ # x = x[torch.isfinite(x).all(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ i = i[:max_det] # limit detections
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
+ # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if mps:
+ output[xi] = output[xi].to(device)
+
+ return output
+
+
+def clip_boxes(boxes, shape):
+ """
+ It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
+ shape
+
+ Args:
+ boxes (torch.Tensor): the bounding boxes to clip
+ shape (tuple): the shape of the image
+ """
+ if isinstance(boxes, torch.Tensor): # faster individually
+ boxes[..., 0].clamp_(0, shape[1]) # x1
+ boxes[..., 1].clamp_(0, shape[0]) # y1
+ boxes[..., 2].clamp_(0, shape[1]) # x2
+ boxes[..., 3].clamp_(0, shape[0]) # y2
+ else: # np.array (faster grouped)
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
+
+
+def clip_coords(coords, shape):
+ """
+ Clip line coordinates to the image boundaries.
+
+ Args:
+ coords (torch.Tensor) or (numpy.ndarray): A list of line coordinates.
+ shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
+
+ Returns:
+ (None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
+ """
+ if isinstance(coords, torch.Tensor): # faster individually
+ coords[..., 0].clamp_(0, shape[1]) # x
+ coords[..., 1].clamp_(0, shape[0]) # y
+ else: # np.array (faster grouped)
+ coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
+ coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
+
+
+def scale_image(masks, im0_shape, ratio_pad=None):
+ """
+ Takes a mask, and resizes it to the original image size
+
+ Args:
+ masks (torch.Tensor): resized and padded masks/images, [h, w, num]/[h, w, 3].
+ im0_shape (tuple): the original image shape
+ ratio_pad (tuple): the ratio of the padding to the original image.
+
+ Returns:
+ masks (torch.Tensor): The masks that are being returned.
+ """
+ # Rescale coordinates (xyxy) from im1_shape to im0_shape
+ im1_shape = masks.shape
+ if im1_shape[:2] == im0_shape[:2]:
+ return masks
+ if ratio_pad is None: # calculate from im0_shape
+ gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
+ pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+ top, left = int(pad[1]), int(pad[0]) # y, x
+ bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
+
+ if len(masks.shape) < 2:
+ raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
+ masks = masks[top:bottom, left:right]
+ # masks = masks.permute(2, 0, 1).contiguous()
+ # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
+ # masks = masks.permute(1, 2, 0).contiguous()
+ masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
+ if len(masks.shape) == 2:
+ masks = masks[:, :, None]
+
+ return masks
+
+
+def xyxy2xywh(x):
+ """
+ Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
+ y[..., 2] = x[..., 2] - x[..., 0] # width
+ y[..., 3] = x[..., 3] - x[..., 1] # height
+ return y
+
+
+def xywh2xyxy(x):
+ """
+ Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
+ top-left corner and (x2, y2) is the bottom-right corner.
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
+ return y
+
+
+def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
+ """
+ Convert normalized bounding box coordinates to pixel coordinates.
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The bounding box coordinates.
+ w (int): Width of the image. Defaults to 640
+ h (int): Height of the image. Defaults to 640
+ padw (int): Padding width. Defaults to 0
+ padh (int): Padding height. Defaults to 0
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
+ x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
+ return y
+
+
+def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
+ """
+ Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
+ x, y, width and height are normalized to image dimensions
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
+ w (int): The width of the image. Defaults to 640
+ h (int): The height of the image. Defaults to 640
+ clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
+ eps (float): The minimum value of the box's width and height. Defaults to 0.0
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
+ """
+ if clip:
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
+ return y
+
+
+def xyn2xy(x, w=640, h=640, padw=0, padh=0):
+ """
+ Convert normalized coordinates to pixel coordinates of shape (n,2)
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates
+ w (int): The width of the image. Defaults to 640
+ h (int): The height of the image. Defaults to 640
+ padw (int): The width of the padding. Defaults to 0
+ padh (int): The height of the padding. Defaults to 0
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[..., 0] = w * x[..., 0] + padw # top left x
+ y[..., 1] = h * x[..., 1] + padh # top left y
+ return y
+
+
+def xywh2ltwh(x):
+ """
+ Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ return y
+
+
+def xyxy2ltwh(x):
+ """
+ Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
+ Returns:
+ y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 2] = x[:, 2] - x[:, 0] # width
+ y[:, 3] = x[:, 3] - x[:, 1] # height
+ return y
+
+
+def ltwh2xywh(x):
+ """
+ Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
+
+ Args:
+ x (torch.Tensor): the input tensor
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
+ y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
+ return y
+
+
+def ltwh2xyxy(x):
+ """
+ It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+
+ Args:
+ x (np.ndarray) or (torch.Tensor): the input image
+
+ Returns:
+ y (np.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes.
+ """
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 2] = x[:, 2] + x[:, 0] # width
+ y[:, 3] = x[:, 3] + x[:, 1] # height
+ return y
+
+
+def segments2boxes(segments):
+ """
+ It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
+
+ Args:
+ segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
+
+ Returns:
+ (np.ndarray): the xywh coordinates of the bounding boxes.
+ """
+ boxes = []
+ for s in segments:
+ x, y = s.T # segment xy
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
+
+
+def resample_segments(segments, n=1000):
+ """
+ Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
+
+ Args:
+ segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
+ n (int): number of points to resample the segment to. Defaults to 1000
+
+ Returns:
+ segments (list): the resampled segments.
+ """
+ for i, s in enumerate(segments):
+ s = np.concatenate((s, s[0:1, :]), axis=0)
+ x = np.linspace(0, len(s) - 1, n)
+ xp = np.arange(len(s))
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
+ dtype=np.float32).reshape(2, -1).T # segment xy
+ return segments
+
+
+def crop_mask(masks, boxes):
+ """
+ It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
+
+ Args:
+ masks (torch.Tensor): [h, w, n] tensor of masks
+ boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
+
+ Returns:
+ (torch.Tensor): The masks are being cropped to the bounding box.
+ """
+ n, h, w = masks.shape
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
+
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
+
+
+def process_mask_upsample(protos, masks_in, bboxes, shape):
+ """
+ It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
+ quality but is slower.
+
+ Args:
+ protos (torch.Tensor): [mask_dim, mask_h, mask_w]
+ masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
+ bboxes (torch.Tensor): [n, 4], n is number of masks after nms
+ shape (tuple): the size of the input image (h,w)
+
+ Returns:
+ (torch.Tensor): The upsampled masks.
+ """
+ c, mh, mw = protos.shape # CHW
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ masks = crop_mask(masks, bboxes) # CHW
+ return masks.gt_(0.5)
+
+
+def process_mask(protos, masks_in, bboxes, shape, upsample=False):
+ """
+ Apply masks to bounding boxes using the output of the mask head.
+
+ Args:
+ protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
+ masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
+ bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
+ shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
+ upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
+
+ Returns:
+ (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
+ are the height and width of the input image. The mask is applied to the bounding boxes.
+ """
+
+ c, mh, mw = protos.shape # CHW
+ ih, iw = shape
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
+
+ downsampled_bboxes = bboxes.clone()
+ downsampled_bboxes[:, 0] *= mw / iw
+ downsampled_bboxes[:, 2] *= mw / iw
+ downsampled_bboxes[:, 3] *= mh / ih
+ downsampled_bboxes[:, 1] *= mh / ih
+
+ masks = crop_mask(masks, downsampled_bboxes) # CHW
+ if upsample:
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ return masks.gt_(0.5)
+
+
+def process_mask_native(protos, masks_in, bboxes, shape):
+ """
+ It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
+
+ Args:
+ protos (torch.Tensor): [mask_dim, mask_h, mask_w]
+ masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
+ bboxes (torch.Tensor): [n, 4], n is number of masks after nms
+ shape (tuple): the size of the input image (h,w)
+
+ Returns:
+ masks (torch.Tensor): The returned masks with dimensions [h, w, n]
+ """
+ c, mh, mw = protos.shape # CHW
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
+ gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
+ pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
+ top, left = int(pad[1]), int(pad[0]) # y, x
+ bottom, right = int(mh - pad[1]), int(mw - pad[0])
+ masks = masks[:, top:bottom, left:right]
+
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ masks = crop_mask(masks, bboxes) # CHW
+ return masks.gt_(0.5)
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False):
+ """
+ Rescale segment coordinates (xyxy) from img1_shape to img0_shape
+
+ Args:
+ img1_shape (tuple): The shape of the image that the coords are from.
+ coords (torch.Tensor): the coords to be scaled
+ img0_shape (tuple): the shape of the image that the segmentation is being applied to
+ ratio_pad (tuple): the ratio of the image size to the padded image size.
+ normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
+
+ Returns:
+ coords (torch.Tensor): the segmented image.
+ """
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[..., 0] -= pad[0] # x padding
+ coords[..., 1] -= pad[1] # y padding
+ coords[..., 0] /= gain
+ coords[..., 1] /= gain
+ clip_coords(coords, img0_shape)
+ if normalize:
+ coords[..., 0] /= img0_shape[1] # width
+ coords[..., 1] /= img0_shape[0] # height
+ return coords
+
+
+def masks2segments(masks, strategy='largest'):
+ """
+ It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
+
+ Args:
+ masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
+ strategy (str): 'concat' or 'largest'. Defaults to largest
+
+ Returns:
+ segments (List): list of segment masks
+ """
+ segments = []
+ for x in masks.int().cpu().numpy().astype('uint8'):
+ c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
+ if c:
+ if strategy == 'concat': # concatenate all segments
+ c = np.concatenate([x.reshape(-1, 2) for x in c])
+ elif strategy == 'largest': # select largest segment
+ c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
+ else:
+ c = np.zeros((0, 2)) # no segments found
+ segments.append(c.astype('float32'))
+ return segments
+
+
+def clean_str(s):
+ """
+ Cleans a string by replacing special characters with underscore _
+
+ Args:
+ s (str): a string needing special characters replaced
+
+ Returns:
+ (str): a string with special characters replaced by an underscore _
+ """
+ return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
diff --git a/src/effocr-layout/ocr/effocr/engines/recognizer_engine.py b/src/effocr-layout/ocr/effocr/engines/recognizer_engine.py
new file mode 100644
index 0000000..5819b67
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/engines/recognizer_engine.py
@@ -0,0 +1,41 @@
+import os
+import sys
+import torch
+import onnxruntime as ort
+import numpy as np
+
+
+class EffRecognizer:
+
+ def __init__(self, model, transform = None, num_cores = None, providers=None, char=True):
+
+ sess_options = ort.SessionOptions()
+ if num_cores is not None:
+ sess_options.intra_op_num_threads = num_cores
+
+ if providers is None:
+ providers = ort.get_available_providers()
+
+ self.transform = transform
+ # null_input = torch.zeros((3, 224, 224)) if char else torch.zeros((1, 224, 224))
+ self._eng_net = ort.InferenceSession(
+ model,
+ sess_options,
+ providers=providers,
+ )
+
+ def __call__(self, imgs):
+ return self.run(imgs)
+
+ def run(self, imgs):
+ trans_imgs = []
+ for img in imgs:
+ try:
+ trans_imgs.append(self.transform(img.astype(np.uint8))[0])
+ except Exception as e:
+ trans_imgs.append(torch.zeros((3, 224, 224)))
+
+ onnx_input = torch.nn.functional.pad(torch.stack(trans_imgs), (0, 0, 0, 0, 0, 0, 0, 64 - len(imgs))).numpy()
+
+ return self._eng_net.run(None, {'imgs': onnx_input})
+
diff --git a/src/effocr-layout/ocr/effocr/infer_line_detection b/src/effocr-layout/ocr/effocr/infer_line_detection
new file mode 100644
index 0000000..e69de29
diff --git a/src/effocr-layout/ocr/effocr/infer_transcripton.py b/src/effocr-layout/ocr/effocr/infer_transcripton.py
new file mode 100644
index 0000000..bed4a17
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/infer_transcripton.py
@@ -0,0 +1,527 @@
+import logging
+import torch
+from torchvision import transforms as T
+import numpy as np
+import queue
+from collections import defaultdict
+import threading
+from glob import glob
+import os
+import sys
+from PIL import Image, ImageDraw
+import time
+
+sys.path.insert(0, "../")
+# from utils.datasets_utils import *
+# from datasets.effocr_datasets import *
+# from utils.localizer_utils import *
+# from utils.coco_utils import *
+# from utils.spell_check_utils import *
+
+LARGE_NUMBER = 1000000000
+PARAGRAPH_BREAK = "\n\n"
+PARA_WEIGHT_L = 3
+PARA_WEIGHT_R = 1
+PARA_THRESH = 5
+ERROR_TEXT = 'XXX_ERROR_XXX'
+END_PUNCTUATION = '.?!,;:"'
+
+def check_any_overlap(bbox_1, bbox_2):
+ """Check if two bboxes overlap, we do this by checking all four corners of bbox_1 against bbox_2"""
+
+ x1, y1, x2, y2 = bbox_1
+ x3, y3, x4, y4 = bbox_2
+
+ return x1 < x4 and x2 > x3 and y1 < y4 and y2 > y3
+
+def check_hoi_overlap(bbox_1, bbox_2):
+ """ Check if two bboxes overlap horizontally, we do this by checking the right and left sides of bbox_1 against bbox_2"""
+
+ x1, y1, x2, y2 = bbox_1
+ x3, y3, x4, y4 = bbox_2
+
+ return x1 < x4 and x2 > x3
+
+def blank_layout_response():
+ return {-1: ''}
+
+def blank_dists_response():
+ return {'l_dists': {}, 'r_dists': {}}
+
+def ord_str_to_word(ord_str):
+ return ''.join([chr(int(o)) for o in ord_str.split('_')])
+
+def add_paragraph_breaks_to_dict(inference_assembly, side_dists):
+ l_list, r_list = [], []
+ im_ids = sorted(list(side_dists['l_dists'].keys()))
+ for i in im_ids:
+ l_list.append(side_dists['l_dists'][i])
+ r_list.append(side_dists['r_dists'][i])
+
+ try:
+ l_avg = sum(filter(None, l_list)) / (len(l_list) - l_list.count(None))
+ r_avg = sum(filter(None, r_list)) / (len(r_list) - r_list.count(None))
+ except ZeroDivisionError:
+ print("ZeroDivisionError: l_list: {}, r_list: {}".format(l_list, r_list))
+ print(f'side_dists: {side_dists}')
+ print(f'im_ids: {im_ids}')
+ print(f'l_avg: {l_avg}, r_avg: {r_avg}')
+ return inference_assembly
+
+ l_list = [l_avg if l is None else l for l in l_list]
+ r_list = [r_avg if r is None else r for r in r_list]
+ r_max = max(r_list)
+ r_avg = r_max - r_avg
+
+ l_list = [l / l_avg for l in l_list]
+ try:
+ r_list = [(r_max - r) / r_avg for r in r_list]
+ except ZeroDivisionError:
+ r_list = [0] * len(r_list)
+
+ for i in range(len(l_list) - 1):
+ score = l_list[i + 1] * PARA_WEIGHT_L + r_list[i] * PARA_WEIGHT_R
+ if score > PARA_THRESH:
+ inference_assembly[im_ids[i]]['text'] += PARAGRAPH_BREAK
+
+ return inference_assembly
+
+def find_overlaps(sorted_chars, sorted_words):
+ # For each word, find all chars it overlaps with:
+ word_char_idx = [[] for _ in range(len(sorted_words))]
+ word_idx = 0
+ for char_idx, char_bbox in enumerate(sorted_chars):
+ if word_idx >= len(sorted_words):
+ break
+
+ orig_idx = word_idx
+ while word_idx < len(sorted_words) and not check_any_overlap(sorted_words[word_idx], char_bbox):
+ word_idx += 1
+
+ # If the detected character is oddly positioned, like squished into the top of the screen or similar, we will
+ # not find an overlapping word and will need to reset.
+ if word_idx >= len(sorted_words):
+ word_idx = orig_idx
+ else:
+ word_char_idx[word_idx].append(char_idx)
+
+ return word_char_idx
+
+
+def en_preprocess(bboxes_char, bboxes_word, vertical=False):
+
+ sorted_bboxes_char = sorted(bboxes_char, key=lambda x: x[1] if vertical else x[0])
+ sorted_bboxes_word = sorted(bboxes_word, key=lambda x: x[1] if vertical else x[0])
+
+ # Find all overlaps between chars and words
+ word_char_idx = find_overlaps(sorted_bboxes_char, sorted_bboxes_word)
+
+ # # For each word, find all chars it overlaps with
+ # word_char_idx = []
+ # for word_bbox in sorted_bboxes_word:
+ # word_char_idx.append([])
+ # for char_idx, char_bbox in enumerate(sorted_bboxes_char):
+ # if check_any_overlap(word_bbox, char_bbox):
+ # word_char_idx[-1].append(char_idx)
+
+ # If there are no overlapping chars for a word, append the word bounding box to the list of chars as a char
+ redo_list, to_remove = False, []
+ for i, word_bbox in enumerate(sorted_bboxes_word):
+ if len(word_char_idx[i]) == 0:
+ remove = False
+ for j, comp_word_bbox in enumerate(sorted_bboxes_word):
+ if i != j and check_hoi_overlap(word_bbox, comp_word_bbox):
+ remove = True
+ to_remove.append(i)
+ break
+
+ if not remove:
+ sorted_bboxes_char.append(word_bbox)
+ redo_list = True
+
+ for i in sorted(to_remove, reverse=True):
+ del sorted_bboxes_word[i]
+ del word_char_idx[i]
+
+
+ # If we found a word with no overlapping chars, we now need to resort the char list and recreate the word_char_idx list
+ if redo_list:
+ # Resort the sorted_bboxes_char list and adjust the word_char_idx list accordingly
+ sorted_bboxes_char = sorted(sorted_bboxes_char, key=lambda x: x[1] if vertical else x[0])
+ word_char_idx = find_overlaps(sorted_bboxes_char, sorted_bboxes_word)
+
+ if any([len(w) == 0 for w in word_char_idx]):
+ print('Error: word_char_idx contains a list with no elements')
+ print(word_char_idx)
+ print(sorted_bboxes_char)
+ print(sorted_bboxes_word)
+ print(bboxes_char)
+ print(bboxes_word)
+ print(redo_list)
+ raise ValueError('word_char_idx contains a list with no elements')
+ # Return the lists of chars, words, and overlaps
+ return sorted_bboxes_char, sorted_bboxes_word, word_char_idx
+
+def create_batches(data, batch_size = 64, transform = None):
+ """Create batches for inference"""
+
+ batches = []
+ batch = []
+ for i, d in enumerate(data):
+ if d is not None:
+ batch.append(d)
+ else:
+ batch.append(np.zeros((33, 33, 3), dtype=np.int8))
+ if (i+1) % batch_size == 0:
+ batches.append(batch)
+ batch = []
+ if len(batch) > 0:
+ batches.append(batch)
+ return [b for b in batches]
+
+def get_crop_embeddings(recognizer_engine, crops, num_streams=4):
+ # Create batches of word crops
+ crop_batches = create_batches(crops)
+
+ input_queue = queue.Queue()
+ for i, batch in enumerate(crop_batches):
+ input_queue.put((i, batch))
+ output_queue = queue.Queue()
+ threads = []
+
+ for thread in range(num_streams):
+ threads.append(RecognizerEngineExecutorThread(recognizer_engine, input_queue, output_queue))
+
+ for thread in threads:
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+
+ embeddings = [None] * len(crop_batches)
+ while not output_queue.empty():
+ i, result = output_queue.get()
+ embeddings[i] = result[0][0]
+
+ embeddings = [torch.nn.functional.normalize(torch.from_numpy(embedding), p=2, dim=1) for embedding in embeddings]
+ return embeddings
+
+def iteration(model, input):
+ output = model.run(input)
+ return output, output
+
+''' Threaded Localizer Inference'''
+class LocalizerEngineExecutorThread(threading.Thread):
+ def __init__(
+ self,
+ model,
+ input_queue: queue.Queue,
+ output_queue: queue.Queue,
+ ):
+ super(LocalizerEngineExecutorThread, self).__init__()
+ self._model = model
+ self._input_queue = input_queue
+ self._output_queue = output_queue
+
+ def run(self):
+ while not self._input_queue.empty():
+ img_idx, img = self._input_queue.get()
+ output = iteration(self._model, [img])
+ self._output_queue.put((img_idx, output))
+
+'''Threaded Recognizer Inference'''
+class RecognizerEngineExecutorThread(threading.Thread):
+ def __init__(
+ self,
+ model,
+ input_queue: queue.Queue,
+ output_queue: queue.Queue,
+ ):
+ super(RecognizerEngineExecutorThread, self).__init__()
+ self._model = model
+ self._input_queue = input_queue
+ self._output_queue = output_queue
+
+ def run(self):
+ while not self._input_queue.empty():
+ i, batch = self._input_queue.get()
+ output = iteration(self._model, batch)
+ self._output_queue.put((i, output))
+
+''' Main Function for Running EffOCR on a set of textline images'''
+def run_effocr_word(textline_images, localizer_engine, recognizer_engine, char_recognizer_engine, candidate_chars, candidate_words, lang,
+ word_index, char_index, num_streams=4, vertical=False, localizer_output = None, conf_thres=0.5, recognizer_thresh = 0.5,
+ bbox_output = False, punc_padding = 0, insert_paragraph_breaks = True):
+
+ # textline_images = textline_images[:10]
+ inference_results = {}
+ inference_assembly = defaultdict(blank_layout_response)
+ inference_bboxes = defaultdict(dict)
+ image_id, anno_id = 0, 0
+
+ # print(len(textline_images))
+ input_queue = queue.Queue()
+ for im_idx, p in enumerate(textline_images):
+ input_queue.put((im_idx, p))
+ if bbox_output: # Start detections with empty list for each textline image
+ inference_bboxes[im_idx] = {'detections': {'words': [], 'chars': []}}
+ output_queue = queue.Queue()
+ threads = []
+ start = time.time()
+
+ for thread in range(num_streams):
+ threads.append(LocalizerEngineExecutorThread(localizer_engine, input_queue, output_queue))
+
+ for thread in threads:
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+
+ logging.info(f'Localizer inference time: {time.time() - start}')
+ big_start = time.time()
+ start = time.time()
+ word_crops, char_crops, n_words, n_chars = [], [], [], []
+ all_word_bboxes, word_rec_types, coco_new_order = [None] * len(textline_images), [None] * len(textline_images), []
+ word_char_overlaps, last_char_crops = [], []
+ side_dists = blank_dists_response()
+ parse_time, boxes_time, output_time, word_time, last_char_time = 0, 0, 0, 0, 0
+ logging.info('Init time: {}'.format(time.time() - start))
+ while not output_queue.empty():
+ start = time.time()
+ im_idx, result = output_queue.get()
+ coco_new_order.append(im_idx)
+ im = textline_images[im_idx]
+
+ if localizer_output:
+ os.makedirs(os.path.join(localizer_output, str(bbox_idx)), exist_ok=True)
+
+ if localizer_engine._model_backend == 'yolo' or localizer_engine._model_backend == 'yolov8':
+ result = result[0][0]
+ bboxes, labels = result[:, :4], result[:, -1]
+
+ elif localizer_engine._model_backend == 'detectron2':
+ result = result[0][0]
+ bboxes, labels = result[0][result[3] > conf_thres], result[1][result[3] > conf_thres]
+ bboxes, labels = torch.from_numpy(bboxes), torch.from_numpy(labels)
+ elif localizer_engine._model_backend == 'mmdetection':
+ result = result[0][0]
+ bboxes, labels = result[0][result[0][:, -1] > conf_thres], result[1][result[0][:, -1] > conf_thres]
+ bboxes = bboxes[:, :-1]
+ bboxes, labels = torch.from_numpy(bboxes), torch.from_numpy(labels)
+
+ parse_time += time.time() - start
+ start = time.time()
+ if lang == "en":
+ char_bboxes, word_bboxes = bboxes[labels == 0], bboxes[labels == 1]
+ if len(word_bboxes) != 0:
+ char_bboxes, word_bboxes, word_char_overlap = en_preprocess(char_bboxes, word_bboxes)
+ word_char_overlaps.append(word_char_overlap)
+ n_words.append(len(word_bboxes))
+ else:
+ n_words.append(0)
+ word_char_overlaps.append([])
+
+ if len(char_bboxes) != 0:
+ l_dist, r_dist = char_bboxes[0][0].item(), char_bboxes[-1][-2].item()
+ side_dists['l_dists'][im_idx] = l_dist # Store distances for paragraph detection
+ side_dists['r_dists'][im_idx] = r_dist
+ n_chars.append(len(char_bboxes))
+ else:
+ n_chars.append(0)
+ side_dists['l_dists'][im_idx] = None; side_dists['r_dists'][im_idx] = None
+
+ boxes_time += time.time() - start
+ start = time.time()
+
+ output_time += time.time() - start
+ start = time.time()
+ im_height, im_width = im.shape[0], im.shape[1]
+ # print(len(word_bboxes))
+ for i, bbox in enumerate(word_bboxes):
+ x0, y0, x1, y1 = torch.round(bbox)
+ # print(x0, y0, x1, y1)
+ if vertical:
+ x0, y0, x1, y1 = 0, int(round(y0.item() * im_height / 640)), im_width, int(round(y1.item() * im_height / 640))
+ else:
+ x0, y0, x1, y1 = int(round(x0.item() * im_width / 640)), 0, int(round(x1.item() * im_width / 640)), im_height
+
+ # Verify that the crop is not empty
+ if x0 == x1 or y0 == y1 or x0 < 0:
+ # If so, eliminate the corresponding entry in the word_char_overlaps list
+ word_char_overlaps[-1].pop(i)
+ n_words[-1] -= 1
+ else:
+ word_crops.append(im[y0:y1, x0:x1, :])
+
+ word_time += time.time() - start
+ start = time.time()
+ last_chars = [overlaps[-1] for overlaps in word_char_overlaps[-1]]
+ for i, bbox in enumerate(char_bboxes):
+ x0, y0, x1, y1 = torch.round(bbox)
+ if vertical:
+ x0, y0, x1, y1 = 0, int(round(y0.item() * im_height / 640)), im_width, int(round(y1.item() * im_height / 640))
+ else:
+ x0, y0, x1, y1 = int(round(x0.item() * im_width / 640)), 0, int(round(x1.item() * im_width / 640)), im_height
+
+ char_crops.append(im[y0:y1, x0:x1, :])
+
+ # if i in last_chars:
+ # last_char_crops.append(im[y0 - punc_padding:y1 + punc_padding, x0-punc_padding:x1+punc_padding, :])
+
+ last_char_time += time.time() - start
+ # print(len(word_crops))
+ # exit()
+
+ print('Word crops: ', len(word_crops))
+ print('Char crops: ', len(char_crops))
+ logging.info('Localizer results processing: {} seconds'.format(time.time() - big_start))
+ logging.info('Breakdown:')
+ logging.info('Parse: {} seconds'.format(parse_time))
+ logging.info('Boxes: {} seconds'.format(boxes_time))
+ logging.info('Output: {} seconds'.format(output_time))
+ logging.info('Word crops: {} seconds'.format(word_time))
+ logging.info('Last char crops: {} seconds'.format(last_char_time))
+ ''' --- Last Character Recognition ---'''
+ # This is an easy way to increase accuracy for punctuation by a lot-- we check the last character of every word
+ # (where punctuation is much more likely to appear) to see if it is a punctuation mark. If so, we adjust the word bounding box
+ # And save the punctuation mark to be appended to the word later on.
+
+ # Collect the last character crop from each word
+ start = time.time()
+ last_chars = [[overlap[-1] for overlap in word_char_overlap] for word_char_overlap in word_char_overlaps]
+ last_char_crops, char_idx = [], 0
+ for i, n in enumerate(n_chars):
+ for last in last_chars[i]:
+ last_char_crops.append(char_crops[char_idx + last])
+ char_idx += n
+ logging.info('Number of word crops: {}'.format(len(word_crops)))
+ logging.info('Number of last characters: {}'.format(len(last_char_crops)))
+ logging.info('Time to get last char crops: {}'.format(time.time() - start))
+ start = time.time()
+ # Create batches of last character crops
+ embeddings = get_crop_embeddings(char_recognizer_engine, last_char_crops, num_streams=num_streams)
+ logging.info('Time to get last char embeddings: {}'.format(time.time() - start))
+ start = time.time()
+ # Get the nearest neighbor of each last character crop
+ embeddings = torch.cat(embeddings, dim=0)
+ indices = char_index.search(embeddings, 1)[1]
+ nn_outputs_last_chars = [candidate_chars[idx[0]] for idx in indices][:len(word_crops)]
+
+ # If the nearest neighbor is a punctuation mark, we adjust the word bounding box and save the punctuation mark
+ found_end_punctuation, cur_line = [], 0
+ for i, nn_output in enumerate(nn_outputs_last_chars):
+ if nn_output in END_PUNCTUATION:
+ found_end_punctuation.append((i, nn_output))
+ word_crops[i] = word_crops[i][:, :(-1 * last_char_crops[i].shape[1])]
+
+ logging.info('Time to get last char nn: {}'.format(time.time() - start))
+
+ ''' Word level recognition '''
+ # Get recognizer embeddings of word crops
+ start = time.time()
+ embeddings = get_crop_embeddings(recognizer_engine, word_crops, num_streams=num_streams)
+ logging.info('Time to get word embeddings: {}'.format(time.time() - start))
+ start = time.time()
+ # Get the nearest neighbor of each word crop
+ embeddings = torch.cat(embeddings, dim=0)
+ distances, indices = word_index.search(embeddings, 1)
+ distances_and_indices = [(distance, index[0]) for distance, index in zip(distances, indices)]
+ nn_outputs, rec_types = [], []
+ logging.info('Time to get word nn: {}'.format(time.time() - start))
+
+ start = time.time()
+ # If the nearest neighbor is closer than the threshold, we recognize the word. Otherwise, we pass the word to char level recognition
+ for (distance, idx) in distances_and_indices:
+ if distance > recognizer_thresh:
+ nn_outputs.append(ord_str_to_word(candidate_words[idx]))
+ rec_types.append('word')
+ else:
+ nn_outputs.append("WORD_LEVEL")
+ rec_types.append('char')
+
+ # Add punctuation marks to the end of words recognized by the word recognizer
+ for (i, punctuation) in found_end_punctuation:
+ if nn_outputs[i] != 'WORD_LEVEL':
+ nn_outputs[i] += punctuation
+
+ ''' Char level recognition'''
+ # Collect char crops from words that are not recognized
+ char_crops_to_recognize, word_lens = [], []
+ word_idx, char_idx = 0, 0
+ for i, (n_c, n_w) in enumerate(zip(n_chars, n_words)):
+ for j in range(word_idx, word_idx + n_w):
+ if nn_outputs[j] == "WORD_LEVEL":
+ for k in word_char_overlaps[i][j - word_idx]:
+ char_crops_to_recognize.append(char_crops[char_idx + k])
+
+ word_lens.append(len(word_char_overlaps[i][j - word_idx]))
+
+ word_idx += n_w
+ char_idx += n_c
+
+ logging.info('Time to find char crops to recognize: {}'.format(time.time() - start))
+ logging.info('Number of char crops to recognize: {}'.format(len(char_crops_to_recognize)))
+ # Get char recognizer embeddings of char crops
+ start = time.time()
+ embeddings = get_crop_embeddings(char_recognizer_engine, char_crops_to_recognize, num_streams=num_streams)
+ logging.info('Time to get char embeddings: {}'.format(time.time() - start))
+
+ start = time.time()
+ if len(embeddings) < 1:
+ logging.info('No char crops to recognize')
+ return {}, {}
+
+ embeddings = torch.cat(embeddings, dim=0)
+ indices = char_index.search(embeddings, 1)[1]
+ nn_outputs_chars = [candidate_chars[idx[0]] for idx in indices]
+ word_idx, char_idx = 0, 0
+ logging.info('Time to get char nn: {}'.format(time.time() - start))
+
+ start = time.time()
+ # Summing only up to the total number of words to avoid running into padded examples
+ # at the end of the last batch
+ for i in range(sum(n_words)):
+ if nn_outputs[i] == "WORD_LEVEL":
+ textline = nn_outputs_chars[char_idx:char_idx + word_lens[word_idx]]
+ nn_outputs[i] = "".join(x[0] for x in textline).strip()
+ char_idx += word_lens[word_idx]
+ word_idx += 1
+
+ #Now run postprocessing to create full textlines
+ idx, textline_outputs, textline_rec_types = 0, [], []
+ for l in n_words:
+ textline_outputs.append(nn_outputs[idx:idx+l])
+ textline_rec_types.append(rec_types[idx:idx+l])
+ idx += l
+
+ outputs = [" ".join(x for x in textline).strip() for textline in textline_outputs]
+
+ # Postprocess textlines, saving to inference_assembly
+ if lang == "en":
+ for i, im_idx in enumerate(coco_new_order):
+ inference_assembly[im_idx] = {}
+ inference_assembly[im_idx]['text'] = outputs[i]
+ inference_assembly[im_idx]['rec_types'] = textline_rec_types[i]
+ if inference_assembly[im_idx] is None:
+ inference_assembly[im_idx]['text'] = " "
+ inference_assembly[im_idx]['rec_types'] = []
+
+ # Remove the -1 key from inference_assembly
+ for bbox_idx in inference_assembly.keys():
+ if -1 in inference_assembly[bbox_idx].keys():
+ del inference_assembly[bbox_idx][-1]
+
+ if insert_paragraph_breaks:
+ inference_assembly = add_paragraph_breaks_to_dict(inference_assembly, side_dists)
+
+ try:
+ inference_results = '\n'.join([inference_assembly[i]['text'] for i in sorted([int(x) for x in inference_assembly.keys()])])
+
+ except TypeError as e:
+ print(e)
+ print(inference_assembly)
+ inference_results = ''
+
+ logging.info('Time to postprocess: {}'.format(time.time() - start))
+ return inference_results, inference_bboxes
\ No newline at end of file
diff --git a/src/effocr-layout/ocr/effocr/utils/__init__.py b/src/effocr-layout/ocr/effocr/utils/__init__.py
new file mode 100644
index 0000000..7b16038
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/utils/__init__.py
@@ -0,0 +1,2 @@
+from .dataset_utils import create_paired_transform, create_paired_transform_word
+from .image_utils import letterbox, non_max_suppression
\ No newline at end of file
diff --git a/src/effocr-layout/ocr/effocr/utils/dataset_utils.py b/src/effocr-layout/ocr/effocr/utils/dataset_utils.py
new file mode 100644
index 0000000..5b0bc21
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr/utils/dataset_utils.py
@@ -0,0 +1,79 @@
+import numpy as np
+from torchvision import transforms as T
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from PIL import Image
+import time
+from timeit import default_timer as timer
+
+
+def chunks(lst, n):
+ """Yield successive n-sized chunks from lst."""
+ for i in range(0, len(lst), n):
+ yield lst[i:i + n]
+
+class MedianPadWord:
+ """This padding preserves the aspect ratio of the image. It also pads the image with the median value of the border pixels.
+ Note how it also centres the ROI in the padded image."""
+ def __init__(self, override=None,aspect_cutoff=0):
+ self.override = override
+ self.aspect_cutoff=aspect_cutoff
+ def __call__(self, image):
+ ##Convert to RGB
+ image = image.convert("RGB") if isinstance(image, Image.Image) else image
+ image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
+ max_side = max(image.size)
+ aspect_ratio = image.size[0] / image.size[1]
+ if aspect_ratio conf_thres # candidates
+
+ # Checks
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+
+ # Settings
+ # min_wh = 2 # (pixels) minimum box width and height
+ max_wh = 7680 # (pixels) maximum box width and height
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
+ time_limit = 0.5 + 0.05 * bs # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ mi = 5 + nc # mask start index
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ lb = labels[xi]
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
+ v[:, :4] = lb[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box/Mask
+ box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
+ mask = x[:, mi:] # zero columns if no masks
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
+ else: # best class only
+ conf, j = x[:, 5:mi].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+ elif n > max_nms: # excess boxes
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
+ else:
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if i.shape[0] > max_det: # limit detections
+ i = i[:max_det]
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if mps:
+ output[xi] = output[xi].to(device)
+
+ return output
+
diff --git a/src/effocr-layout/ocr/effocr_agent.py b/src/effocr-layout/ocr/effocr_agent.py
new file mode 100644
index 0000000..9a383b9
--- /dev/null
+++ b/src/effocr-layout/ocr/effocr_agent.py
@@ -0,0 +1,264 @@
+import io
+import os
+import json
+import warnings
+
+import numpy as np
+from cv2 import imencode
+import multiprocessing
+import faiss
+from huggingface_hub import hf_hub_download
+import joblib
+
+from .base import BaseOCRAgent, BaseOCRElementType
+from .effocr import EffLocalizer, EffRecognizer, EffLineDetector, \
+ run_effocr_word, create_paired_transform, create_paired_transform_word
+
+EFFOCR_DEFAULT_CONFIG = {
+ "line_model": "",
+ "line_backend": "yolov8",
+ "line_input_shape": (640, 640),
+ "localizer_model": "",
+ "localizer_backend": "yolov8",
+ "localizer_input_shape": (640, 640),
+ "word_recognizer_model": "./src/layoutparser/models/effocr/word_recognizer/enc.onnx",
+ "word_index": "./src/layoutparser/models/effocr/word_recognizer/word_index.index",
+ "word_ref": "./src/layoutparser/models/effocr/word_recognizer/word_ref.txt",
+ "char_recognizer_model": "./src/layoutparser/models/effocr/char_recognizer/enc.onnx",
+ "char_index": "./src/layoutparser/models/effocr/char_recognizer/char_index.index",
+ "char_ref": "./src/layoutparser/models/effocr/char_recognizer/char_ref.txt",
+ "localizer_iou_thresh": 0.10,
+ "localizer_conf_thresh": 0.20,
+ "line_iou_thresh": 0.05,
+ "line_conf_thresh": 0.50,
+ "word_dist_thresh": 0.90,
+ "lang": "en",
+}
+
+HUGGINGFACE_MODEL_MAP = {
+ 'line_model': 'line.onnx',
+ 'localizer_model': 'localizer.onnx',
+ 'word_recognizer_model': 'word_recognizer/enc.onnx',
+ 'char_recognizer_model': 'char_recognizer/enc.onnx',
+ 'word_index': 'word_recognizer/word_index.index',
+ 'word_ref': 'word_recognizer/word_ref.txt',
+ 'char_index': 'char_recognizer/char_index.index',
+ 'char_ref': 'char_recognizer/char_ref.txt'
+}
+
+HUGGINGFACE_REPO_NAME = 'dell-research-harvard/effocr_en'
+
+class EffOCRFeatureType(BaseOCRElementType):
+ """
+ The element types from EffOCR
+ """
+
+ PAGE = 0
+ PARA = 1
+ LINE = 2
+ WORD = 3
+ CHAR = 4
+
+ @property
+ def attr_name(self):
+ name_cvt = {
+ EffOCRFeatureType.BLOCK: "blocks",
+ EffOCRFeatureType.PARA: "paragraphs",
+ EffOCRFeatureType.LINE: "lines",
+ EffOCRFeatureType.WORD: "words",
+ }
+ return name_cvt[self]
+
+ @property
+ def child_level(self):
+ child_cvt = {
+ EffOCRFeatureType.BLOCK: EffOCRFeatureType.PARA,
+ EffOCRFeatureType.PARA: EffOCRFeatureType.LINE,
+ EffOCRFeatureType.LINE: EffOCRFeatureType.WORD,
+ EffOCRFeatureType.WORD: None,
+ }
+ return child_cvt[self]
+
+
+
+class EffOCRAgent(BaseOCRAgent):
+ """EffOCR Inference -- Implements method described in https://scholar.harvard.edu/sites/scholar.harvard.edu/files/dell/files/effocr.pdf
+
+ Note:
+ TODO: Fill in with info once implemented
+ """
+
+ # TODO: Fill in with package dependencies
+ DEPENDENCIES = ["effocr"]
+
+ def __init__(self, languages="eng", **kwargs):
+ """Create a EffOCR Agent.
+
+ Args:
+ languages (:obj:`list` or :obj:`str`, optional):
+ You can specify the language code(s) of the documents to detect to determine the
+ language EffOCR uses when transcribing the document. As of 7/24, the only option is
+ English, but Japanese EffOCR will be implemented soon.
+ Defaults to 'eng'.
+ """
+ if languages != 'eng':
+ raise NotImplementedError("EffOCR only supports English at this time.")
+
+ self.lang = languages if isinstance(languages, str) else "+".join(languages)
+
+ self.config = EFFOCR_DEFAULT_CONFIG
+ for key, value in kwargs.items():
+ if key in self.config.keys():
+ self.config[key] = value
+ else:
+ warnings.warn(f"Unknown config parameter {key} for {self.__class__.__name__}. Ignoring it.")
+
+ self._check_and_download_models()
+ self._check_and_download_indices()
+ self._load_models()
+ self._load_indices()
+ print(self.config)
+
+ def _check_and_download_models(self):
+ '''
+ Checks if all of line, localizer, word recognizer, and char recognizer are downloaded,
+ then downloads them if they are not.
+ '''
+
+ model_keys = ['line_model', 'localizer_model', 'word_recognizer_model', 'char_recognizer_model']
+ for key in model_keys:
+ if not os.path.exists(self.config[key]) or not self.config[key].endswith('.onnx'):
+ self.config[key] = hf_hub_download(HUGGINGFACE_REPO_NAME, HUGGINGFACE_MODEL_MAP[key])
+ # TODO: replace FileNotFoundError with download code
+
+ def _check_and_download_indices(self):
+ '''
+ Checks if the word and character recognizers' indices and refernece files are downloaded,
+ then downloads them if they are not.
+ '''
+
+ index_keys = ['word_index', 'char_index']
+ ref_keys = ['word_ref', 'char_ref']
+
+ for key in index_keys:
+ if not os.path.exists(self.config[key]):
+ self.config[key] = hf_hub_download(HUGGINGFACE_REPO_NAME, HUGGINGFACE_MODEL_MAP[key])
+
+ for key in ref_keys:
+ if not os.path.exists(self.config[key]):
+ self.config[key] = hf_hub_download(HUGGINGFACE_REPO_NAME, HUGGINGFACE_MODEL_MAP[key])
+
+ def _load_models(self):
+ '''
+ Function to instantiate each of the line model,
+ localizer model, word recognizer model, and char recognizer model.
+ '''
+
+ self.localizer_engine = EffLocalizer(
+ self.config['localizer_model'],
+ iou_thresh = self.config['localizer_iou_thresh'],
+ conf_thresh = self.config['localizer_conf_thresh'],
+ vertical = False if self.config['lang'] == "en" else True,
+ num_cores = multiprocessing.cpu_count(),
+ model_backend = self.config['localizer_backend'],
+ input_shape = self.config['localizer_input_shape']
+ )
+
+ # TODO: Fix imports for paired_transforms
+ char_transform = create_paired_transform(lang='en')
+ word_transform = create_paired_transform_word(lang='en')
+
+ self.word_recognizer_engine = EffRecognizer(
+ model = self.config['word_recognizer_model'],
+ transform = char_transform,
+ num_cores=multiprocessing.cpu_count(),
+ )
+
+ self.char_recognizer_engine = EffRecognizer(
+ model = self.config['char_recognizer_model'],
+ transform = char_transform,
+ num_cores=multiprocessing.cpu_count(),
+ )
+
+ self.line_detector_engine = EffLineDetector(
+ self.config['line_model'],
+ iou_thresh = self.config['line_iou_thresh'],
+ conf_thresh = self.config['line_conf_thresh'],
+ num_cores = multiprocessing.cpu_count(),
+ model_backend = self.config['line_backend'],
+ input_shape = self.config['line_input_shape']
+ )
+
+ def _load_indices(self):
+ '''
+ Function to instantiate the faiss indices for each of the word and character recognizers.
+ Indicies are responsible for storing base vectors for each word/character and performing
+ similarity search on unknown symbols.
+ '''
+
+ # char index
+ self.char_index = faiss.read_index(self.config['char_index'])
+ with open(self.config['char_ref']) as ref_file:
+ self.candidate_chars = ref_file.read().split()
+
+ # word index
+ self.word_index = faiss.read_index(self.config['word_index'])
+ with open(self.config['word_ref']) as ref_file:
+ self.candidate_words = ref_file.read().split()
+
+ def _detect(self, image, viz_lines_path=None):
+ '''
+ Function to detect text in an image using EffOCR.
+
+ Each of the two main parts, line detection and line transcription, are abstrated out here
+ '''
+
+ # Line Detection
+ line_crops, line_coords = self.line_detector_engine(image)
+
+ # Line Transcription
+ text_results = run_effocr_word(line_crops, self.localizer_engine, self.word_recognizer_engine, self.char_recognizer_engine, self.candidate_chars,
+ self.candidate_words, self.config['lang'], self.word_index, self.char_index, num_streams=multiprocessing.cpu_count(), vertical=False,
+ localizer_output = None, conf_thres=self.config['localizer_conf_thresh'], recognizer_thresh = self.config['word_dist_thresh'],
+ bbox_output = False, punc_padding = 0, insert_paragraph_breaks = True)
+
+ return text_results
+
+ def detect(self, image, return_response=False, return_only_text=True, agg_output_level=None, viz_lines_path = None):
+ """Send the input image for OCR by the EffOCR agent.
+
+ Args:
+ image (:obj:`np.ndarray` or :obj:`str`):
+ The input image array or the name of the image file
+ return_response (:obj:`bool`, optional):
+ Whether directly return the effocr output.
+ Defaults to `False`.
+ return_only_text (:obj:`bool`, optional):
+ Whether return only the texts in the OCR results.
+ Defaults to `False`.
+ agg_output_level (:obj:`~EffOCRFeatureType`, optional):
+ When set, aggregate the EffOCR output with respect to the
+ specified aggregation level. Defaults to `None`.
+
+ Returns:
+ :obj:`dict` or :obj:`str`:
+ The OCR results in the specified format.
+ """
+
+ res = self._detect(image, viz_lines_path = viz_lines_path)
+
+ if return_response:
+ return res
+
+ if return_only_text:
+ return res["text"]
+
+ if agg_output_level is not None:
+ return self.gather_data(res, agg_output_level)
+
+ return res["text"]
+
+
+if __name__ == '__main__':
+ agent = EffOCRAgent()
+ img_path = r'C:\Users\bryan\Documents\NBER\layout-parser\tests\fixtures\ocr\test_effocr_image.jpg'
diff --git a/src/layoutparser/ocr/gcv_agent.py b/src/effocr-layout/ocr/gcv_agent.py
similarity index 100%
rename from src/layoutparser/ocr/gcv_agent.py
rename to src/effocr-layout/ocr/gcv_agent.py
diff --git a/src/layoutparser/ocr/tesseract_agent.py b/src/effocr-layout/ocr/tesseract_agent.py
similarity index 100%
rename from src/layoutparser/ocr/tesseract_agent.py
rename to src/effocr-layout/ocr/tesseract_agent.py
diff --git a/src/layoutparser/tools/__init__.py b/src/effocr-layout/tools/__init__.py
similarity index 100%
rename from src/layoutparser/tools/__init__.py
rename to src/effocr-layout/tools/__init__.py
diff --git a/src/layoutparser/tools/shape_operations.py b/src/effocr-layout/tools/shape_operations.py
similarity index 100%
rename from src/layoutparser/tools/shape_operations.py
rename to src/effocr-layout/tools/shape_operations.py
diff --git a/src/layoutparser/visualization.py b/src/effocr-layout/visualization.py
similarity index 100%
rename from src/layoutparser/visualization.py
rename to src/effocr-layout/visualization.py
diff --git a/tests/fixtures/ocr/line_dets.png b/tests/fixtures/ocr/line_dets.png
new file mode 100644
index 0000000..1600a5c
Binary files /dev/null and b/tests/fixtures/ocr/line_dets.png differ
diff --git a/tests/fixtures/ocr/test_effocr_image.jpg b/tests/fixtures/ocr/test_effocr_image.jpg
new file mode 100644
index 0000000..b9b60b0
Binary files /dev/null and b/tests/fixtures/ocr/test_effocr_image.jpg differ
diff --git a/tests/test_ocr.py b/tests/test_ocr.py
index 0c42cfc..99b32f0 100644
--- a/tests/test_ocr.py
+++ b/tests/test_ocr.py
@@ -17,10 +17,13 @@
GCVFeatureType,
TesseractAgent,
TesseractFeatureType,
+ EffOCRAgent,
+ EffOCRFeatureType,
)
import json, cv2, os
image = cv2.imread("tests/fixtures/ocr/test_gcv_image.jpg")
+effocr_image = cv2.imread("tests/fixtures/ocr/test_effocr_image.jpg")
def test_gcv_agent(test_detect=False):
@@ -76,4 +79,29 @@ def test_tesseract(test_detect=False):
assert r2 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.BLOCK)
assert r3 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PARA)
assert r4 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.LINE)
- assert r5 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD)
\ No newline at end of file
+ assert r5 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD)
+
+'''
+Test the EffOCRAgent, which implements EffOCR -- https://scholar.harvard.edu/sites/scholar.harvard.edu/files/dell/files/effocr.pdf
+'''
+def test_effocr(test_detect=True):
+ ocr_agent = EffOCRAgent()
+
+ # res = ocr_agent.load_response("tests/fixtures/ocr/test_effocr_response.json")
+ # r0 = ocr_agent.gather_text_annotations(res)
+ # r1 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.BLOCK)
+ # r2 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.PARA)
+ # r3 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.LINE)
+ # r4 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.WORD)
+ # r5 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.CHAR)
+
+ if test_detect:
+ res = ocr_agent.detect(effocr_image, return_response=True)
+ assert "The tug boat Alice" in res[0]
+ assert False
+ # assert r0 == res["text"]
+ # assert r1 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.BLOCK)
+ # assert r2 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.PARA)
+ # assert r3 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.LINE)
+ # assert r4 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.WORD)
+ # assert r5 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.CHAR)