#!/usr/bin/env python
# autopatchfit - Run patchcorr and matchorwarp to achieve desired residuals
#
# Author: David Mastronarde
#
# $Id: autopatchfit,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'autopatchfit'
prefix = 'ERROR: ' + progname + ' - '

def failureLog(comName, logLines):
   for line in logLines:
      if 'ERROR:' in line and '-StandardInput: exited' not in line:
         prnstr(line)
   exitError(comName + ' failed with unrecoverable error')


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, math, shutil

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *

# Fallbacks from ../manpages/autodoc2man 3 1 autopatchfit
options = ["final:FinalPatchTypeOrXYZ:CH:", "extra:ExtraResidualTargets:CH:",
           "high:HighDensityFinalTrial:I:", "trial:TrialMode:B:",
           "skip:SkipFirstPatchcorr:B:", "help:usage:B:"]

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 0)

(comExt, dualNum, setroot, typeExt, stackExt) = findRootAxisAndExtensions \
                                                (forceSingle = -1)
if not comExt:
   comExt = defaultComExtension()
patchcom = 'patchcorr.' + comExt
mowcom = 'matchorwarp.' + comExt

# Get options
finalSize = PipGetString('FinalPatchTypeOrXYZ', 'L')
(nxFinal, nyFinal, nzFinal, err) = patchSizeFromEntry(finalSize)
if err:
   exitError('You must enter one of S, M, L, or E or sizes in X,Y,Z for ' +\
                'the -final option')

extraTarget = PipGetString('ExtraResidualTargets', '')
highDensity = PipGetInteger('HighDensityFinalTrial', 0)
trialMode = PipGetBoolean('TrialMode', 0)
skipFirst = PipGetBoolean('SkipFirstPatchcorr', 0)

# Get patchcorr and determine high density setting unless entered
patchLines = readTextFile(patchcom)
   
# Determine if there is an initial shift entry, note that if the user leaves X and Y
# fields blank in Etomo, this is a ",value/"
if not highDensity:
   highDensity = -1
   initialShift = optionValue(patchLines, 'InitialShiftXYZ', STRING_VALUE)
   if initialShift:
      highDensity = 1

# First manage the trial flag in matchorwarp and pull off extra targets
# Identify if standard input form of file
mowLines = readTextFile(mowcom)
found = 0
stdInputLine = -1
trialLine = -1
warpLine = -1
warpOpt = '-warplimit'
for lineInd in range(len(mowLines)):
   line = mowLines[lineInd].strip()
   if 'matchorwarp' in line and '-StandardI' in line:
      stdInputLine = lineInd
      warpOpt = 'WarpLimits'
   if stdInputLine >= 0 and line.startswith('WarpLimits'):
      if extraTarget:
         mowLines[lineInd] = mowLines[lineInd].replace(',' + extraTarget, '')
      warpLine = lineInd
   if stdInputLine < 0 and (not line.startswith('#')) and '-warplimit' in line:
      warpLine = lineInd
      if extraTarget:
         mowLines[lineInd] = mowLines[lineInd].replace(',' + extraTarget, '')
      if trialMode:
         mowLines[lineInd] = mowLines[lineInd].replace('-warplimit', '-trial -warplimit')
         break
   if stdInputLine >= 0 and line.startswith('Trial'):
      trialLine = lineInd
   if stdInputLine < 0 and not trialMode and (not line.startswith('#')) and \
          '-trial' in line:
      mowLines[lineInd] = mowLines[lineInd].replace('-trial', '')

if warpLine < 0:
   exitError('Cannot find existing residual targets in ' + mowcom)

# Pull out or put in the Trial line, adjusting warpLine as needed
if stdInputLine >= 0:
   if trialMode and trialLine < 0:
      mowLines.insert(stdInputLine + 1, 'TrialMode 1')
      warpLine += 1
   elif not trialMode and trialLine >= 0:
      del mowLines[trialLine]
      if trialLine < warpLine:
         warpLine -= 1

