issue FloatingPointError: invalid value encountered in double_scalars - wanghaisheng/awesome-ocr GitHub Wiki
https://github.com/tmbdev/ocropy/issues/93
root@4b2648975d03:/ocropy# cat ocropus-rtrain
#!/usr/bin/env python
# vim: set fileencoding=utf8 :
import sys
import random as pyrandom
import re
from pylab import *
import os.path
import ocrolib
import argparse
import matplotlib
import numpy
from ocrolib import lineest
import ocrolib.lstm as lstm
import traceback
numpy.seterr(divide='raise',over='raise',invalid='raise',under='ignore')
reload(sys)
sys.setdefaultencoding('utf-8')
parser = argparse.ArgumentParser("train an RNN recognizer")
# line normalization
parser.add_argument("-e","--lineest",default="center",
help="type of text line estimator, default: %(default)s")
parser.add_argument("-E","--nolineest",action="store_true",
help="don't perform line estimation and load .dew.png file")
parser.add_argument("-l","--height",default=48,type=int,
help="set the default height for line estimation, default: %(default)s")
parser.add_argument("--dewarp",action="store_true",
help="only perform line estimation and output .dew.png file")
# character set
parser.add_argument("-c","--codec",default=[],nargs='*',
help="construct a codec from the input text")
# learning
parser.add_argument("-C","--clstm",action="store_true",
help="use C++ LSTM")
parser.add_argument("-r","--lrate",type=float,default=1e-4,
help="LSTM learning rate, default: %(default)s")
parser.add_argument("-S","--hiddensize",type=int,default=100,
help="# LSTM state units, default: %(default)s")
parser.add_argument("-o","--output",default=None,
help="LSTM model file")
parser.add_argument("-F","--savefreq",type=int,default=1000,
help="LSTM save frequency, default: %(default)s")
parser.add_argument("--strip",action="store_false",
help="strip the model before saving")
parser.add_argument("-N","--ntrain",type=int,default=1000000,
help="# lines to train before stopping, default: %(default)s")
parser.add_argument("-t","--tests",default=None,
help="test cases for error estimation")
parser.add_argument('--unidirectional',action="store_true",
help="use only unidirectional LSTM")
parser.add_argument("--updates",action="store_true",
help="verbose LSTM updates")
parser.add_argument('--load',default=None,
help="start training with a previously trained model")
parser.add_argument('--start',default=-1,type=int)
# debugging
parser.add_argument("-X","--exec",default="None",dest="execute",
help="execute before anything else (usually used for imports)")
parser.add_argument("-v","--verbose",action="store_true")
parser.add_argument("-d","--display",type=int,default=0,
help="display output for every nth iteration, where n=DISPLAY, default: %(default)s")
parser.add_argument("-m","--movie",default=None)
parser.add_argument("-M","--moviesample",default=None)
parser.add_argument("-q","--quiet",action="store_true")
parser.add_argument("-Q","--nocheck",action="store_true")
parser.add_argument("-p","--pad",type=int,default=16)
parser.add_argument("files",nargs="*")
args = parser.parse_args()
inputs = ocrolib.glob_all(args.files)
if len(inputs)==0:
parser.print_help()
sys.exit(0)
print "# inputs",len(inputs)
# pre-execute any python commands
exec args.execute
# make sure movie mode is used correctly
if args.movie is not None:
if args.display<2:
print "you must set --display to some number greater than 1"
sys.exit(0)
if args.moviesample is None:
args.moviesample = inputs[0]
# make sure an output file has been set
if args.output is None:
print "you must give an output file with %d in it, or a prefix"
sys.exit(0)
if not "%" in args.output:
if args.clstm:
oname = args.output+"-%08d.h5"
else:
oname = args.output+"-%08d.pyrnn"
else:
oname = args.output
# get a separate test set, if present
tests = None
if args.tests is not None:
tests = ocrolib.glob_all(args.tests.split(":"))
print "# tests",(len(tests) if tests is not None else "None")
# load the line normalizer
if args.lineest=="center":
lnorm = lineest.CenterNormalizer()
else:
raise Exception(args.lineest+": unknown line normalizer")
lnorm.setHeight(args.height)
# The `codec` maps between strings and arrays of integers.
if args.codec!=[]:
print "# building codec"
codec = lstm.Codec()
charset = set()
print args.codec
for fname in ocrolib.glob_all(args.codec):
base,_ = ocrolib.allsplitext(fname)
try:
line = ocrolib.read_image_gray(fname)
transcript = ocrolib.read_text(base+".gt.txt").encode('utf-8') ### <- HERE
except IOError as e:
print "ERROR",e
continue
#fix anscii issue
# transcript = ocrolib.read_text(fname).encode('utf-8')
l = list(lstm.normalize_nfkc(transcript))
charset = charset.union(l)
charset = sorted(list(charset))
charset = [c for c in charset if c>" " and c!="~"]
else:
print "# using default codec"
charset = sorted(list(set(list(lstm.ascii_labels) + list(ocrolib.chars.default))))
charset = [""," ","~",]+[c for c in charset if c not in [" ","~"]]
print "# charset size",len(charset),
if len(charset)<200:
print ("["+"".join(charset)+"]").encode('utf-8').strip()
else:
s = ("".join(charset)).encode('utf-8').strip()
print "["+s[:20],"...",s[-20:]+"]"
codec = lstm.Codec().init(charset)
# Load an existing network or construct a new one
# Somewhat convoluted logic for dealing with old style Python
# modules and new style C++ LSTM networks.
def save_lstm(fname,network):
if args.clstm:
network.lstm.save(fname)
else:
if args.strip:
network.clear_log()
for x in network.walk(): x.preSave()
ocrolib.save_object(fname,network)
if args.strip:
for x in network.walk(): x.postLoad()
def load_lstm(fname):
if args.clstm:
network = lstm.SeqRecognizer(args.height,args.hiddensize,
codec=codec,
normalize=lstm.normalize_nfkc)
import clstm
mylstm = clstm.make_BIDILSTM()
mylstm.init(network.No,args.hiddensize,network.Ni)
mylstm.load(fname)
network.lstm = clstm.CNetwork(mylstm)
return network
else:
network = ocrolib.load_object(last_save)
network.upgrade()
for x in network.walk(): x.postLoad()
return network
if args.load:
print "# loading",args.load
last_save = args.load
network = load_lstm(args.load)
else:
last_save = None
network = lstm.SeqRecognizer(args.height,args.hiddensize,
codec=codec,
normalize=lstm.normalize_nfkc)
if args.clstm:
import clstm
mylstm = clstm.make_BIDILSTM()
mylstm.init(network.No,args.hiddensize,network.Ni)
network.lstm = clstm.CNetwork(mylstm)
if getattr(network,"lnorm",None) is None:
network.lnorm = lnorm
network.upgrade()
if network.last_trial%100==99: network.last_trial += 1
print "# last_trial",network.last_trial
# set up the learning rate
network.setLearningRate(args.lrate,0.9)
if args.updates: network.lstm.verbose = 1
# used for plotting
ion()
matplotlib.rc('xtick',labelsize=7)
matplotlib.rc('ytick',labelsize=7)
matplotlib.rcParams.update({"font.size":7})
def cleandisp(s):
return re.sub('[$]',r'#',s)
def plot_network_info(network,transcript,pred,gta):
subplot(511)
imshow(line.T,cmap=cm.gray)
title(cleandisp(transcript))
subplot(512)
gca().set_xticks([])
imshow(network.outputs.T[1:],vmin=0,cmap=cm.hot)
title(cleandisp(pred[:len(transcript)]))
subplot(513)
imshow(network.aligned.T[1:],vmin=0,cmap=cm.hot)
title(cleandisp(gta[:len(transcript)]))
subplot(514)
plot(network.outputs[:,0],color='yellow',linewidth=3,alpha=0.5)
plot(network.outputs[:,1],color='green',linewidth=3,alpha=0.5)
plot(amax(network.outputs[:,2:],axis=1),color='blue',linewidth=3,alpha=0.5)
plot(network.aligned[:,0],color='orange',linestyle='dashed',alpha=0.7)
plot(network.aligned[:,1],color='green',linestyle='dashed',alpha=0.5)
plot(amax(network.aligned[:,2:],axis=1),color='blue',linestyle='dashed',alpha=0.5)
subplot(515)
gca().set_yscale('log')
r = 10000
errs = network.errors(range=r,smooth=100)
xs = arange(len(errs))+network.last_trial-len(errs)
plot(xs,errs,color='black')
plot(xs,network.errors(range=r),color='black',alpha=0.4)
plot(xs,network.cerrors(range=r,smooth=100),color='red',linestyle='dashed')
start = args.start if args.start>=0 else network.last_trial
for trial in range(start,args.ntrain):
network.last_trial = trial+1
do_display = (args.display>0 and trial%args.display==0)
do_update = 1
if args.movie and do_display:
fname = args.moviesample
do_update = 0
else:
fname = pyrandom.sample(inputs,1)[0]
base,_ = ocrolib.allsplitext(fname)
try:
line = ocrolib.read_image_gray(fname)
transcript = ocrolib.read_text(base+".gt.txt")
except IOError as e:
print "ERROR",e
continue
if not args.nolineest:
assert "dew.png" not in fname,"don't dewarp already dewarped lines"
network.lnorm.measure(amax(line)-line)
line = network.lnorm.normalize(line,cval=amax(line))
else:
assert "dew.png" in fname,"input must already be dewarped"
if line.size<10 or amax(line)==amin(line):
print "EMPTY-INPUT"
continue
line = line * 1.0/amax(line)
line = amax(line)-line
line = line.T
if args.pad>0:
w = line.shape[1]
line = vstack([zeros((args.pad,w)),line,zeros((args.pad,w))])
cs = array(codec.encode(transcript),'i')
try:
pcs = network.trainSequence(line,cs,update=do_update,key=fname)
except FloatingPointError as e:
print "# oops, got FloatingPointError",e
traceback.print_exc()
network = load_lstm(last_save)
continue
except lstm.RangeError as e:
continue
pred = "".join(codec.decode(pcs))
acs = lstm.translate_back(network.aligned)
gta = "".join(codec.decode(acs))
if not args.quiet:
print "%d %.2f %s"%(trial,network.error,line.shape),fname
print " TRU:",repr(transcript)
print " ALN:",repr(gta[:len(transcript)+5])
print " OUT:",repr(pred[:len(transcript)+5])
pred = re.sub(' ','_',pred)
gta = re.sub(' ','_',gta)
if (trial+1)%args.savefreq==0:
ofile = oname%(trial+1)+".gz"
print "# saving",ofile
save_lstm(ofile,network)
last_save = ofile
if do_display:
figure("training",figsize=(1400//75,800//75),dpi=75)
clf()
gcf().canvas.set_window_title(args.output)
plot_network_info(network,transcript,pred,gta)
ginput(1,0.01)
if args.movie is not None:
draw()
savefig("%s-%08d.png"%(args.movie,trial),bbox_inches=0)