Compare commits
5 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
6612a2f58d | |
|
|
1287abeeaf | |
|
|
eecc0df1ac | |
|
|
0183f8f462 | |
|
|
674b5f4336 |
|
|
@ -6,7 +6,7 @@ import sys
|
|||
from PyPDF2 import PdfFileReader, PdfFileWriter
|
||||
|
||||
from .core import TableList
|
||||
from .parsers import Stream, Lattice
|
||||
from .parsers import Lattice, Stream, LatticeOCR, StreamOCR
|
||||
from .utils import (
|
||||
TemporaryDirectory,
|
||||
get_page_layout,
|
||||
|
|
@ -163,14 +163,19 @@ class PDFHandler(object):
|
|||
List of tables found in PDF.
|
||||
|
||||
"""
|
||||
parsers = {
|
||||
"lattice": Lattice,
|
||||
"stream": Stream,
|
||||
"lattice_ocr": LatticeOCR,
|
||||
"stream_ocr": StreamOCR,
|
||||
}
|
||||
|
||||
tables = []
|
||||
with TemporaryDirectory() as tempdir:
|
||||
for p in self.pages:
|
||||
self._save_page(self.filepath, p, tempdir)
|
||||
pages = [
|
||||
os.path.join(tempdir, f"page-{p}.pdf") for p in self.pages
|
||||
]
|
||||
parser = Lattice(**kwargs) if flavor == "lattice" else Stream(**kwargs)
|
||||
pages = [os.path.join(tempdir, f"page-{p}.pdf") for p in self.pages]
|
||||
parser = parsers[flavor](**kwargs)
|
||||
for p in pages:
|
||||
t = parser.extract_tables(
|
||||
p, suppress_stdout=suppress_stdout, layout_kwargs=layout_kwargs
|
||||
|
|
|
|||
|
|
@ -98,9 +98,10 @@ def read_pdf(
|
|||
tables : camelot.core.TableList
|
||||
|
||||
"""
|
||||
if flavor not in ["lattice", "stream"]:
|
||||
if flavor not in ["lattice", "stream", "lattice_ocr", "stream_ocr"]:
|
||||
raise NotImplementedError(
|
||||
"Unknown flavor specified." " Use either 'lattice' or 'stream'"
|
||||
"Unknown flavor specified. Use one of the following: "
|
||||
"'lattice', 'stream', 'lattice_ocr', 'stream_ocr'"
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .stream import Stream
|
||||
from .lattice import Lattice
|
||||
from .stream import Stream
|
||||
from .lattice_ocr import LatticeOCR
|
||||
from .stream_ocr import StreamOCR
|
||||
|
|
|
|||
|
|
@ -0,0 +1,243 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import easyocr
|
||||
except ImportError:
|
||||
_HAS_EASYOCR = False
|
||||
else:
|
||||
_HAS_EASYOCR = True
|
||||
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
from .base import BaseParser
|
||||
from ..core import Table
|
||||
from ..utils import TemporaryDirectory, merge_close_lines, scale_pdf, segments_in_bbox
|
||||
from ..image_processing import (
|
||||
adaptive_threshold,
|
||||
find_lines,
|
||||
find_contours,
|
||||
find_joints,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("camelot")
|
||||
|
||||
|
||||
class LatticeOCR(BaseParser):
|
||||
def __init__(
|
||||
self,
|
||||
table_regions=None,
|
||||
table_areas=None,
|
||||
line_scale=15,
|
||||
line_tol=2,
|
||||
joint_tol=2,
|
||||
threshold_blocksize=15,
|
||||
threshold_constant=-2,
|
||||
iterations=0,
|
||||
resolution=300,
|
||||
):
|
||||
self.table_regions = table_regions
|
||||
self.table_areas = table_areas
|
||||
self.line_scale = line_scale
|
||||
self.line_tol = line_tol
|
||||
self.joint_tol = joint_tol
|
||||
self.threshold_blocksize = threshold_blocksize
|
||||
self.threshold_constant = threshold_constant
|
||||
self.iterations = iterations
|
||||
self.resolution = resolution
|
||||
|
||||
if _HAS_EASYOCR:
|
||||
self.reader = easyocr.Reader(['en'], gpu=False)
|
||||
else:
|
||||
raise ImportError("easyocr is required to run OCR on image-based PDFs.")
|
||||
|
||||
def _generate_image(self):
|
||||
from ..ext.ghostscript import Ghostscript
|
||||
|
||||
self.imagename = "".join([self.rootname, ".png"])
|
||||
gs_call = "-q -sDEVICE=png16m -o {} -r900 {}".format(
|
||||
self.imagename, self.filename
|
||||
)
|
||||
gs_call = gs_call.encode().split()
|
||||
null = open(os.devnull, "wb")
|
||||
with Ghostscript(*gs_call, stdout=null) as gs:
|
||||
pass
|
||||
null.close()
|
||||
|
||||
def _generate_table_bbox(self):
|
||||
def scale_areas(areas, scalers):
|
||||
scaled_areas = []
|
||||
for area in areas:
|
||||
x1, y1, x2, y2 = area.split(",")
|
||||
x1 = float(x1)
|
||||
y1 = float(y1)
|
||||
x2 = float(x2)
|
||||
y2 = float(y2)
|
||||
x1, y1, x2, y2 = scale_pdf((x1, y1, x2, y2), scalers)
|
||||
scaled_areas.append((x1, y1, abs(x2 - x1), abs(y2 - y1)))
|
||||
return scaled_areas
|
||||
|
||||
self.image, self.threshold = adaptive_threshold(
|
||||
self.imagename, blocksize=self.threshold_blocksize, c=self.threshold_constant
|
||||
)
|
||||
|
||||
image_width = self.image.shape[1]
|
||||
image_height = self.image.shape[0]
|
||||
image_width_scaler = image_width / float(self.pdf_width)
|
||||
image_height_scaler = image_height / float(self.pdf_height)
|
||||
image_scalers = (image_width_scaler, image_height_scaler, self.pdf_height)
|
||||
|
||||
if self.table_areas is None:
|
||||
regions = None
|
||||
if self.table_regions is not None:
|
||||
regions = scale_areas(self.table_regions, image_scalers)
|
||||
|
||||
vertical_mask, vertical_segments = find_lines(
|
||||
self.threshold,
|
||||
regions=regions,
|
||||
direction="vertical",
|
||||
line_scale=self.line_scale,
|
||||
iterations=self.iterations,
|
||||
)
|
||||
horizontal_mask, horizontal_segments = find_lines(
|
||||
self.threshold,
|
||||
regions=regions,
|
||||
direction="horizontal",
|
||||
line_scale=self.line_scale,
|
||||
iterations=self.iterations,
|
||||
)
|
||||
|
||||
contours = find_contours(vertical_mask, horizontal_mask)
|
||||
table_bbox = find_joints(contours, vertical_mask, horizontal_mask)
|
||||
else:
|
||||
vertical_mask, vertical_segments = find_lines(
|
||||
self.threshold,
|
||||
direction="vertical",
|
||||
line_scale=self.line_scale,
|
||||
iterations=self.iterations,
|
||||
)
|
||||
horizontal_mask, horizontal_segments = find_lines(
|
||||
self.threshold,
|
||||
direction="horizontal",
|
||||
line_scale=self.line_scale,
|
||||
iterations=self.iterations,
|
||||
)
|
||||
|
||||
areas = scale_areas(self.table_areas, image_scalers)
|
||||
table_bbox = find_joints(areas, vertical_mask, horizontal_mask)
|
||||
|
||||
self.table_bbox_unscaled = copy.deepcopy(table_bbox)
|
||||
|
||||
self.table_bbox = table_bbox
|
||||
self.vertical_segments = vertical_segments
|
||||
self.horizontal_segments = horizontal_segments
|
||||
|
||||
def _generate_columns_and_rows(self, table_idx, tk):
|
||||
cols, rows = zip(*self.table_bbox[tk])
|
||||
cols, rows = list(cols), list(rows)
|
||||
cols.extend([tk[0], tk[2]])
|
||||
rows.extend([tk[1], tk[3]])
|
||||
# sort horizontal and vertical segments
|
||||
cols = merge_close_lines(sorted(cols), line_tol=self.line_tol)
|
||||
rows = merge_close_lines(sorted(rows), line_tol=self.line_tol)
|
||||
# make grid using x and y coord of shortlisted rows and cols
|
||||
cols = [(cols[i], cols[i + 1]) for i in range(0, len(cols) - 1)]
|
||||
rows = [(rows[i], rows[i + 1]) for i in range(0, len(rows) - 1)]
|
||||
|
||||
return cols, rows
|
||||
|
||||
|
||||
def _generate_table(self, table_idx, cols, rows, **kwargs):
|
||||
table = Table(cols, rows)
|
||||
# set table edges to True using ver+hor lines
|
||||
table = table.set_edges(self.vertical_segments, self.horizontal_segments, joint_tol=self.joint_tol)
|
||||
# set table border edges to True
|
||||
table = table.set_border()
|
||||
# set spanning cells to True
|
||||
table = table.set_span()
|
||||
|
||||
_seen = set()
|
||||
for r_idx in range(len(table.cells)):
|
||||
for c_idx in range(len(table.cells[r_idx])):
|
||||
if (r_idx, c_idx) in _seen:
|
||||
continue
|
||||
|
||||
_seen.add((r_idx, c_idx))
|
||||
|
||||
_r_idx = r_idx
|
||||
_c_idx = c_idx
|
||||
|
||||
if table.cells[r_idx][_c_idx].hspan:
|
||||
while not table.cells[r_idx][_c_idx].right:
|
||||
_c_idx += 1
|
||||
_seen.add((r_idx, _c_idx))
|
||||
|
||||
if table.cells[_r_idx][c_idx].vspan:
|
||||
while not table.cells[_r_idx][c_idx].bottom:
|
||||
_r_idx += 1
|
||||
_seen.add((_r_idx, c_idx))
|
||||
|
||||
for i in range(r_idx, _r_idx + 1):
|
||||
for j in range(c_idx, _c_idx + 1):
|
||||
_seen.add((i, j))
|
||||
|
||||
x1 = int(table.cells[r_idx][c_idx].x1)
|
||||
y1 = int(table.cells[_r_idx][_c_idx].y1)
|
||||
|
||||
x2 = int(table.cells[_r_idx][_c_idx].x2)
|
||||
y2 = int(table.cells[r_idx][c_idx].y2)
|
||||
|
||||
with TemporaryDirectory() as tempdir:
|
||||
temp_image_path = os.path.join(tempdir, f"{table_idx}_{r_idx}_{c_idx}.png")
|
||||
|
||||
cell_image = Image.fromarray(self.image[y2:y1, x1:x2])
|
||||
cell_image.save(temp_image_path)
|
||||
|
||||
text = self.reader.readtext(temp_image_path, detail=0)
|
||||
text = " ".join(text)
|
||||
|
||||
table.cells[r_idx][c_idx].text = text
|
||||
|
||||
data = table.data
|
||||
table.df = pd.DataFrame(data)
|
||||
table.shape = table.df.shape
|
||||
|
||||
table.flavor = "lattice_ocr"
|
||||
table.accuracy = 0
|
||||
table.whitespace = 0
|
||||
table.order = table_idx + 1
|
||||
table.page = int(os.path.basename(self.rootname).replace("page-", ""))
|
||||
|
||||
# for plotting
|
||||
table._text = None
|
||||
table._image = (self.image, self.table_bbox_unscaled)
|
||||
table._segments = (self.vertical_segments, self.horizontal_segments)
|
||||
table._textedges = None
|
||||
|
||||
return table
|
||||
|
||||
def extract_tables(self, filename, suppress_stdout=False, layout_kwargs={}):
|
||||
self._generate_layout(filename, layout_kwargs)
|
||||
if not suppress_stdout:
|
||||
logger.info("Processing {}".format(os.path.basename(self.rootname)))
|
||||
|
||||
self._generate_image()
|
||||
self._generate_table_bbox()
|
||||
|
||||
_tables = []
|
||||
# sort tables based on y-coord
|
||||
for table_idx, tk in enumerate(
|
||||
sorted(self.table_bbox.keys(), key=lambda x: x[1], reverse=True)
|
||||
):
|
||||
cols, rows = self._generate_columns_and_rows(table_idx, tk)
|
||||
table = self._generate_table(table_idx, cols, rows)
|
||||
table._bbox = tk
|
||||
_tables.append(table)
|
||||
|
||||
return _tables
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .base import BaseParser
|
||||
|
||||
|
||||
class StreamOCR(BaseParser):
|
||||
pass
|
||||
|
|
@ -48,7 +48,7 @@ class PlotMethods(object):
|
|||
if filename is not None:
|
||||
fig.savefig(filename)
|
||||
return None
|
||||
|
||||
|
||||
return fig
|
||||
|
||||
def text(self, table):
|
||||
|
|
|
|||
|
|
@ -93,7 +93,6 @@ def download_url(url):
|
|||
return filepath
|
||||
|
||||
|
||||
stream_kwargs = ["columns", "edge_tol", "row_tol", "column_tol"]
|
||||
lattice_kwargs = [
|
||||
"process_background",
|
||||
"line_scale",
|
||||
|
|
@ -106,6 +105,7 @@ lattice_kwargs = [
|
|||
"iterations",
|
||||
"resolution",
|
||||
]
|
||||
stream_kwargs = ["columns", "edge_tol", "row_tol", "column_tol"]
|
||||
|
||||
|
||||
def validate_input(kwargs, flavor="lattice"):
|
||||
|
|
@ -116,14 +116,14 @@ def validate_input(kwargs, flavor="lattice"):
|
|||
f"{','.join(sorted(isec))} cannot be used with flavor='{flavor}'"
|
||||
)
|
||||
|
||||
if flavor == "lattice":
|
||||
if flavor in ["lattice", "lattice_ocr"]:
|
||||
check_intersection(stream_kwargs, kwargs)
|
||||
else:
|
||||
check_intersection(lattice_kwargs, kwargs)
|
||||
|
||||
|
||||
def remove_extra(kwargs, flavor="lattice"):
|
||||
if flavor == "lattice":
|
||||
if flavor in ["lattice", "lattice_ocr"]:
|
||||
for key in kwargs.keys():
|
||||
if key in stream_kwargs:
|
||||
kwargs.pop(key)
|
||||
|
|
|
|||
11
setup.py
11
setup.py
|
|
@ -27,6 +27,10 @@ cv_requires = [
|
|||
'opencv-python>=3.4.2.17'
|
||||
]
|
||||
|
||||
ocr_requires = [
|
||||
'easyocr>=1.1.10'
|
||||
]
|
||||
|
||||
plot_requires = [
|
||||
'matplotlib>=2.2.3',
|
||||
]
|
||||
|
|
@ -40,7 +44,7 @@ dev_requires = [
|
|||
'Sphinx>=3.1.2'
|
||||
]
|
||||
|
||||
all_requires = cv_requires + plot_requires
|
||||
all_requires = cv_requires + ocr_requires + plot_requires
|
||||
dev_requires = dev_requires + all_requires
|
||||
|
||||
|
||||
|
|
@ -57,10 +61,11 @@ def setup_package():
|
|||
packages=find_packages(exclude=('tests',)),
|
||||
install_requires=requires,
|
||||
extras_require={
|
||||
'all': all_requires,
|
||||
'cv': cv_requires,
|
||||
'ocr': ocr_requires,
|
||||
'plot': plot_requires,
|
||||
'all': all_requires,
|
||||
'dev': dev_requires,
|
||||
'plot': plot_requires
|
||||
},
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in New Issue