makeBackupFile(mowcom)
writeTextFile(mowcom, mowLines)

# Make sure matchorwarp can be modified for extra target before starting
if extraTarget:
   line = mowLines[warpLine]
   lsplit = line.strip().replace('	', ' ').split()
   for ind in range(len(lsplit) - 1):
      if lsplit[ind] == warpOpt:
         curTarget = lsplit[ind + 1]
         if not curTarget.endswith(extraTarget):
            newTarget = curTarget + ',' + extraTarget
            mowLines[warpLine] = line.replace(curTarget, newTarget)
            break

# Figure out the maximum number of runs in advance
nxyzPatchCur = optionValue(patchLines, 'PatchSizeXYZ', INT_VALUE, numVal = 3)
numXYZorig = optionValue(patchLines, 'NumberOfPatchesXYZ', INT_VALUE, numVal = 3)
xMinMax = optionValue(patchLines, 'XMinAndMax', INT_VALUE, numVal = 2)
zMinMax = optionValue(patchLines, 'YMinAndMax', INT_VALUE, numVal = 2)
yMinMax = optionValue(patchLines, 'ZMinAndMax', INT_VALUE, numVal = 2)
if not nxyzPatchCur or not numXYZorig or not xMinMax or not yMinMax or not zMinMax:
   exitError('Cannot find one of patch size, number of patches, or X, Y or Z limits in '
             + patchcom)

nxCurr = nxyzPatchCur[0]
nyCurr = nxyzPatchCur[2]
nzCurr = nxyzPatchCur[1]
if nxFinal < nxCurr or nyFinal < nyCurr or nzFinal < nzCurr:
   exitError('Final patch size cannot be smaller that current size in any dimension')

if nxFinal == nxCurr and nyFinal == nyCurr and nzFinal == nzCurr:
   maxTrials = 1
else:

   # Try to match up size with the stock ones
   curInd = -1
   finalInd = -1
   for ind in range(len(PATCHXY)):
      if nxFinal == nyFinal and nxFinal == PATCHXY[ind] and nzFinal == PATCHZ[ind]:
         finalInd = ind
      if nxCurr == PATCHXY[ind] and nyCurr == PATCHXY[ind] and nzCurr == PATCHZ[ind]:
         curInd = ind

   if curInd >= 0 and finalInd >= 0:
      maxTrials = finalInd + 1 - curInd
   else:

      # If no size match, target steps of 1.25, round number of steps up so the steps
      # will be no bigger than ~1.3
      maxFactor = max(float(nxFinal) / nxCurr, float(nyFinal) / nyCurr, 
                      float(nzFinal) / nzCurr)
      numSteps = max(1, int(math.log(maxFactor) / math.log(1.25) + 0.75))
      stepFactor = math.exp(math.log(maxFactor) / numSteps)
      maxTrials = numSteps + 1
      
if highDensity > 0:
   maxTrials += 1

# Start loop on trials
nxNew = nxCurr
nyNew = nyCurr
nzNew = nzCurr
cumFactor = 1.
zRange = zMinMax[1] + 1 - zMinMax[0]
if numXYZorig[1] == 1:
   nzLimit = zRange
else:
   nzLimit = (zRange * 3) // 2

