Skip to content
81 changes: 61 additions & 20 deletions pix2tex/dataset/latex2png.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import io
import glob
import tempfile
import shlex
import subprocess
import traceback
from PIL import Image


Expand All @@ -27,6 +29,8 @@ def __init__(self, math, dpi=250, font='Latin Modern Math'):
self.math = math
self.dpi = dpi
self.font = font
self.prefix_line = self.BASE.split("\n").index(
"%s") # used for calculate error formula index

def write(self, return_bytes=False):
# inline = bool(re.match('^\$[^$]*\$$', self.math)) and False
Expand All @@ -39,8 +43,9 @@ def write(self, return_bytes=False):
# print(document)
f.write(document)

png = self.convert_file(texfile, workdir, return_bytes=return_bytes)
return png
png, error_index = self.convert_file(
texfile, workdir, return_bytes=return_bytes)
return png, error_index

finally:
if os.path.exists(texfile):
Expand All @@ -53,51 +58,64 @@ def convert_file(self, infile, workdir, return_bytes=False):

try:
# Generate the PDF file
cmd = 'xelatex -halt-on-error -output-directory %s %s' % (workdir, infile)
# not stop on error line, but return error line index,index start from 1
cmd = 'xelatex -interaction nonstopmode -file-line-error -output-directory %s %s' % (
workdir, infile)

p = subprocess.Popen(
cmd,
shell=True,
shlex.split(cmd),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
)
sout, serr = p.communicate()
# Something bad happened, abort
if p.returncode != 0:
raise Exception('latex error', serr, sout)

# extract error line from sout
error_index, _ = extract(text=sout, expression="%s:(\d+)" % infile)
# extract success rendered equation
if error_index != []:
# offset index start from 0, same as self.math
error_index = [int(_)-self.prefix_line-1 for _ in error_index]
# Convert the PDF file to PNG's
pdffile = infile.replace('.tex', '.pdf')
result, _ = extract(
text=sout, expression="Output written on %s \((.*)? pages\)" % pdffile)
if int(result[0]) != len(self.math):
raise Exception('xelatex rendering error, generated %d formula\'s page, but the total number of formulas is %d.' % (
int(result[0]), len(self.math)))
pngfile = os.path.join(workdir, infile.replace('.tex', '.png'))

