Skip to content

Commit

Permalink
chore: update files
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Feb 11, 2025
1 parent f10eab5 commit 5bf7702
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 18 deletions.
7 changes: 4 additions & 3 deletions python/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# @Contact: liekkaskono@163.com
import cv2

from rapidocr import RapidOCR, VisRes

# from rapidocr_onnxruntime import RapidOCR, VisRes
from rapidocr_torch import RapidOCR, VisRes

# from rapidocr import RapidOCR, VisRes


# from rapidocr_paddle import RapidOCR, VisRes
Expand All @@ -14,7 +15,7 @@
# yaml_path = "tests/test_files/config.yaml"
# engine = RapidOCR(config_path=yaml_path)

engine = RapidOCR(params={"Cls.model_path": "1.onnx"})
engine = RapidOCR(params={"Global.with_paddle": True})
vis = VisRes()

image_path = "tests/test_files/ch_en_num.jpg"
Expand Down
9 changes: 6 additions & 3 deletions python/rapidocr/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ EngineConfig:
gpu_id: 0
gpu_mem: 500

openvino:
inference_num_threads: -1

torch:
use_cuda: false

Det:
model_path: models/ch_PP-OCRv4_det_infer.onnx
paddle_model_dir: models/paddle/ch_PP-OCRv4_det_infer

limit_side_len: 736
limit_type: min
Expand All @@ -47,7 +52,6 @@ Det:

Cls:
model_path: models/ch_ppocr_mobile_v2.0_cls_infer.onnx
paddle_model_dir: models/paddle/ch_ppocr_mobile_v2_cls_infer

cls_image_shape: [3, 48, 192]
cls_batch_num: 6
Expand All @@ -56,7 +60,6 @@ Cls:

Rec:
model_path: models/ch_PP-OCRv4_rec_infer.onnx
paddle_model_dir: models/paddle/ch_PP-OCRv4_rec_infer

rec_img_shape: [3, 48, 320]
rec_batch_num: 6
3 changes: 2 additions & 1 deletion python/rapidocr/inference_engine/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import traceback

import numpy as np
from omegaconf import DictConfig
from openvino.runtime import Core

from .base import InferSession


class OpenVINOInferSession(InferSession):
def __init__(self, config):
def __init__(self, config: DictConfig):
core = Core()

self._verify_model(config["model_path"])
Expand Down
2 changes: 1 addition & 1 deletion python/rapidocr/inference_engine/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, config, mode: Optional[str] = None) -> None:
self.logger = get_logger("PaddleInferSession")
self.mode = mode

model_dir = Path(config.padde_model_dir)
model_dir = Path(config.model_path)
pdmodel_path = model_dir / "inference.pdmodel"
pdiparams_path = model_dir / "inference.pdiparams"

Expand Down
4 changes: 3 additions & 1 deletion python/rapidocr/inference_engine/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]:

class TorchInferSession:
def __init__(self, config, mode: Optional[str] = None) -> None:
self.logger = get_logger("TorchInferSession")

all_arch_config = read_yaml(DEFAULT_CFG_PATH)

self.logger = get_logger("TorchInferSession")
self.mode = mode
model_path = Path(config["model_path"])
self._verify_model(model_path)
file_name = model_path.stem
if file_name not in all_arch_config:
raise ValueError(f"architecture {file_name} is not in config.yaml")

arch_config = all_arch_config[file_name]
self.predictor = BaseModel(arch_config)
self.predictor.load_state_dict(torch.load(model_path, weights_only=True))
Expand Down
1 change: 0 additions & 1 deletion python/rapidocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(

engine_name = get_engine_name(config)

# 根据选定的语言加载对应的模型
det_lang, rec_lang = parse_lang(config.Global.lang)

self.print_verbose = config.Global.print_verbose
Expand Down
1 change: 1 addition & 0 deletions python/rapidocr/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from .parse_parameters import ParseParams, init_args, parse_lang
from .process_img import add_round_letterbox, increase_min_side, reduce_max_side
from .typings import RapidOCROutput
from .utils import download_file
from .vis_res import VisRes
28 changes: 28 additions & 0 deletions python/rapidocr/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from pathlib import Path
from typing import Union

import requests
from tqdm import tqdm


def download_file(url: str, save_path: Union[str, Path]):
response = requests.get(url, stream=True, timeout=60)
status_code = response.status_code

if status_code != 200:
raise DownloadFileError("Something went wrong while downloading models")

total_size_in_bytes = int(response.headers.get("content-length", 1))
block_size = 1024 # 1 Kibibyte
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as pb:
with open(save_path, "wb") as file:
for data in response.iter_content(block_size):
pb.update(len(data))
file.write(data)


class DownloadFileError(Exception):
pass
12 changes: 7 additions & 5 deletions python/rapidocr_torch/utils/infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os
import platform
from pathlib import Path
from typing import Optional, Union, Dict
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand All @@ -13,18 +11,20 @@
root_dir = Path(__file__).resolve().parent.parent
DEFAULT_CFG_PATH = root_dir / "arch_config.yaml"


def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]:
with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data

from .logger import get_logger

from rapidocr_torch.modeling.architectures.base_model import BaseModel

from .logger import get_logger


class TorchInferSession:
def __init__(self, config, mode: Optional[str] = None) -> None:

all_arch_config = read_yaml(DEFAULT_CFG_PATH)

self.logger = get_logger("TorchInferSession")
Expand All @@ -42,6 +42,7 @@ def __init__(self, config, mode: Optional[str] = None) -> None:
if config["use_cuda"]:
self.predictor.cuda()
self.use_gpu = True

def __call__(self, img: np.ndarray):
with torch.no_grad():
inp = torch.from_numpy(img)
Expand All @@ -50,6 +51,7 @@ def __call__(self, img: np.ndarray):
# 适配跟onnx对齐取值逻辑
outputs = self.predictor(inp).unsqueeze(0)
return outputs.cpu().numpy()

@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
Expand Down
22 changes: 19 additions & 3 deletions python/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,29 @@
sys.path.append(str(root_dir))

from rapidocr import LoadImageError, RapidOCR

from .base_module import download_file
from rapidocr.utils import download_file

engine = RapidOCR()
tests_dir = root_dir / "tests" / "test_files"
img_path = tests_dir / "ch_en_num.jpg"
package_name = "rapidocr_onnxruntime"


def test_engine_openvino():
engine = RapidOCR(params={"Global.with_openvino": True})
result = engine(img_path)
assert result.txts[0] == "正品促销"


def test_engine_paddle():
engine = RapidOCR(
params={
"Global.with_paddle": True,
"Det.model_path": "tests/test_files/ch_ppocr_server_v2.0_det_infer.onnx",
"Rec.model_path": "tests/test_files/ch_ppocr_server_v2.0_rec_infer.onnx",
}
)
result = engine(img_path)
assert result.txts[0] == "正品促销"


def test_long_img():
Expand Down

0 comments on commit 5bf7702

Please # to comment.