issue FloatingPointError: invalid value encountered in double_scalars - wanghaisheng/awesome-ocr GitHub Wiki

https://github.com/tmbdev/ocropy/issues/93

root@4b2648975d03:/ocropy# cat ocropus-rtrain
#!/usr/bin/env python
# vim: set fileencoding=utf8 :
import sys
import random as pyrandom
import re
from pylab import *
import os.path
import ocrolib
import argparse
import matplotlib
import numpy
from ocrolib import lineest
import ocrolib.lstm as lstm
import traceback

numpy.seterr(divide='raise',over='raise',invalid='raise',under='ignore')

reload(sys)
sys.setdefaultencoding('utf-8')
parser = argparse.ArgumentParser("train an RNN recognizer")

# line normalization
parser.add_argument("-e","--lineest",default="center",
                    help="type of text line estimator, default: %(default)s")
parser.add_argument("-E","--nolineest",action="store_true",
                    help="don't perform line estimation and load .dew.png file")
parser.add_argument("-l","--height",default=48,type=int,
                    help="set the default height for line estimation, default: %(default)s")
parser.add_argument("--dewarp",action="store_true",
                    help="only perform line estimation and output .dew.png file")

# character set
parser.add_argument("-c","--codec",default=[],nargs='*',
                    help="construct a codec from the input text")

# learning
parser.add_argument("-C","--clstm",action="store_true",
                    help="use C++ LSTM")
parser.add_argument("-r","--lrate",type=float,default=1e-4,
                    help="LSTM learning rate, default: %(default)s")
parser.add_argument("-S","--hiddensize",type=int,default=100,
                    help="# LSTM state units, default: %(default)s")
parser.add_argument("-o","--output",default=None,
                    help="LSTM model file")
parser.add_argument("-F","--savefreq",type=int,default=1000,
                    help="LSTM save frequency, default: %(default)s")
parser.add_argument("--strip",action="store_false",
                    help="strip the model before saving")
parser.add_argument("-N","--ntrain",type=int,default=1000000,
                    help="# lines to train before stopping, default: %(default)s")
parser.add_argument("-t","--tests",default=None,
                    help="test cases for error estimation")
parser.add_argument('--unidirectional',action="store_true",
                    help="use only unidirectional LSTM")
parser.add_argument("--updates",action="store_true",
                    help="verbose LSTM updates")
parser.add_argument('--load',default=None,
                    help="start training with a previously trained model")
parser.add_argument('--start',default=-1,type=int)

# debugging
parser.add_argument("-X","--exec",default="None",dest="execute",
                    help="execute before anything else (usually used for imports)")
parser.add_argument("-v","--verbose",action="store_true")
parser.add_argument("-d","--display",type=int,default=0,
                    help="display output for every nth iteration, where n=DISPLAY, default: %(default)s")
parser.add_argument("-m","--movie",default=None)
parser.add_argument("-M","--moviesample",default=None)
parser.add_argument("-q","--quiet",action="store_true")
parser.add_argument("-Q","--nocheck",action="store_true")
parser.add_argument("-p","--pad",type=int,default=16)

parser.add_argument("files",nargs="*")
args = parser.parse_args()

inputs = ocrolib.glob_all(args.files)
if len(inputs)==0:
    parser.print_help()
    sys.exit(0)

print "# inputs",len(inputs)

# pre-execute any python commands

exec args.execute

# make sure movie mode is used correctly

if args.movie is not None:
    if args.display<2:
        print "you must set --display to some number greater than 1"
        sys.exit(0)

if args.moviesample is None:
    args.moviesample = inputs[0]

# make sure an output file has been set

if args.output is None:
    print "you must give an output file with %d in it, or a prefix"
    sys.exit(0)

if not "%" in args.output:
    if args.clstm:
        oname = args.output+"-%08d.h5"
    else:
        oname = args.output+"-%08d.pyrnn"
else:
    oname = args.output