for trial in range(maxTrials):
   finalTrial = trial == maxTrials - 1
   
   # Modify to the next size or density after the first trial
   if trial:
      if not (highDensity > 0 and finalTrial):
         densityInd = 0
         if curInd >= 0 and finalInd >= 0:
            curInd += 1
            nxNew = nyNew = PATCHXY[curInd]
            nzNew = PATCHZ[curInd]
         elif trial < numSteps:
            cumFactor *= stepFactor
            nxNew = 2 * int(round(nxCurr * cumFactor / 2.))
            nyNew = 2 * int(round(nyCurr * cumFactor / 2.))
            nzNew = 2 * int(round(nzCurr * cumFactor / 2.))
         else:
            nxNew = nxFinal
            nyNew = nyFinal
            nzNew = nzFinal
            
         nzNew = min(nzNew, nzLimit)

      else:
         densityInd = 1
         
      numXnew = autoPatchNumber(nxNew, xMinMax[0], xMinMax[1], 0, densityInd)
      numYnew = autoPatchNumber(nyNew, yMinMax[0], yMinMax[1], 0, densityInd)
      numZnew = autoPatchNumber(nzNew, zMinMax[0], zMinMax[1], 1, densityInd)
      if numXYZorig[1] == 1:
         numZnew = 1
      sedcom = [sedModify('PatchSizeXYZ', fmtstr('{},{},{}', nxNew, nzNew, nyNew)),
                sedModify('NumberOfPatchesXYZ', fmtstr('{},{},{}', numXnew, numZnew,
                                                       numYnew))]
      if trial == 1:
         makeBackupFile(patchcom)
      pysed(sedcom, patchLines, patchcom)
      prnstr(fmtstr('AUTOPATCHFIT - Changing to patch size {} {} {}, number {} {} {}',
                    nxNew, nyNew, nzNew, numXnew, numYnew, numZnew))

   else:
      prnstr(fmtstr('AUTOPATCHFIT - Using initial patch size {} {} {}, number {} {} {}',
                    nxNew, nyNew, nzNew, numXYZorig[0], numXYZorig[2], numXYZorig[1]))

   # Modify matchorwarp if there are extra criteria on last round
   if finalTrial and extraTarget:
      prnstr(fmtstr('AUTOPATCHFIT - Adding {} to warp residual limits', extraTarget))
      writeTextFile(mowcom, mowLines)

   if trial or not skipFirst:
      prnstr('AUTOPATCHFIT - Running ' + patchcom, flush = True)
      try:
         runcmd('vmstopy -x -q ' + patchcom + ' patchcorr.log')
      except ImodpyError:
         logLines = readTextFile('patchcorr.log')
         failureLog(patchcom, logLines)
      
   prnstr('AUTOPATCHFIT - Running ' + mowcom, flush = True)
   try:
      makeBackupFile('matchorwarp.log')
      runcmd('vmstopy -x -q ' + mowcom + ' matchorwarp.log')

      # Success!
      logLines = readTextFile('matchorwarp.log')
      refineRes = ''
      warpRes = ''
      for line in logLines:
         if 'FOUND A GOOD' in line.upper():
            if 'REFINEMATCH' in line.upper() and refineRes:
               prnstr('AUTOPATCHFIT - Refinematch found a good transformation, mean ' + \
                         'residual ' + refineRes)
               break
            if 'FINDWARP' in line.upper() and warpRes:
               prnstr('AUTOPATCHFIT - Findwarp found a good warping, mean ' + \
                         'residual ' + warpRes)
               break
         if 'Mean residual' in line:
            lsplit = line.replace(',', ' ').split()
            for token in lsplit:
               if '.' in token:
                  try:
                     dummy = float(token)
                     if 'has' in line:
                        warpRes = token
                     else:
                        refineRes = token
                  except Exception:
                     pass
                  break
               
      sys.exit(0)
      
   except ImodpyError:
      logLines = readTextFile('matchorwarp.log')
      for line in logLines:
         if 'FINDWARP - FAILED TO FIND' in line.upper():
            prnstr(line.replace('ERROR: ', ''))
            break
      else:   # ELSE ON FOR
         failureLog(mowcom, logLines)

      # Save the current patches unless it is the last round
      if not finalTrial:
         patchName = fmtstr('patch_{}x{}x{}.out', nxNew, nyNew, nzNew)
         try:
            makeBackupFile(patchName)
            os.rename('patch.out', patchName)
            prnstr('AUTOPATCHFIT - Renamed patch.out as ' + patchName)
         except OSError:
            exitError('Renaming patch.out to ' + patchName)

# End of loop with no success.  How to leave things?
exitError('Could not get patch correlations with an acceptable fit')

   
