#!/usr/bin/env python
# Xfalign - align serial images
#
# Author: David Mastronarde
#
# $Id: xfalign,v 937343107256 2023/02/19 22:44:16 mast $

progname = 'xfalign'
prefix = 'ERROR: ' + progname + ' - '
tmpxf = ''

# Cleanup files
def cleanup():
   cleanupFiles([tmpxf, tmpxf + '~', tmpRawAli, tmpRawAli + '~'])

# Make a gray scale file from color file if necessary and return the name to use
def makeGrayFile(filename):
   (root, ext) = os.path.splitext(filename)
   newname = root + '_gray' + ext
   if os.path.exists(newname) and os.path.getmtime(newname) > os.path.getmtime(filename):
      return newname
   if os.path.exists(newname):
      prnstr('Remaking gray scale file ' + newname + ' because it is older than ' + \
                filename)
   else:
      prnstr('Making gray scale file for computing alignments: ' + newname)
   try:
      runcmd(fmtstr('clip resize -m 0 "{}" "{}"', filename, newname))
   except ImodpyError:
      exitFromImodError()
   return newname


# Make or start a command list for running tiltxcorr
def makeTiltxcorrCom(outputFile, warping):
   tiltcom = ['OutputFile ' + outputFile,
              'FilterSigma1 ' + str(xcfilter[0]),
              'FilterSigma2 ' + str(xcfilter[1]),
              'FilterRadius1 ' + str(xcfilter[2]),
              'FilterRadius2 ' + str(xcfilter[3])]
   binning = 0
   if warping > 1:
      tiltcom.append('InputFile ' + tmpRawAli)
   else:
      tiltcom.append('InputFile ' + infile)
   if warping:
      tiltcom += [fmtstr('SizeOfPatchesXandY {},{}', warpPatchX, warpPatchY),
                  'FindWarp 1']
      if boundModel:
         tiltcom.append('BoundaryModel ' + boundModel)
      if seedModel:
         tiltcom.append('SeedModel ' + seedModel)
      if warpLimitsX or warpLimitsY:
         tiltcom.append(fmtstr('ShiftLimitsXandY {},{}', warpLimitsX, warpLimitsY))

      # Set binning if specified, provide size to compute it from
      if warpBinning > 0:
         binning = warpBinning
      sizeForBin = min(warpPatchX, warpPatchY)

   else:
      tiltcom += ['RotationAngle 0', 'FirstTiltAngle 0.', 'TiltIncrement 0.']
      if limits[0] > 0. and limits[1] > 0.:
         tiltcom.append(fmtstr('ShiftLimitsXandY {},{}', int(round(limits[0])),
                               int(round(limits[1]))))

      # Set binning if specified, provide size to compute it from
      if preBinning > 0:
         binning = preBinning
      sizeForBin = min(newsize[0], newsize[1])
      
   if ifsize:
      tiltcom += [fmtstr('XMinAndMax {},{}', xmin, xmax),
                  fmtstr('YMinAndMax {},{}', ymin, ymax)]

   # Add the specified binning or predict the default
   if binning > 0:
      tiltcom.append(fmtstr('BinningToApply {}', binning))
   else:
      binning = max(1, int(sizeForBin // 1250 + 0.99))

   # Turn on antialiasing for binning >= 4
   if binning >= 4:
      tiltcom.append('AntialiasFilter 5')

   # Add the skip or break list, adding 1 to get from sections to views
   if (skipList or breakList) and warping < 2:
      if skipList:
         nlist = skipList
         opt = 'SkipViews'
      else:
         nlist = breakList
         opt = 'BreakAtViews'
      for num in nlist:
         opt += ' ' + str(num + 1)
      tiltcom.append(opt)

   return tiltcom


# load System Libraries
import os, sys, time

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

# Initializations
tmpdir = imodTempDir()
xcfilter = [0.01,0.05,0,0.25]     # Default in adoc
tmpxf = tmpdir + '/' + progname + '.xf.' + str(os.getpid())
tmpRawAli = tmpdir + '/' + progname + '.st.' + str(os.getpid())
breakList = []
skipList = []
limits = [0, 0, 0, 0, .1, 4]      # Default in adoc

# Fallbacks from ../manpages/autodoc2man 3 1 xfalign
options = [":InputImageFile:FN:", ":OutputTransformFile:FN:", "size:SizeToAnalyze:IP:",
"offset:OffsetToSubarea:IP:", "matt:EdgeToIgnore:F:",
"reduce:ReduceByBinning:I:", "filter:FilterParameters:FA:",
"sobel:SobelFilter:B:", "params:ParametersToSearch:I:",
"limits:LimitsOnSearch:FA:", "bilinear:BilinearInterpolation:B:",
"ccc:CorrelationCoefficient:B:", "local:LocalPatchSize:I:",
"reference:ReferenceFile:FN:", "prexcorr:PreCrossCorrelation:B:",
"xcfilter:XcorrFilter:FA:", "xcreduce:XcorrReduction:I:",
"initial:InitialTransforms:FN:", "warp:WarpPatchSizeXandY:IP:",
"boundary:BoundaryModel:FN:", "seed:WarpSeedModel:FN:",
"shift:ShiftLimitsForWarp:IP:", "wreduce:WarpReduction:I:",
"skip:SkipSections:LI:", "break:BreakAtSections:LI:",
"bpair:PairedImages:B:", "tomo:TomogramAverages:B:",
"diff:DifferenceOutput:B:", "one:SectionsNumberedFromOne:B:", ":PID:B:",
"help:usage:B:"]

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 2, 1, 1)
passOnKeyInterrupt(True)

doPID = PipGetBoolean('PID', 0)
printPID(doPID)

infile = PipGetInOutFile('InputImageFile', 0)
if not infile:
   exitError('Input image file must be entered')
xflistfile = PipGetInOutFile('OutputTransformFile', 1)
if not xflistfile:
   exitError('Output file for transforms must be entered')

newsize = PipGetTwoIntegers('SizeToAnalyze', 0, 0)
ifsize = 1 - PipGetErrNo()
newcen = PipGetTwoIntegers('OffsetToSubarea', 0, 0)
filterparam = PipGetFloatArray('FilterParameters', 4)
nreduce = PipGetInteger('ReduceByBinning', 2)      # Default in adoc
ifskip = PipGetString('SkipSections', '')
if ifskip:
   skipList = parselist(ifskip)
   if not skipList:
      exitError('Parsing the skip list')
ifbreak = PipGetString('BreakAtSections', '')
if ifbreak:
   breakList = parselist(ifbreak)
   if not breakList:
      exitError('Parsing the break list')

ifpair = PipGetBoolean('PairedImages', 0)
fracmatt = PipGetFloat('EdgeToIgnore', 0.05)      # Default in adoc
nparam = PipGetInteger('ParametersToSearch', 0)     # Default in adoc
if nparam and (nparam < -1 or (nparam > 0 and (nparam + 3) // 5 != 1)):
   exitError('Number of parameters to search must be -1, 0, or 2-6')
ifbilinear = PipGetBoolean('BilinearInterpolation', 0)
reffile = PipGetString('ReferenceFile', '')
prexcorr = PipGetBoolean('PreCrossCorrelation', 0)
(warpPatchX, warpPatchY) = PipGetTwoIntegers('WarpPatchSizeXandY', 0, 0)
boundModel = PipGetString('BoundaryModel', '')
(warpLimitsX, warpLimitsY) = PipGetTwoIntegers('ShiftLimitsForWarp', 0, 0)
if (warpPatchX and warpPatchY <= 0) or (warpPatchY and warpPatchX <= 0):
   exitError('Patch sizes must be positive in both X and Y')
if nparam == -1 and not prexcorr and not warpPatchX:
   exitError('-param entry can be -1 only when doing initial cross-correlation or ' + \
                'warping')
xcfin = PipGetFloatArray('XcorrFilter', 4)
if xcfin:
   xcfilter = xcfin
prexffile = PipGetString('InitialTransforms', '')
iftomo = PipGetBoolean('TomogramAverages', 0)
diffout = PipGetBoolean('DifferenceOutput', 0)
doCCC = PipGetBoolean('CorrelationCoefficient', 0)
sobel = PipGetBoolean('SobelFilter', 0)
local = PipGetInteger('LocalPatchSize', 0)
fromOne = PipGetBoolean('SectionsNumberedFromOne', 0)
limitsIn = PipGetFloatArray('LimitsOnSearch', 0)
if limitsIn:
   for i in range(min(6, len(limitsIn))):
      limits[i] = limitsIn[i]

seedModel = PipGetString('WarpSeedModel', '')
warpBinning = PipGetInteger('WarpReduction', 0)
preBinning = PipGetInteger('XcorrReduction', 0)

# Error checks
if prexcorr and reffile:
   exitError('You cannot use initial cross-correlation with alignment to one reference')
if prexcorr and prexffile:
   exitError('You cannot use initial cross-correlation with initial transforms')
if skipList and breakList:
   exitError('You cannot use both a break list and skip list')
if breakList and reffile:
   exitError('It is meaningless to use a break list with alignment to a reference')
if warpPatchX and reffile:
   exitError('You cannot do warping alignment to a reference')
if not os.path.exists(infile):
   exitError('Input image file ' + infile + ' does not exist')
if prexffile and not os.path.exists(prexffile):
   exitError('Initial transform file ' + prexffile + ' does not exist')
if reffile and not os.path.exists(reffile):
   exitError('Reference image file ' + reffile + ' does not exist')

# Check if initial file is warping
if prexffile:
   preLines = readTextFile(prexffile, maxLines = 1)
   if preLines and len(preLines[0].split()) == 1:
      exitError('You cannot use warping transforms as initial transforms')

try:
   (nx, ny, numsec, mode, px, py, pz) = getmrc(infile)
   if reffile:
      (refnx, refny, refnumsec, refmode, px, py, pz) = getmrc(reffile)
except ImodpyError:
   exitFromImodError(progname)

# If sections numbered from 1, shift the entered lists down
if fromOne and skipList:
   for i in range(len(skipList)):
      skipList[i] -= 1
if fromOne and breakList:
   for i in range(len(breakList)):
      skipList[i] -= 1
   
# If serial tomograms or paired images, check other options
if iftomo or ifpair:
   if ifskip or ifbreak:
      exitError('You cannot enter a break or a skip list in tomogram or pair mode')
   if reffile:
      exitError('You cannot use alignment to one reference with tomograms or pairs')

   # set up break list as even sections
   breakList = list(range(2,numsec,2))
   ifbreak = True

# Given new size and center, compute min/max of ranges
if ifsize:
   xmin = (nx - newsize[0]) // 2 + newcen[0]
   xmax = (nx + newsize[0]) // 2 + newcen[0] - 1
   ymin = (ny - newsize[1]) // 2 + newcen[1]
   ymax = (ny + newsize[1]) // 2 + newcen[1] - 1
   if xmin < 0 or xmax >= nx or xmin >= xmax or ymin < 0 or ymax >= ny or ymin >= ymax:
      exitError('New center or subarea offset gives an area outside range of image')
else:
   newsize = [nx, ny]

# Check matt entry
xtrim = int(fracmatt)
ytrim = xtrim
if fracmatt < 1.:
   xtrim = int(newsize[0] * fracmatt)
   ytrim = int(newsize[1] * fracmatt)
if newsize[0] - 2 * xtrim < 10 or newsize[1] - 2 * ytrim < 10:
   exitError('Entry for new size or for edge to ignore leaves too small an area')

if mode == 16:
   infile = makeGrayFile(infile)
if reffile and refmode == 16:
   reffile = makeGrayFile(reffile)

# Warp patch tracking can be run on the whole stack if there is no prealignment, no
# initial transforms, and no search 
if warpPatchX and nparam == -1 and not prexcorr and not prexffile:
   tiltcom = makeTiltxcorrCom(xflistfile, 1)
   prnstr('RUNNING TILTXCORR WITH PATCH TRACKING TO FIND ALL WARPING ALIGNMENTS...')
   sys.stdout.flush()

   try:
      runcmd('tiltxcorr -StandardInput', tiltcom)
   except ImodpyError:
      exitFromImodError(progname)

   prnstr('DONE')
   sys.exit(0)

# If doing warps afterwards, adjust the output filename for the preliminaries
(preroot, ext) = os.path.splitext(xflistfile)
if warpPatchX:
   warpOutFile = xflistfile
   xflistfile = preroot + '.linxf'

# Set up initial cross-correlation
if prexcorr:
   prexffile = preroot + '.xcxf'
   tiltcom = makeTiltxcorrCom(prexffile, 0)
   prnstr('RUNNING TILTXCORR FOR INITIAL CROSS-CORRELATION ALIGNMENTS...')
   sys.stdout.flush()
      
   try:
      runcmd('tiltxcorr -StandardInput', tiltcom)
   except ImodpyError:
      exitFromImodError(progname)

   prexflist = readTextFile(prexffile)

   # trim the transform list for tomos
   if iftomo:
      idel = 2
      while idel < len(prexflist):
         prexflist.pop(idel)
         idel += 1
      writeTextFile(prexffile, prexflist)

   prnstr('X, Y SHIFTS FOUND:')
   for ind in range(len(prexflist)):
      if not ifpair or ind % 2:
         try:
            lsplit = prexflist[ind].split()
            prnstr(fmtstr('{:9.2f}  {:9.2f}', float(lsplit[4]),  float(lsplit[5])))
         except:
            exitError('Getting shifts from ' + prexffile)
   sys.stdout.flush()

   if nparam < 0:
      writeTextFile(xflistfile, prexflist)
      if not warpPatchX:
         sys.exit(0)

# Loop on all the sections in file
try:
   nextRef = -1
   xfoutList = []
   if nparam >= 0:
      prnstr('TRANSFORMS FOUND BY XFSIMPLEX:')
   for sec in range(numsec):
      if nparam < 0:
         continue
      
      fileref = infile
      refsec = nextRef
      if reffile:
         refsec = 0
         fileref = reffile

      doskip = 0
      dobreak = 0
      if (ifskip and sec in skipList) or (ifbreak and sec in breakList):
         doskip = ifskip
         dobreak = ifbreak
         if ifbreak:
            nextRef = sec
      else:
         nextRef = sec
         
      if doskip or dobreak or refsec < 0:
         if not sec or not iftomo:
            xfoutList.append('   1.0000000   0.0000000   0.0000000   1.0000000       ' +\
                             '0.000       0.000')

      else:
         simpcom = ['AImageFile ' + fileref,
                    'BImageFile ' + infile,
                    'OutputFile ' + tmpxf,
                    fmtstr('SectionsToUse {},{}', refsec, sec),
                    fmtstr('VariablesToSearch {}', nparam),
                    fmtstr('BinningToApply {}', nreduce),
                    fmtstr('LinearInterpolation {}', ifbilinear),
                    fmtstr('EdgeToIgnore {}', fracmatt),
                    fmtstr('CorrelationCoefficient {}', doCCC),
                    fmtstr('SobelFilter {}', sobel),
                    'FloatOption 1',
                    'AntialiasFilter -1']

         if local:
            simpcom.append(fmtstr('LocalPatchSize {}', local))

         # Set up initial transform
         if prexffile:
            presec = sec
            if iftomo:
               presec = (sec + 1) // 2
            simpcom += ['InitialTransformFile ' + prexffile,
                        fmtstr('UseTransformLine {}', presec)]

         # Add filter
         if filterparam:
            simpcom += ['FilterSigma1 ' + str(filterparam[0]),
                        'FilterSigma2 ' + str(filterparam[1]),
                        'FilterRadius1 ' + str(filterparam[2]),
                        'FilterRadius2 ' + str(filterparam[3])]

         # Handle subarea
         if ifsize:
            simpcom += [fmtstr('XMinAndMax {},{}', xmin + xtrim, xmax - xtrim),
                        fmtstr('YMinAndMax {},{}', ymin + ytrim, ymax - ytrim)]

         if nparam:
            simpcom.append(fmtstr('LimitsOnSearch {},{},{},{},{},{}', limits[0],
                                  limits[1], limits[2], limits[3], limits[4], limits[5]))
         else:
            simpcom.append(fmtstr('LimitsOnSearch {},{}', limits[0], limits[1]))

         try:
            simplines = runcmd('xfsimplex -StandardInput', simpcom)
         except ImodpyError:
            cleanup()
            exitFromImodError(progname)

         if diffout:
            for i in range(len(simplines)):
               if simplines[i].find('FINAL VALUES') >= 0 and i < len(simplines) - 1:
                  lsplit = simplines[i+1].split()
                  if len(lsplit) > 1:
                     label = 'Difference:  '
                     if doCCC:
                        label = 'CCC:  '
                     prnstr(label + lsplit[1])

         xfone = readTextFile(tmpxf)
         prnstr(fmtstr('{} : {}', sec, xfone[0]))
         sys.stdout.flush()
         xfoutList.append(xfone[0])

except KeyboardInterrupt:
   cleanup()
   sys.exit(1)

if nparam >= 0:
   makeBackupFile(xflistfile)
   writeTextFile(xflistfile, xfoutList)
   cleanup()
   
if not warpPatchX:
   sys.exit(0)

try:
   nextRef = -1
   anyDone = False

   # Set file for prealign transforms if any and get them in
   xfForWarp = prexffile
   if prexcorr or nparam >= 0:
      xfForWarp = xflistfile

   prnstr('RUNNING TILTXCORR TO FIND WARPING ALIGNMENTS ONE SECTION AT A TIME...')
   sys.stdout.flush()

   # Loop on sections as above
   for sec in range(numsec):
      refsec = nextRef
      doskip = 0
      dobreak = 0
      if (ifskip and sec in skipList) or (ifbreak and sec in breakList):
         doskip = ifskip
         dobreak = ifbreak
         if ifbreak:
            nextRef = sec
      else:
         nextRef = sec
         
      if doskip or dobreak or refsec < 0:

         # If skipping or breaking and file has been started, append unit 
         # transform and a leading 0 for no control points.  Both break and skip require
         # a unit transform here regardless of what is in the initial file
         if anyDone:
            try:
               action = 'Opening'
               warpfile = open(warpOutFile, 'a')
               action = 'Appending to'
               warpfile.write('0\n   1.0000000   0.0000000   0.0000000   1.0000000    ' +\
                   '   0.000       0.000\n')
               warpfile.close()
                         
            except IOError:
               exitError(action + ' warp transform file: ' + str(sys.exc_info()[1]))

      else:

         # Newstack the pair of sections
         newstcom = ['InputFile ' + infile,
                     'OutputFile ' + tmpRawAli,
                     fmtstr('SectionsToRead {},{}', refsec, sec)]
         if xfForWarp:
            newstcom += ['TransformFile ' + xfForWarp,
                         'UseTransformLines 0,' + str(sec)]
         try:
            runcmd('newstack -StandardInput', newstcom)
         except ImodpyError:
            cleanup()
            exitFromImodError(progname)

         # Make the tiltxcorr command
         tiltcom = makeTiltxcorrCom(warpOutFile, 2)
         tiltcom.append(fmtstr('RawAndAlignedPair {},{}', sec + 1, numsec))
         if xfForWarp:
            tiltcom.append('PrealignmentTransformFile ' + xfForWarp)
         if anyDone:
            tiltcom.append('AppendToWarpFile')

         try:
            runcmd('tiltxcorr -StandardInput', tiltcom)
         except ImodpyError:
            cleanup()
            exitFromImodError(progname)

         anyDone = True
         prnstr('.', end = '')
         sys.stdout.flush()

   prnstr('\nDONE')
            
except KeyboardInterrupt:
   cleanup()
   sys.exit(1)

cleanup()
sys.exit(0)