# get a separate test set, if present

tests = None
if args.tests is not None:
    tests = ocrolib.glob_all(args.tests.split(":"))
print "# tests",(len(tests) if tests is not None else "None")

# load the line normalizer

if args.lineest=="center":
  lnorm = lineest.CenterNormalizer()
else:
  raise Exception(args.lineest+": unknown line normalizer")
lnorm.setHeight(args.height)

# The `codec` maps between strings and arrays of integers.

if args.codec!=[]:
    print "# building codec"
    codec = lstm.Codec()
    charset = set()
    print args.codec
    for fname in ocrolib.glob_all(args.codec):
        base,_ = ocrolib.allsplitext(fname)
        try:
           line = ocrolib.read_image_gray(fname)
           transcript = ocrolib.read_text(base+".gt.txt").encode('utf-8')   ### <- HERE
        except IOError as e:
           print "ERROR",e
           continue
#fix anscii issue
       # transcript = ocrolib.read_text(fname).encode('utf-8')
        l = list(lstm.normalize_nfkc(transcript))
        charset = charset.union(l)
    charset = sorted(list(charset))
    charset = [c for c in charset if c>" " and c!="~"]
else:
    print "# using default codec"
    charset = sorted(list(set(list(lstm.ascii_labels) + list(ocrolib.chars.default))))

charset = [""," ","~",]+[c for c in charset if c not in [" ","~"]]
print "# charset size",len(charset),
if len(charset)<200:
    print ("["+"".join(charset)+"]").encode('utf-8').strip()
else:
    s = ("".join(charset)).encode('utf-8').strip()
    print "["+s[:20],"...",s[-20:]+"]"
codec = lstm.Codec().init(charset)

# Load an existing network or construct a new one
# Somewhat convoluted logic for dealing with old style Python
# modules and new style C++ LSTM networks.

def save_lstm(fname,network):
    if args.clstm:
        network.lstm.save(fname)
    else:
        if args.strip:
            network.clear_log()
            for x in network.walk(): x.preSave()
        ocrolib.save_object(fname,network)
        if args.strip:
            for x in network.walk(): x.postLoad()

def load_lstm(fname):
    if args.clstm:
        network = lstm.SeqRecognizer(args.height,args.hiddensize,
            codec=codec,
            normalize=lstm.normalize_nfkc)
        import clstm
        mylstm = clstm.make_BIDILSTM()
        mylstm.init(network.No,args.hiddensize,network.Ni)
        mylstm.load(fname)
        network.lstm = clstm.CNetwork(mylstm)
        return network
    else:
        network = ocrolib.load_object(last_save)
        network.upgrade()
        for x in network.walk(): x.postLoad()
        return network

if args.load:
    print "# loading",args.load
    last_save = args.load
    network = load_lstm(args.load)
else:
    last_save = None
    network = lstm.SeqRecognizer(args.height,args.hiddensize,
        codec=codec,
        normalize=lstm.normalize_nfkc)
    if args.clstm:
        import clstm
        mylstm = clstm.make_BIDILSTM()
        mylstm.init(network.No,args.hiddensize,network.Ni)
        network.lstm = clstm.CNetwork(mylstm)

if getattr(network,"lnorm",None) is None:
    network.lnorm = lnorm

network.upgrade()
if network.last_trial%100==99: network.last_trial += 1
print "# last_trial",network.last_trial


# set up the learning rate

network.setLearningRate(args.lrate,0.9)
if args.updates: network.lstm.verbose = 1

# used for plotting

ion()
matplotlib.rc('xtick',labelsize=7)
matplotlib.rc('ytick',labelsize=7)
matplotlib.rcParams.update({"font.size":7})

def cleandisp(s):
    return re.sub('[$]',r'#',s)

