diff --git a/camelot/handlers.py b/camelot/handlers.py index 9ec10bb..eca05bd 100644 --- a/camelot/handlers.py +++ b/camelot/handlers.py @@ -6,7 +6,7 @@ import sys from PyPDF2 import PdfFileReader, PdfFileWriter from .core import TableList -from .parsers import Stream, Lattice +from .parsers import Lattice, Stream, LatticeOCR, StreamOCR from .utils import ( TemporaryDirectory, get_page_layout, @@ -163,14 +163,19 @@ class PDFHandler(object): List of tables found in PDF. """ + parsers = { + "lattice": Lattice, + "stream": Stream, + "lattice_ocr": LatticeOCR, + "stream_ocr": StreamOCR, + } + tables = [] with TemporaryDirectory() as tempdir: for p in self.pages: self._save_page(self.filepath, p, tempdir) - pages = [ - os.path.join(tempdir, f"page-{p}.pdf") for p in self.pages - ] - parser = Lattice(**kwargs) if flavor == "lattice" else Stream(**kwargs) + pages = [os.path.join(tempdir, f"page-{p}.pdf") for p in self.pages] + parser = parsers[flavor](**kwargs) for p in pages: t = parser.extract_tables( p, suppress_stdout=suppress_stdout, layout_kwargs=layout_kwargs diff --git a/camelot/io.py b/camelot/io.py index a27a7c6..66f8104 100644 --- a/camelot/io.py +++ b/camelot/io.py @@ -98,9 +98,10 @@ def read_pdf( tables : camelot.core.TableList """ - if flavor not in ["lattice", "stream"]: + if flavor not in ["lattice", "stream", "lattice_ocr", "stream_ocr"]: raise NotImplementedError( - "Unknown flavor specified." " Use either 'lattice' or 'stream'" + "Unknown flavor specified. Use one of the following: " + "'lattice', 'stream', 'lattice_ocr', 'stream_ocr'" ) with warnings.catch_warnings(): diff --git a/camelot/parsers/__init__.py b/camelot/parsers/__init__.py index 5cc6605..5648917 100644 --- a/camelot/parsers/__init__.py +++ b/camelot/parsers/__init__.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- -from .stream import Stream from .lattice import Lattice +from .stream import Stream +from .lattice_ocr import LatticeOCR +from .stream_ocr import StreamOCR diff --git a/camelot/parsers/lattice_ocr.py b/camelot/parsers/lattice_ocr.py new file mode 100644 index 0000000..bc235fb --- /dev/null +++ b/camelot/parsers/lattice_ocr.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- + +import os +import copy +import logging +import subprocess + +try: + import easyocr +except ImportError: + _HAS_EASYOCR = False +else: + _HAS_EASYOCR = True + +import pandas as pd +from PIL import Image + +from .base import BaseParser +from ..core import Table +from ..utils import TemporaryDirectory, merge_close_lines, scale_image, segments_in_bbox +from ..image_processing import ( + adaptive_threshold, + find_lines, + find_contours, + find_joints, +) + + +logger = logging.getLogger("camelot") + + +class LatticeOCR(BaseParser): + def __init__( + self, + table_areas=None, + line_scale=15, + line_tol=2, + joint_tol=2, + threshold_blocksize=15, + threshold_constant=-2, + iterations=0, + resolution=300, + ): + self.table_areas = table_areas + self.line_scale = line_scale + self.line_tol = line_tol + self.joint_tol = joint_tol + self.threshold_blocksize = threshold_blocksize + self.threshold_constant = threshold_constant + self.iterations = iterations + self.resolution = resolution + + if _HAS_EASYOCR: + self.reader = easyocr.Reader(['en'], gpu=False) + else: + raise ImportError("easyocr is required to run OCR on image-based PDFs.") + + def _generate_image(self): + from ..ext.ghostscript import Ghostscript + + self.imagename = "".join([self.rootname, ".png"]) + gs_call = "-q -sDEVICE=png16m -o {} -r300 {}".format( + self.imagename, self.filename + ) + gs_call = gs_call.encode().split() + null = open(os.devnull, "wb") + with Ghostscript(*gs_call, stdout=null) as gs: + pass + null.close() + + def _generate_table_bbox(self): + self.image, self.threshold = adaptive_threshold( + self.imagename, blocksize=self.threshold_blocksize, c=self.threshold_constant + ) + + image_width = self.image.shape[1] + image_height = self.image.shape[0] + + vertical_mask, vertical_segments = find_lines( + self.threshold, + direction="vertical", + line_scale=self.line_scale, + iterations=self.iterations, + ) + horizontal_mask, horizontal_segments = find_lines( + self.threshold, + direction="horizontal", + line_scale=self.line_scale, + iterations=self.iterations, + ) + + if self.table_areas is not None: + areas = [] + for area in self.table_areas: + x1, y1, x2, y2 = area.split(",") + x1 = int(float(x1)) + y1 = int(float(y1)) + x2 = int(float(x2)) + y2 = int(float(y2)) + areas.append((x1, y1, abs(x2 - x1), abs(y2 - y1))) + table_bbox = find_joints(areas, vertical_mask, horizontal_mask) + else: + contours = find_contours(vertical_mask, horizontal_mask) + table_bbox = find_joints(contours, vertical_mask, horizontal_mask) + + self.table_bbox_unscaled = copy.deepcopy(table_bbox) + + self.table_bbox = table_bbox + self.vertical_segments = vertical_segments + self.horizontal_segments = horizontal_segments + + def _generate_columns_and_rows(self, table_idx, tk): + cols, rows = zip(*self.table_bbox[tk]) + cols, rows = list(cols), list(rows) + cols.extend([tk[0], tk[2]]) + rows.extend([tk[1], tk[3]]) + # sort horizontal and vertical segments + cols = merge_close_lines(sorted(cols), line_tol=self.line_tol) + rows = merge_close_lines(sorted(rows), line_tol=self.line_tol) + # make grid using x and y coord of shortlisted rows and cols + cols = [(cols[i], cols[i + 1]) for i in range(0, len(cols) - 1)] + rows = [(rows[i], rows[i + 1]) for i in range(0, len(rows) - 1)] + + return cols, rows + + + def _generate_table(self, table_idx, cols, rows, **kwargs): + table = Table(cols, rows) + # set table edges to True using ver+hor lines + table = table.set_edges(self.vertical_segments, self.horizontal_segments, joint_tol=self.joint_tol) + # set table border edges to True + table = table.set_border() + # set spanning cells to True + table = table.set_span() + + for r_idx in range(len(table.cells)): + for c_idx in range(len(table.cells[r_idx])): + x1 = int(table.cells[r_idx][c_idx].x1) + y1 = int(table.cells[r_idx][c_idx].y1) + x2 = int(table.cells[r_idx][c_idx].x2) + y2 = int(table.cells[r_idx][c_idx].y2) + + with TemporaryDirectory() as tempdir: + temp_image_path = os.path.join(tempdir, f"{table_idx}_{r_idx}_{c_idx}.png") + + cell_image = Image.fromarray(self.image[y2:y1, x1:x2]) + cell_image.save(temp_image_path) + + text = self.reader.readtext(temp_image_path, detail=0) + text = " ".join(text) + + table.cells[r_idx][c_idx].text = text + + data = table.data + table.df = pd.DataFrame(data) + table.shape = table.df.shape + + table.flavor = "lattice_ocr" + table.accuracy = 0 + table.whitespace = 0 + table.order = table_idx + 1 + table.page = int(os.path.basename(self.rootname).replace("page-", "")) + + # for plotting + table._text = None + table._image = (self.image, self.table_bbox_unscaled) + table._segments = (self.vertical_segments, self.horizontal_segments) + table._textedges = None + + return table + + def extract_tables(self, filename, suppress_stdout=False, layout_kwargs={}): + self._generate_layout(filename, layout_kwargs) + if not suppress_stdout: + logger.info("Processing {}".format(os.path.basename(self.rootname))) + + self._generate_image() + self._generate_table_bbox() + + _tables = [] + # sort tables based on y-coord + for table_idx, tk in enumerate( + sorted(self.table_bbox.keys(), key=lambda x: x[1], reverse=True) + ): + cols, rows = self._generate_columns_and_rows(table_idx, tk) + table = self._generate_table(table_idx, cols, rows) + table._bbox = tk + _tables.append(table) + + return _tables diff --git a/camelot/parsers/stream_ocr.py b/camelot/parsers/stream_ocr.py new file mode 100644 index 0000000..1ae6094 --- /dev/null +++ b/camelot/parsers/stream_ocr.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .base import BaseParser + + +class StreamOCR(BaseParser): + pass diff --git a/setup.py b/setup.py index b1ac666..63db662 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,10 @@ cv_requires = [ 'opencv-python>=3.4.2.17' ] +ocr_requires = [ + 'easyocr>=1.1.10' +] + plot_requires = [ 'matplotlib>=2.2.3', ] @@ -40,7 +44,7 @@ dev_requires = [ 'Sphinx>=3.1.2' ] -all_requires = cv_requires + plot_requires +all_requires = cv_requires + ocr_requires + plot_requires dev_requires = dev_requires + all_requires @@ -57,10 +61,11 @@ def setup_package(): packages=find_packages(exclude=('tests',)), install_requires=requires, extras_require={ - 'all': all_requires, 'cv': cv_requires, + 'ocr': ocr_requires, + 'plot': plot_requires, + 'all': all_requires, 'dev': dev_requires, - 'plot': plot_requires }, entry_points={ 'console_scripts': [ diff --git a/tests/files/foo_image.pdf b/tests/files/foo_image.pdf new file mode 100644 index 0000000..c73244b Binary files /dev/null and b/tests/files/foo_image.pdf differ