#!/usr/bin/env python
# dualvolmatch - Get initial transformation between dual-axis volumes
#
# Author: David Mastronarde
#
# $Id: dualvolmatch,v 7b900ada5ec6 2025/08/05 19:30:55 mast $
#

progname = 'dualvolmatch'
prefix = 'ERROR: ' + progname + ' - '

def cleanup():
   if testMode < 2:
      cleanupFiles(cleanList)

      
def cleanExitError(message):
   cleanup()
   exitError(message)
   

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *

# Fallbacks from ../manpages/autodoc2man 3 1 dualvolmatch
options = ["name:RootName:CH:", "atob:MatchAtoB:B:", "binning:BinningToApply:I:",
           "tilt:TiltAngleMaxAndStep:FP:", "refine:RefineTiltAngles:I:",
           "center:CenterShiftLimit:F:", "maxresid:MaximumResidual:F:",
           "scan:ScanRotationMaxAndStep:FP:", "final:FinalOutputFile:FN:",
           "style:NamingStyle:I:", "test:TestMode:I:", "help:usage:B:"]

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 0)

maxTiltShift = 2
leaveOpt = ''
maxDropFrac = 0.17
DTOR = 0.01745329252

(comExt, dualNum, setroot, typeExt, stackExt) = findRootAxisAndExtensions \
                                                (forceSingle = -1)
(nameStyle, typeExt) = getNamingStyle(typeExt)
if typeExt == None:
   typeExt = ''

# Get options
rootName = PipGetInOutFile('RootName', 0)
if not rootName:
   exitError('The root name of the dataset must be entered')

ifAtoB = PipGetBoolean('MatchAtoB', 0)
asrc = 'a'
bsrc = 'b'
if ifAtoB:
   asrc = 'b'
   bsrc = 'a'

rootA = rootName + asrc
rootB = rootName + bsrc
setRootAndExtension(rootB, typeExt)
recNameA = datasetFilename('.rec', root = rootA)
recNameB = datasetFilename('.rec')
for rec in (recNameA, recNameB):
   if not os.path.exists(rec):
      exitError('Tomogram does not exist: ' + rec)

(nxa, nya, nza) = getmrcsize(recNameA)
(nxb, nyb, nzb) = getmrcsize(recNameA)