def plot_network_info(network,transcript,pred,gta):
    subplot(511)
    imshow(line.T,cmap=cm.gray)
    title(cleandisp(transcript))
    subplot(512)
    gca().set_xticks([])
    imshow(network.outputs.T[1:],vmin=0,cmap=cm.hot)
    title(cleandisp(pred[:len(transcript)]))
    subplot(513)
    imshow(network.aligned.T[1:],vmin=0,cmap=cm.hot)
    title(cleandisp(gta[:len(transcript)]))
    subplot(514)
    plot(network.outputs[:,0],color='yellow',linewidth=3,alpha=0.5)
    plot(network.outputs[:,1],color='green',linewidth=3,alpha=0.5)
    plot(amax(network.outputs[:,2:],axis=1),color='blue',linewidth=3,alpha=0.5)
    plot(network.aligned[:,0],color='orange',linestyle='dashed',alpha=0.7)
    plot(network.aligned[:,1],color='green',linestyle='dashed',alpha=0.5)
    plot(amax(network.aligned[:,2:],axis=1),color='blue',linestyle='dashed',alpha=0.5)
    subplot(515)
    gca().set_yscale('log')
    r = 10000
    errs = network.errors(range=r,smooth=100)
    xs = arange(len(errs))+network.last_trial-len(errs)
    plot(xs,errs,color='black')
    plot(xs,network.errors(range=r),color='black',alpha=0.4)
    plot(xs,network.cerrors(range=r,smooth=100),color='red',linestyle='dashed')

start = args.start if args.start>=0 else network.last_trial

for trial in range(start,args.ntrain):
    network.last_trial = trial+1

    do_display = (args.display>0 and trial%args.display==0)
    do_update = 1

    if args.movie and do_display:
        fname = args.moviesample
        do_update = 0
    else:
        fname = pyrandom.sample(inputs,1)[0]

    base,_ = ocrolib.allsplitext(fname)
    try:
        line = ocrolib.read_image_gray(fname)
        transcript = ocrolib.read_text(base+".gt.txt")
    except IOError as e:
        print "ERROR",e
        continue

    if not args.nolineest:
        assert "dew.png" not in fname,"don't dewarp already dewarped lines"
        network.lnorm.measure(amax(line)-line)
        line = network.lnorm.normalize(line,cval=amax(line))
    else:
        assert "dew.png" in fname,"input must already be dewarped"

    if line.size<10 or amax(line)==amin(line):
        print "EMPTY-INPUT"
        continue
    line = line * 1.0/amax(line)
    line = amax(line)-line
    line = line.T
    if args.pad>0:
        w = line.shape[1]
        line = vstack([zeros((args.pad,w)),line,zeros((args.pad,w))])
    cs = array(codec.encode(transcript),'i')
    try:
        pcs = network.trainSequence(line,cs,update=do_update,key=fname)
    except FloatingPointError as e:
        print "# oops, got FloatingPointError",e
        traceback.print_exc()
        network = load_lstm(last_save)
        continue
    except lstm.RangeError as e:
        continue
    pred = "".join(codec.decode(pcs))
    acs = lstm.translate_back(network.aligned)
    gta = "".join(codec.decode(acs))
    if not args.quiet:
        print "%d %.2f %s"%(trial,network.error,line.shape),fname
        print "   TRU:",repr(transcript)
        print "   ALN:",repr(gta[:len(transcript)+5])
        print "   OUT:",repr(pred[:len(transcript)+5])

    pred = re.sub(' ','_',pred)
    gta = re.sub(' ','_',gta)

    if (trial+1)%args.savefreq==0:
        ofile = oname%(trial+1)+".gz"
        print "# saving",ofile
        save_lstm(ofile,network)
        last_save = ofile

    if do_display:
        figure("training",figsize=(1400//75,800//75),dpi=75)
        clf()
        gcf().canvas.set_window_title(args.output)
        plot_network_info(network,transcript,pred,gta)
        ginput(1,0.01)
        if args.movie is not None:
            draw()
            savefig("%s-%08d.png"%(args.movie,trial),bbox_inches=0)