#!/usr/bin/env python
# matchrotpairs - Find best-matching pair between parts of two tilt series
#
# Author: David Mastronarde
#
# $Id: matchrotpairs,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'matchrotpairs'
prefix = 'ERROR: ' + progname + ' - '

# Finds interpolated fit position as in libcfshr
def parabolicFitPosition(y1, y2, y3):
   cx = 0.
   denom = 2. * (y1 + y3 - 2. * y2)
   if math.fabs(denom) > math.fabs(1.e-2 * (y1 - y3)):
      cx = (y1 - y3) / denom
   return max(-0.5, min(0.5, cx))


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, glob, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *
from tiltmatch import *

# Fallbacks from ../manpages/autodoc2man 3 1 matchrotpairs
options = ["ia:AImageFile:FN:", "ib:BImageFile:FN:", "output:OutputFile:FN:",
           "za:AStartingEndingViews:IP:", "zb:BStartingEndingViews:IP:",
           "swap:SwapAandB:B:", "a:AngleOfRotation:I:", "mirror:MirrorXaxis:I:",
           "d:DistortionFile:FN:", "b:ImagesAreBinned:I:", "m:RunMidas:B:",
           "scan:ScanRotationMaxAndStep:FP:", "nearest:NearestNeighbor:B:",
           "x:WriteAllTransforms:B:", "t:LeaveTempFiles:B:", ":PID:B:"]

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 0, 0)
passOnKeyInterrupt(True)

doPID = PipGetBoolean('PID', 0)
printPID(doPID)

# Set names of temp files
tmpMinxf = getTempNames(progname)
(tmpRoot, tmpDir, pid) = getTempComponents()

tmpStack = tmpRoot + 'stack' + pid
tmpTwoxf = tmpRoot + 'twoxf' + pid
tmpXfmod = tmpRoot + 'xfmod' + pid
tmpMidxf = tmpRoot + 'midxf' + pid

imageA = PipGetInOutFile('AImageFile', 0)
imageB = PipGetInOutFile('BImageFile', 1)
AA = 'A'
BB = 'B'
aviewOpt = 'AStartingEndingViews'
bviewOpt = 'BStartingEndingViews'
ifBtoA = PipGetBoolean('SwapAandB', 0)

# Swap files, letters, view range if doing B to A
if ifBtoA:
   AA = 'B'
   BB = 'A'
   tmp = imageA
   imageA = imageB
   imageB = tmp
   aviewOpt = 'BStartingEndingViews'
   bviewOpt = 'AStartingEndingViews'

outFile = PipGetInOutFile('OutputFile', 2)
if not imageA or not imageB or not outFile:
   exitError('You must enter two input files and an output file')

# Make sure image files exist
for imfile in (imageA, imageB):
   if not os.path.exists(imfile):
      exitError('Image file ' + imfile + ' does not exist')

# Get image sizes
(nxa, nya, nza) = getmrcsize(imageA)
(nxb, nyb, nzb) = getmrcsize(imageB)

# Get starting and ending section numbers
(asecStart, asecEnd) = PipGetTwoIntegers(aviewOpt, 1, nza)
if asecStart < 1 or asecEnd > nza or asecStart > asecEnd:
   exitError('Starting and ending views from ' + AA + ' are out of range or out of order')

(bsecStart, bsecEnd) = PipGetTwoIntegers(bviewOpt, 1, nzb)
if bsecStart < 1 or bsecEnd > nzb or bsecStart > bsecEnd:
   exitError('Starting and ending views from ' + BB + ' are out of range or out of order')

# Convert to center views and number of views
zeroA = (asecStart + asecEnd - 1) // 2
zeroB = (bsecStart + bsecEnd - 1) // 2
nviewsA = asecEnd + 1 - asecStart
nviewsB = bsecEnd + 1 - bsecStart

# get distortion, nearest neighbor, all transform options
distort = ''
bilinear = 1
distortFile = PipGetString('DistortionFile', '')
if distortFile:
   imageBinned = PipGetInteger('ImagesAreBinned', -1)
   distort = '-dist "' + distortFile + '"'
   if imageBinned > 0:
      distort += ' -image ' + str(imageBinned)

nearest = PipGetBoolean('NearestNeighbor', 0)
if nearest:
   bilinear = 0

allXfOut = PipGetBoolean('WriteAllTransforms', 0)
(outRoot, ext) = os.path.splitext(outFile)

# Run the search
(asecBest, bsecBest, diffList, allXfList) = \
    searchPairs(progname, zeroA, zeroB, nviewsA, nviewsB, imageA, imageB, nxa, nxb, nya, \
                   nyb, AA, BB, '', 0, 0, distort, bilinear, allXfOut, 1)

try:
   prnstr(fmtstr('Views in best pair: {} {}  {} {}', AA, asecBest + 1, BB, bsecBest + 1))

   # Get interpolated position or report that search is at end of range
   indA = asecBest - (asecStart - 1)
   indB = bsecBest - (bsecStart - 1)
   if indA > 0 and indA < nviewsA - 1 and indB > 0 and indB < nviewsB -1:
      x1 = diffList[indB - 1][indA]
      x3 = diffList[indB + 1][indA]
      y1 = diffList[indB][indA - 1]
      y3 = diffList[indB][indA + 1]
      if x1 > 0. and x3 > 0. and y1 > 0. and y3 > 0.:
         interpA = parabolicFitPosition(-x1, -diffList[indB][indA], -x3)
         interpB = parabolicFitPosition(-y1, -diffList[indB][indA], -y3)
         prnstr(fmtstr('Interpolated view numbers:  {:.1f}  {:.1f}',
                       asecBest + 1 + interpA, bsecBest + 1 + interpB))
   else:
      prnstr('Best pair is at end of search range')

   # Produce a standard transform file with this transform in second line
   bestStack = outRoot + '.stack'
   minxf = []
   makeBackupFile(outFile)
   if os.path.exists(tmpMinxf):
      minxf = readTextFile(tmpMinxf)
   if len(minxf) < 1:
      cleanup()
      exitError('No alignment was computed, cannot continue')
   writeTextFile(outFile, ['1 0 0 1 0 0', minxf[0]])

   # Stack the two best sections unless nearest neighbor
   if not nearest:
      try:
         runcmd(fmtstr('newstack -sec {} -sec {} -float 2 -size {},{} -use 0,1 -float ' +\
                          '2 {} {} {} "{}"', bsecBest, asecBest, max(nxa, nxb),
                       max(nya, nyb), distort, imageB, imageA, bestStack))
      except ImodpyError:
         cleanExitError('Stacking two best views')

   # Output transform files
   if allXfOut:
      for view in range(nviewsB):
         allName = fmtstr('{}-{}.xf', outRoot, view + 1)
         writeTextFile(allName, allXfList[view])

except KeyboardInterrupt:
   pass

cleanup()
sys.exit(0)