cmd = 'magick convert -density %i -colorspace gray %s -quality 90 %s' % (
cmd = 'convert -density %i -colorspace gray %s -quality 90 %s' % (
self.dpi,
pdffile,
pngfile,
) # -bg Transparent -z 9
p = subprocess.Popen(
cmd,
shell=True,
shlex.split(cmd),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

sout, serr = p.communicate()
if p.returncode != 0:
raise Exception('PDFpng error', serr, cmd, os.path.exists(pdffile), os.path.exists(infile))
raise Exception('PDFpng error', serr, cmd, os.path.exists(
pdffile), os.path.exists(infile))
if return_bytes:
if len(self.math) > 1:
png = [open(pngfile.replace('.png', '')+'-%i.png' % i, 'rb').read() for i in range(len(self.math))]
png = [open(pngfile.replace('.png', '')+'-%i.png' %
i, 'rb').read() for i in range(len(self.math))]
else:
png = [open(pngfile.replace('.png', '')+'.png', 'rb').read()]
return png
png = [open(pngfile.replace(
'.png', '')+'.png', 'rb').read()]
else:
# return path
if len(self.math) > 1:
return [(pngfile.replace('.png', '')+'-%i.png' % i) for i in range(len(self.math))]
png = [(pngfile.replace('.png', '')+'-%i.png' % i)
for i in range(len(self.math))]
else:
return (pngfile.replace('.png', '')+'.png')
png = [(pngfile.replace('.png', '')+'.png')]
return png, error_index
finally:
# Cleanup temporaries
basefile = infile.replace('.tex', '')
Expand All @@ -122,9 +140,32 @@ def tex2png(eq, **kwargs):


def tex2pil(tex, **kwargs):
pngs = Latex(tex, **kwargs).write(return_bytes=True)
pngs, error_index = Latex(tex, **kwargs).write(return_bytes=True)
images = [Image.open(io.BytesIO(d)) for d in pngs]
return images
return images, error_index


def extract(text, expression=None, type: str = None):
"""extract text from text by regular expression

Args:
text (str): input text
expression (str, optional): regular expression. Defaults to None.
type (str, optional): type of extracted text. Defaults to None.

Returns:
str: extracted text
"""
if type is not None:
type2expression = {"en": r"[a-zA-Z]+", "zh": r"[\u4e00-\u9fa5]+", "num": r"\d+",
"punctuation": u"[\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]"}
expression = type2expression[type]
try:
pattern = re.compile(expression)
results = re.findall(pattern, text)
return results, True if len(results) != 0 else False
except Exception:
traceback.print_exc()


if __name__ == '__main__':
Expand Down
178 changes: 116 additions & 62 deletions pix2tex/dataset/render.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pix2tex.dataset.latex2png import *

from pix2tex.dataset.latex2png import Latex, tex2pil
import argparse
import sys
import os
Expand All @@ -7,112 +8,165 @@
from tqdm.auto import tqdm
import cv2
import numpy as np
from PIL import Image
import traceback
import subprocess
import shlex


def render_dataset(dataset: np.ndarray, names: np.ndarray, args):
'''Renders a list of tex equations
def get_installed_fonts(tex_path: str):
cmd = "find %s -name *Math*.otf" % tex_path
process = subprocess.Popen(shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
)
stdout, stderr = process.communicate()
if process.returncode != 0:
raise Exception(stderr)
fonts = [_.split(os.sep)[-1] for _ in stdout.split('\n')][:-1]
fonts.append("Latin Modern Math")
return fonts


def render_dataset(dataset: np.ndarray, unrenders: np.ndarray, args):
'''Renders a list of tex equations
Args:
dataset (numpy.ndarray): List of equations
names (numpy.ndarray): List of integers of size `dataset` that give the name of the saved image
unrenders (numpy.ndarray): List of integers of size `dataset` that give the name of the saved image
args (Union[Namespace, Munch]): additional arguments: mode (equation or inline), out (output directory), divable (common factor )
batchsize (how many samples to render at once), dpi, font (Math font), preprocess (crop, alpha off)
shuffle (bool)
Returns:
list: equation indices that could not be rendered.
'''
assert len(names) == len(dataset), 'names and dataset must be of equal size'
assert len(unrenders) == len(
dataset), 'unrenders and dataset must be of equal size'
math_mode = '$$'if args.mode == 'equation' else '$'
os.makedirs(args.out, exist_ok=True)
indices = np.array([int(os.path.basename(img).split('.')[0]) for img in glob.glob(os.path.join(args.out, '*.png'))])

valid = [i for i, j in enumerate(names) if j not in indices]
# remove successfully rendered equations
rendered = np.array([int(os.path.basename(img).split('.')[0])
for img in glob.glob(os.path.join(args.out, '*.png'))])
valid = [i for i, j in enumerate(unrenders) if j not in rendered]
# update unrenders and dataset
dataset = dataset[valid]
names = names[valid]
order = np.random.permutation(len(dataset)) if args.shuffle else np.arange(len(dataset))
unrenders = unrenders[valid]
order = np.random.permutation(
len(dataset)) if args.shuffle else np.arange(len(dataset))
faulty = []
for i in tqdm(range(0, len(dataset), args.batchsize)):
batch = dataset[order[i:i+args.batchsize]]
for batch_offset in tqdm(range(0, len(dataset), args.batchsize), desc="global batch index"):
batch = dataset[order[batch_offset:batch_offset+args.batchsize]]
#batch = [x for j, x in enumerate(batch) if order[i+j] not in indices]
if len(batch) == 0:
continue
math = [math_mode+x+math_mode for x in batch if x != '']
valid_math = np.asarray([[i, "%s %s %s" % (math_mode, x, math_mode)] for i, x in enumerate(
batch) if x != ''], dtype=object) # space used to prevent escape $
#print('\n', i, len(math), '\n'.join(math))
if len(args.font) > 1:
font = np.random.choice(args.font)
else:
font = args.font[0]
if len(args.dpi) > 1:
dpi = np.random.choice(np.arange(min(args.dpi), max(args.dpi)))
else:
dpi = args.dpi[0]
if len(math) > 0:
font = font = np.random.choice(args.font) if len(
args.font) > 1 else args.font[0]
dpi = np.random.choice(np.arange(min(args.dpi), max(args.dpi))) if len(
args.dpi) > 1 else args.dpi[0]
if len(valid_math) > 0:
valid_idx, math = valid_math.T
valid_idx = valid_idx.astype(np.int32)
try:
if args.preprocess:
pngs = tex2pil(math, dpi=dpi, font=font)
pngs, error_index = tex2pil(math, dpi=dpi, font=font)
else:
pngs = Latex(math, dpi=dpi, font=font).write(return_bytes=False)
pngs, error_index = Latex(math, dpi=dpi, font=font).write(
return_bytes=False)
# error_index not count "" line, use valid_idx transfer to real index matching in batch index
local_error_index = valid_idx[error_index]
# tranfer in batch index to global batch index
global_error_index = [
batch_offset+_ for _ in local_error_index]
faulty.extend(list(unrenders[order[global_error_index]]))
except Exception as e:
#print(e)
#print(math)
#raise e
faulty.extend(list(names[order[i:i+args.batchsize]]))
print("\n%s" % e, end='')
faulty.extend(
list(unrenders[order[batch_offset:batch_offset+args.batchsize]]))
continue

for j, k in enumerate(range(i, i+len(pngs))):
outpath = os.path.join(args.out, '%07d.png' % names[order[k]])
for inbatch_idx, order_idx in enumerate(range(batch_offset, batch_offset+args.batchsize)):
# exclude render failed equations and blank line
if inbatch_idx in local_error_index or inbatch_idx not in valid_idx:
continue
outpath = os.path.join(args.out, '%07d.png' %
unrenders[order[order_idx]])
png_idx = np.where(valid_idx == inbatch_idx)[0][0]
if args.preprocess:
try:
data = np.asarray(pngs[j])
data = np.asarray(pngs[png_idx])
# print(data.shape)
gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
# To invert the text to white
gray = 255*(data[..., 0] < 128).astype(np.uint8)
white_pixels = np.sum(gray == 255)
# some png will be whole white, because some equation's syntax is wrong
# eg.$$ \mathit { \Iota \Kappa \Lambda \Mu \Nu \Xi \Omicron \Pi } $$
# extract from wikipedia english dump file https://dumps.wikimedia.org/enwiki/latest/
white_percentage = (
white_pixels / (gray.shape[0] * gray.shape[1]))
if white_percentage == 0:
continue
# Find all non-zero points (text)
coords = cv2.findNonZero(gray)
# Find minimum spanning bounding box
a, b, w, h = cv2.boundingRect(coords)
rect = data[b:b+h, a:a+w]
im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L')
im = Image.fromarray(
(255-rect[..., -1]).astype(np.uint8)).convert('L')
dims = []
for x in [w, h]:
div, mod = divmod(x, args.divable)
dims.append(args.divable*(div + (1 if mod > 0 else 0)))
dims.append(
args.divable*(div + (1 if mod > 0 else 0)))
padded = Image.new('L', dims, 255)
padded.paste(im, im.getbbox())
padded.paste(im, (0, 0, im.size[0], im.size[1]))
padded.save(outpath)
except Exception as e:
print(e)
pass
else:
shutil.move(pngs[j], outpath)

shutil.move(pngs[png_idx], outpath)
# prevent repeat between two error_index and imagemagic error
faulty = list(set(faulty))
faulty.sort()
return np.array(faulty)


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Render dataset')
parser.add_argument('-i', '--data', type=str, required=True, help='file of list of latex code')
parser.add_argument('-o', '--out', type=str, required=True, help='output directory')
parser.add_argument('-b', '--batchsize', type=int, default=100, help='How many equations to render at once')
parser.add_argument('-f', '--font', nargs='+', type=str, default=['Latin Modern Math', 'GFSNeohellenicMath.otf', 'Asana Math', 'XITS Math',
'Cambria Math', 'Latin Modern Math', 'Latin Modern Math', 'Latin Modern Math'], help='font to use. default = Latin Modern Math')
parser.add_argument('-m', '--mode', choices=['inline', 'equation'], default='equation', help='render as inline or equation')
parser.add_argument('--dpi', type=int, default=[110, 170], nargs='+', help='dpi range to render in')
parser.add_argument('-p', '--no-preprocess', dest='preprocess', default=True, action='store_false', help='crop, remove alpha channel, padding')
parser.add_argument('-d', '--divable', type=int, default=32, help='To what factor to pad the images')
parser.add_argument('-s', '--shuffle', action='store_true', help='Whether to shuffle the equations in the first iteration')
parser.add_argument('-i', '--data', type=str,
required=True, help='file of list of latex code')
parser.add_argument('-o', '--out', type=str,
required=True, help='output directory')
parser.add_argument('-b', '--batchsize', type=int, default=100,
help='How many equations to render at once')
parser.add_argument('-f', '--font', nargs='+', type=str,
default="", help='font to use.')
parser.add_argument('-fp', '--fonts_path', type=str,
default="/usr/local/texlive/", help='installed font path')
parser.add_argument('-m', '--mode', choices=[
'inline', 'equation'], default='equation', help='render as inline or equation')
parser.add_argument(
'--dpi', type=int, default=[110, 170], nargs='+', help='dpi range to render in')
parser.add_argument('-p', '--no-preprocess', dest='preprocess', default=True,
action='store_false', help='crop, remove alpha channel, padding')
parser.add_argument('-d', '--divable', type=int, default=32,
help='To what factor to pad the images')
parser.add_argument('-s', '--shuffle', action='store_true',
help='Whether to shuffle the equations in the first iteration')
args = parser.parse_args(sys.argv[1:])

args.font = args.font if args.font != "" else get_installed_fonts(
args.fonts_path)
print(args.font)
dataset = np.array(open(args.data, 'r').read().split('\n'), dtype=object)
names = np.arange(len(dataset))
prev_names = None
for i in range(12):
if len(names) == 0:
break
prev_names = names
names = render_dataset(dataset[names], names, args)
same = names == prev_names
if (type(same) == bool and same) or (type(same) == np.ndarray and same.all()) or (args.batchsize == 1):
break
if len(names) < 50*args.batchsize:
unrenders = np.arange(len(dataset))
failed = np.array([])
while unrenders.tolist() != failed.tolist():
failed = unrenders
unrenders = render_dataset(dataset[unrenders], unrenders, args)
if len(unrenders) < 50*args.batchsize:
args.batchsize = max([1, args.batchsize//2])
args.shuffle = True