Skip to content

Commit

Permalink
chore: update files
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Feb 14, 2025
1 parent ee28011 commit 6de7194
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# engine = engine = RapidOCR(
# params={"Global.with_onnx": True, "EngineConfig.onnxruntime.use_cuda": True}
# )
engine = RapidOCR(params={"Global.with_paddle": True, "Global.lang": "ch"})
engine = RapidOCR(params={"Global.with_openvino": True, "Global.lang": "ch"})
vis = VisRes()

image_path = "tests/test_files/ch_en_num.jpg"
Expand Down
4 changes: 2 additions & 2 deletions python/rapidocr/inference_engine/onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from onnxruntime import (
Expand All @@ -28,7 +28,7 @@ class EP(Enum):


class OrtInferSession(InferSession):
def __init__(self, config: Dict[str, Any]):
def __init__(self, config: Dict[str, Any], mode: Optional[str] = None):
self.logger = Logger(logger_name=__name__).get_log()

model_path = config.get("model_path", None)
Expand Down
7 changes: 6 additions & 1 deletion python/rapidocr/inference_engine/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import traceback
from pathlib import Path
from typing import Optional

import numpy as np
from omegaconf import DictConfig
Expand All @@ -13,8 +14,9 @@


class OpenVINOInferSession(InferSession):
def __init__(self, config: DictConfig):
def __init__(self, config: DictConfig, mode: Optional[str] = None):
super().__init__(config)
self.mode = mode

core = Core()

Expand All @@ -23,6 +25,9 @@ def __init__(self, config: DictConfig):
default_model_url = self.get_model_url(
config.engine_name, config.task_type, config.lang
)
if self.mode == "rec":
default_model_url = default_model_url.model_dir

model_path = self.DEFAULT_MODE_PATH / Path(default_model_url).name
self.download_file(default_model_url, model_path)

Expand Down
9 changes: 9 additions & 0 deletions python/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,24 @@ def get_engine(params: Optional[Dict[str, Any]] = None):
return engine


def test_lang():
engine = get_engine(params={"Global.lang": "ch", "Global.with_openvino": True})
result = engine(img_path)
assert result.txts is not None
assert result.txts[0] == "正品促销"


def test_engine_openvino():
engine = get_engine(params={"Global.with_openvino": True})
result = engine(img_path)
assert result.txts is not None
assert result.txts[0] == "正品促销"


def test_engine_paddle():
engine = RapidOCR(params={"Global.with_paddle": True})
result = engine(img_path)
assert result.txts is not None
assert result.txts[0] == "正品促销"


Expand Down

0 comments on commit 6de7194

Please # to comment.