binning = PipGetInteger('BinningToApply', -1)
if binning <= 0:
   minSize = min(nxa, nza, nxb, nzb)
   binning = max(1, min(4, minSize // 512))

(scanMax, scanInterval) = PipGetTwoFloats('ScanRotationMaxAndStep', 20., 4.)
(tiltMax, tiltInterval) = PipGetTwoFloats('TiltAngleMaxAndStep', 4., 2.)
if tiltMax < 0 or tiltInterval <= 0:
   exitError('Both maximum tilt angle and step size should be positive')
numSteps = int(round(2. * tiltMax / tiltInterval)) + 1
if numSteps < 3:
   exitError('The tilt interval must not be bigger than the maximum tilt angle')
                        
numRefine = PipGetInteger('RefineTiltAngles', 1)
if numRefine < 0 or numRefine > 5:
   exitError('Number of tilt angle refinements should be between 0 and 5')

ubMaxResid = PipGetFloat('MaximumResidual', 10.)
maxCenShift = PipGetFloat('CenterShiftLimit', 10.)
solveFile = PipGetString('FinalOutputFile', 'solve.xf')
testMode = PipGetInteger('TestMode', 0)
if testMode > 2:
   leaveOpt = '-t'

# Make up all needed names and put on the cleanup list
recNames = (recNameA, recNameB)
binRecs = (datasetFilename('_bin.rec', root = rootA), 
           datasetFilename('_bin.rec'))
if binning == 1:
   binRecs = recNames
binProj = (datasetFilename('_bin.proj', root = rootA), 
           datasetFilename('_bin.proj'))
bestXfFile = rootName + '_dvmatch.xf'
init3dFile = rootName + '_dvm3d.xf'
refineFile = rootName + '_refine.xf'
patchFile = rootName + '_dvmpatch.out'
onePatchFile = rootName + '_dvmcenpat.out'
matchRec = datasetFilename('_bin.mat')
cleanList = [binProj[0], binProj[1], bestXfFile, init3dFile,
             refineFile, patchFile, onePatchFile, matchRec]
if binning > 1:
   cleanList += [binRecs[0], binRecs[1]]
ind = len(cleanList)
for name in cleanList[0 : ind]:
   cleanList.append(name + '~')

# Set environment variable to produce files of the right type
if typeExt and typeExt in standardTypeExtensions():
   os.environ['MRC_OUTPUT_FORMAT'] = typeExt.upper()

try:

   # Start out with binning the volumes if needed
   if binning > 1:
      for ind in (0, 1):
         prnstr(fmtstr('Making binned by {} volume {}', binning, binRecs[ind]),
                flush = True)
         runcmd(fmtstr('binvol -bin {} "{}" "{}"', binning, recNames[ind], binRecs[ind]))

   # Set up for loop on projection and matching
   indRefine = 0
   indShiftRef = 0
   maxRefShift = 2
   numTiltShift = 0
   tiltInterval = 2. * tiltMax / (numSteps - 1)
   tiltStart = [-tiltMax, -tiltMax]
   while True:

      # Reproject at current tilt angles
      for ind in (0, 1):
         start = tiltStart[ind]
         end = tiltStart[ind] + (numSteps - 1) * tiltInterval
         prnstr(fmtstr('Making reprojection {}: {:.1f} to {:.1f} at {:.1f} degrees',
                       binProj[ind], start, end, tiltInterval))
         runcmd(fmtstr('xyzproj -axis Z -angles {},{},{} -mode 2 "{}" "{}"',
                       start, end, tiltInterval, binRecs[ind], binProj[ind]))

      # Find best match
      prnstr('Running matchrotpairs to find best matching tilts')
      matchLines = runcmd(fmtstr('matchrotpairs -near -scan {},{} -swap {} "{}" "{}" "{}"'
                                 , scanMax, scanInterval, leaveOpt, binProj[0],
                                 binProj[1], bestXfFile))

      # Look for best pair, interpolated values, and if it is mirrored
      endOfRange = True
      bestA = -1
      numMirror = 0
      numRegular = 0
      for line in matchLines:
         if testMode:
            prnstr(line.rstrip())
         if 'mirror' in line:
            numMirror += 1
         elif 'regular' in line:
            numRegular += 1
         if 'Views in best pair' in line:
            lsplit = line.split()
            bestA = int(lsplit[-1])
            bestB = int(lsplit[-3])
            bestTiltA = tiltStart[0] + (bestA - 1) * tiltInterval
            bestTiltB = tiltStart[1] + (bestB - 1) * tiltInterval
            prnstr(fmtstr('Tilt angles of best pair: {} {:.1f}  {} {:.1f}', asrc.upper(),
                          bestTiltA, bsrc.upper(), bestTiltB))

         elif 'Interpolated view' in line:
            lsplit = line.split()
            endOfRange = False
            interpA = float(lsplit[-1])
            interpB = float(lsplit[-2])
            interpTiltA = tiltStart[0] + (interpA - 1) * tiltInterval
            interpTiltB = tiltStart[1] + (interpB - 1) * tiltInterval
            prnstr(fmtstr('Interpolated tilt angles: {} {:.1f}  {} {:.1f}', asrc.upper(),
                          interpTiltA, bsrc.upper(), interpTiltB))

         elif 'Temporary files' in line:
            prnstr('Matchrotpairs ' + line.strip())

      if bestA <= 0:
         cleanExitError('Cannot find best view pair in output of Matchrotpairs')

      if endOfRange:

         # If end of range flag set, make sure it not just failure to get interpolated
         # values, that it is not a refinement search, and that a shift is still OK
         if bestA > 1 and bestA < numSteps and bestB > 1 and bestB < numSteps:
            cleanExitError('Cannot find interpolated view numbers in output of ' +\
                              'Matchrotpairs')
         if indRefine:
            if indShiftRef >= maxRefShift:
               cleanExitError(fmtstr('Search is at end of range in refinement step ' +\
                                     'after shifting refinement {} times; try setting ' +\
                                     'TiltAngleMaxAndStep to more, smaller steps ' + \
                                     '(e.g. 5,1)', maxRefShift))
            indShiftRef += 1
            prnstr('Shifting search to center on latest best pair')
               
         if numTiltShift >= maxTiltShift:
            cleanExitError('The search for best pairs has already been shifted the ' +\
                         'maximum number of times')

         numTiltShift += 1

      else:

         # If not at end of range, stop if refinement steps are over, otherwise
         # make sure there are at least 5 steps now and cut the interval
         if indRefine >= numRefine:
            break
         
         numSteps = max(7, numSteps)
         indRefine += 1
         tiltInterval /= 2.

      # Set up starting point for next round
      tiltStart[0] = bestTiltA - (numSteps // 2) * tiltInterval
      tiltStart[1] = bestTiltB - (numSteps // 2) * tiltInterval

   # Search is over, get the transformation
   xflines = readTextFile(bestXfFile, 'file with transformation between best tilts')
   if len(xflines) != 2:
      cleanExitError('The file with transformation between best tilts does not have ' +\
                        '2 lines')
   lsplit = xflines[1].split()
   if len(lsplit) < 6:
      cleanExitError('Not enough values in transformation between best tilts')
   L11 = float(lsplit[0])
   L12 = float(lsplit[1])
   L21 = float(lsplit[2])
   L22 = float(lsplit[3])
   Ldx = float(lsplit[4])
   Ldy = float(lsplit[5])

   # Get the mean mag 
   magLines = runcmd('xf2rotmagstr ' + bestXfFile)
   if not magLines[-1].strip().startswith('2:'):
      cleanExitError('Cannot find mean mag in output from xf2rotmagstr')
   prnstr('Transformation between best reprojections:')
   prnstr(magLines[-1][4:].strip())
   lsplit = magLines[-1].split()
   mag = float(lsplit[-1])
   if numMirror > numRegular:
      mag = -mag
      if L12 * L21 < 0:
         prnstr('WARNING: ' + progname + ' - Matchrotpairs indicated mirroring, but' +\
                   ' the transformation is not consistent with that')

   # Compute the transform
   cosA = math.cos(DTOR * interpTiltA)
   sinA = math.sin(DTOR * interpTiltA)
   cosB = math.cos(DTOR * interpTiltB)
   sinB = math.sin(DTOR * interpTiltB)
   a11 = L11 * cosB * cosA + mag * sinA * sinB
   a21 = -L11 * cosB * sinA + mag * cosA * sinB
   a31 = L21 * cosB
   a12 = -L11 * sinB * cosA + mag * sinA * cosB
   a22 = L11 * sinB * sinA + mag * cosA * cosB
   a32 = -L21 * sinB
   a13 = L12 * cosA
   a23 = -L12 * sinA
   a33 = L22
   dx = Ldx * cosA
   dy = -Ldx * sinA
   dz = Ldy

   # Write it to file
   format = '{:10.6f} {:10.6f} {:10.6f} {:10.3f}'
   lines = [fmtstr(format, a11, a12, a13, dx),
            fmtstr(format, a21, a22, a23, dy),
            fmtstr(format, a31, a32, a33, dz)]
   writeTextFile(init3dFile, lines)

   # Make binned matching volume
   (nxBinA, nyBinA, nzBinA) = getmrcsize(binRecs[0])
   (nxBinB, nyBinB, nzBinB) = getmrcsize(binRecs[1])
   prnstr('Making initial matching binned volume ' + matchRec)
   runcmd(fmtstr('matchvol -size {},{},{} -xffile {} "{}" "{}"', nxBinA,
                 max(nyBinA, nyBinB), nzBinA, init3dFile, binRecs[1], matchRec))

   # Set up and run the correlation search
   xmin = nxBinA // 5
   xmax = nxBinA - xmin
   zmin = nzBinA // 5
   zmax = nzBinA - zmin
   rangeX = xmax - xmin
   rangeZ = zmax - zmin
   patchSize = min(rangeX // 3, rangeZ // 3,  512 // binning)
   numPatchX = max(3, int(round((3. * rangeX + patchSize) / (2. * patchSize))))
   numPatchZ = max(3, int(round((3. * rangeZ + patchSize) / (2. * patchSize))))
   border = 24 + 12 * (min(nxBinA, nzBinA) // 1000)
   cscomBase = ['ReferenceFile ' + binRecs[0],
                'FileToAlign ' + matchRec,
                fmtstr('XMinAndMax {},{}', xmin, xmax),
                fmtstr('YMinAndMax {},{}', 1, nyBinA),
                fmtstr('ZMinAndMax {},{}', zmin, zmax),
                'BSourceOrSizeXYZ ' + binRecs[1],
                'BSourceTransform ' + init3dFile,
                fmtstr('BSourceBorderXLoHi {},{}', border, border),
                fmtstr('BSourceBorderYZLoHi {},{}', border, border),
                'FlipYZMessages']
   cscom = cscomBase + [fmtstr('PatchSizeXYZ {},{},{}', patchSize, nyBinA - 2, patchSize),
                        fmtstr('NumberOfPatchesXYZ {},1,{}', numPatchX, numPatchZ),
                        'OutputFile ' + patchFile]

   prnstr('Running corrsearch3d on large patches in binned volumes')
   runcmd('corrsearch3d -StandardInput', cscom)
   
   # Run refinematch for the transform
   rfcom = ['PatchFile ' + patchFile,
            'OutputFile ' + refineFile,
            'VolumeOrSizeXYZ ' + binRecs[0],
            'InitialTransformFile ' + init3dFile,
            'ProductTransformFile ' + solveFile,
            'ScaleShiftByFactor ' + str(binning),
            'MeanResidualLimit ' + str(nxBinA),
            'MaxFractionToDrop ' + str(maxDropFrac)]
   prnstr('Running refinematch to get refined transformation with Z shift')
   refLines = runcmd('refinematch -StandardInput', rfcom)

   # Look for residual and center shift in output
   for line in refLines:
      if testMode:
         prnstr(line.rstrip())
      if 'center shift' in line:
         lsplit = line.split()
         cenShift = float(lsplit[-1])
         #
         # BRT is using 'implies a center' as a tag
         prnstr(fmtstr('The residual for the center patch implies a center shift in ' +\
                          'Z of {:.1f}', cenShift))
      if 'Mean residual' in line:
         if not testMode:
            prnstr(line.rstrip())
         lsplit = line.replace(',', ' ').split()
         meanResid = float(lsplit[2]) * binning

   roundCen = int(round(cenShift))
   roundAbsCen = int(round(math.fabs(cenShift)))
   if meanResid > ubMaxResid:

      # If the residual is too big, just get the center shift from a bigger patch and
      # add it to the initial transform and make that be the output file
      patchSize = min((3 * patchSize) // 2, rangeX - 2, rangeZ - 2)
      cscom = cscomBase + [fmtstr('PatchSizeXYZ {},{},{}', patchSize, nyBinA - 2,
                                  patchSize),
                           'NumberOfPatchesXYZ 1,1,1',
                           'OutputFile ' + onePatchFile]

      # BRT is using 'unbinned mean residual' and 'Falling back' as tags
      prnstr(fmtstr('The unbinned mean residual is {:.2f} which is above the limit',
                    meanResid))
      prnstr('Falling back to the initial estimate of the 3D transformation')
      prnstr('Running corrsearch3d on one bigger patch to get shifts')
      runcmd('corrsearch3d -StandardInput', cscom)
      patLines = readTextFile(onePatchFile, 'file with one center patch displacement')
      if len(patLines) < 2:
         cleanExitError('Too few lines in patch output file')
      lsplit = patLines[1].split()
      lines = [fmtstr(format, a11, a12, a13, binning * dx + float(lsplit[3])),
               fmtstr(format, a21, a22, a23, binning * dy + float(lsplit[4])),
               fmtstr(format, a31, a32, a33, binning * dz + float(lsplit[5]))]
      makeBackupFile(solveFile)
      writeTextFile(solveFile, lines)

      # If there is much center shift in earlier run, warn that thickness should be bigger
      # perhaps by twice this shift
      # BRT is looking for 'may need to set thickness'
      if roundAbsCen >= 5 and nyb + 2 * roundAbsCen > nya:
         prnstr('')
         prnstr('WARNING: dualvolmatch - You may need to set thickness of initial ' + \
                   'matching file to at least ' + str(nyb + 2 * roundAbsCen))
         prnstr('     (In Etomo, Initial match size for Matchvol1)')
         prnstr('')

   elif math.fabs(cenShift) > maxCenShift:
      prnstr(fmtstr('The unbinned mean residual is {:.2f}', meanResid))

      # When using the solved transform, if the center shift is above limit, behave like
      # solvematch, using similar text, and advise on thickness too
      # BRT is looking for 'InitialShiftXYZ' and 'needs' 
      prnstr('')
      prnstr('   The center shift is bigger than the specified limit')
      prnstr('   The InitialShiftXYZ for corrsearch3d needs to be 0 ' + \
                str(roundCen) + ' 0')
      prnstr('   In Etomo, set Patchcorr Initial shifts in X, Y, Z to 0 0 ' + \
                str(roundCen))
      if nyb > nya:
         prnstr('   You should also set thickness of initial matching file to at ' + \
                   'least ' + str(nyb))
         prnstr('     (In Etomo, Initial match size for Matchvol1)')
            
      # BRT is looking for 'CenterShiftLimit' and 'avoid stopping'
      prnstr('   To avoid stopping with this error, set CenterShiftLimit to ' + \
                str(roundAbsCen + 2))

      # BRT is looking for 'Initial shift needs'
      cleanExitError('Initial shift needs to be set for patch correlation')

   cleanup()
   sys.exit(0)

except ImodpyError:
   cleanup()
   exitFromImodError(progname)
except IndexError:
   cleanExitError('Fewer than expected number of values on line in program output')
except ValueError:
   cleanExitError('Extracting a numerical value from program output')
