#!/usr/bin/env python
# matchorwarp - script to run refinewarp and matchvol if fit is good enough
# or findwarp with progressively worse criteria, then warpvol
#
# Author: David Mastronarde
#
# $Id: matchorwarp,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'matchorwarp'
prefix = 'ERROR: ' + progname + ' - '


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

clipsize = ''

# Fallbacks from ../manpages/autodoc2man 3 1 matchorwarp
options = ["inputvolume:InputVolume:FN:", "outputvolume:OutputVolume:FN:",
           "size:SizeXYZorVolume:CH:", "refinelimit:RefineLimit:F:",
           "warplimit:WarpLimits:CH:", "structurecrit:StructureCriteria:CH:",
           "extentfit:ExtentToFit:FN:", "modelfile:ModelFile:FN:",
           "patchfile:PatchFile:FN:", "solvefile:SolveFile:FN:",
           "refinefile:RefineFile:FN:", "inversefile:InverseFile:FN:",
           "warpfile:WarpFile:FN:", "residualfile:ResidualFile:FN:",
           "vectormodel:VectorModel:FN:", "clipsize:ClipPlaneBoxSize:I:",
           "tempdir:TemporaryDirectory:FN:", "xlowerexclude:XLowerExclude:I:",
           "xupperexclude:XUpperExclude:I:", "ylowerexclude:YLowerExclude:I:",
           "yupperexclude:YUpperExclude:I:", "zlowerexclude:ZLowerExclude:I:",
           "zupperexclude:ZUpperExclude:I:", "linear:LinearInterpolation:B:",
           "trial:TrialMode:B:", "help:usage:B:"]

(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 3, 1, 1)

# Get all the options
recfile = PipGetInOutFile('InputVolume', 0)
if not recfile:
   exitError('An input volume must be entered')
matfile = PipGetInOutFile('OutputVolume', 1)
if not matfile:
   exitError('An output volume must be entered')

sizein = PipGetString('size', '')
if not sizein:
   exitError('-size must be entered with nx,ny,nz or file being matched to')
patchfile = PipGetString('patchfile', 'patch.out')
modelfile = PipGetString('modelfile', '')
if modelfile:
   if not os.path.exists(modelfile):
      exitError('Model file ' + modelfile + ' does not exist')

solvefile = PipGetString('solvefile', 'solve.xf')
refinefile = PipGetString('refinefile', 'refine.xf')
inversefile = PipGetString('inversefile', 'inverse.xf')
warpfile = PipGetString('warpfile', 'warp.xf')
residualfile = PipGetString('residualfile', '')

vectormodel = PipGetString('vectormodel', '')
clipin = PipGetInteger('clipsize', 0)
if not PipGetErrNo():
   clipsize = '-c ' + str(clipin)
refinelimit = PipGetFloat('refinelimit', 0.3)
warplimit = PipGetString('warplimit', '0.2,0.27,0.35')
structCrit = PipGetString('structurecrit', '')

tempdir = PipGetString('tempdir', '')
linear = PipGetBoolean('linear', 0)

xlower = PipGetInteger('xlowerexclude', 0)
xupper = PipGetInteger('xupperexclude', 0)
ylower = PipGetInteger('ylowerexclude', 0)
yupper = PipGetInteger('yupperexclude', 0)
zlower = PipGetInteger('zlowerexclude', 0)
zupper = PipGetInteger('zupperexclude', 0)
trial = PipGetBoolean('trial', 0)
extent = PipGetInteger('extentfit', 0)

if not os.path.exists(recfile):
   exitError('Input volume ' + recfile + ' does not exist')
if not os.path.exists(patchfile):
   exitError('Input file ' + patchfile + ' does not exist')
if not os.path.exists(solvefile):
   exitError('Input file ' + solvefile + ' does not exist')


# The size entry: If it is not an existing file, use as is and hope it is numbers
# if it is a file, get the nx, ny, nz of it
size = sizein
if os.path.exists(sizein):
   try:
      (nx, ny, nz) = getmrcsize(sizein)
      size = fmtstr('{},{},{}', nx, ny, nz)
   except ImodpyError:
      exitFromImodError(progname)

if vectormodel and not residualfile:
    exitError('A residual file must be specified to make a vector model')

# Set up name for object in output model and figure out if skipping warp
objname = "Values are residuals"
if clipsize:
   objname = "Values are residuals; clip planes exist"

# Setup base lines for refinematch /findwarp
findbase = ['PatchFile ' + patchfile,
            'VolumeOrSizeXYZ ' + sizein]
if residualfile:
   findbase.append('ResidualPatchOutput ' + residualfile)
if modelfile:
   findbase.append('RegionModel ' + modelfile)
if structCrit:
   findbase.append('ExtraValueSelection 5,1')
   findbase.append('SelectionCriteria ' + structCrit)

# Setup base lines for warpvol/matchvol
volbase = ['InputFile ' + recfile,
           'OutputFile  ' + matfile,
           'OutputSizeXYZ ' + size]
if tempdir:
   volbase.append('TemporaryDirectory ' + tempdir)
if linear:
   volbase.append('InterpolationOrder 1')

# Run refinematch
# The flush is needed because in old python (2.5 or below) the output printed from the
# runcmd somehow gets ahead of this output
prnstr("MATCHORWARP: Running Refinematch to try to find single transformation")
sys.stdout.flush()

skipwarp = warplimit == "0" or warplimit == "0." or warplimit == ".0" or \
           warplimit == "0.0"

comlines = findbase + [fmtstr('MeanResidualLimit {}', refinelimit),
                       'OutputFile ' + refinefile]

try:
   try:
      runcmd('refinematch -StandardInput', comlines, 'stdout')
      savestat = 0
   except ImodpyError:
      savestat = getLastExitStatus()

   # Look for status 2 specifically, it is the code used when above the limit
   if savestat and savestat != 2:
      sys.exit(1)

   # If exiting either because of success or because warp is being skipped,
   # write the vector model now
   if (savestat == 0 or skipwarp) and vectormodel:
      prnstr(" ")
      runcmd(fmtstr('patch2imod {} -n "{}" "{}" "{}"', clipsize, objname, residualfile,
                    vectormodel))
      prnstr('MATCHORWARP: Created ' + vectormodel)

   if savestat == 0:
      prnstr(" ")
      if trial:
         prnstr("MATCHORWARP: Refinematch found a good transformation")
         sys.exit(0)

      # If refinematch did not have error exit, run matchvol
      prnstr("MATCHORWARP: Refinematch found a good transformation: next running " +\
             "Matchvol")
      prnstr(" ")
      sys.stdout.flush()
      
      comlines = volbase + ['TransformFile ' + solvefile,
                            'TransformFile ' + refinefile,
                            'InverseFile ' + inversefile]
      runcmd('matchvol -StandardInput', comlines, 'stdout')
      sys.exit(0)

   # If there is an error exit from refinematch, run findwarp as long as warplimit not 0
   if skipwarp:
      prnstr(" ")
      prnstr(fmtstr("ERROR: MATCHORWARP - Refinematch gave a mean residual error above "+\
                    "{} and warping is disabled", refinelimit))
      sys.exit(1)

   prnstr(" ")
   prnstr("MATCHORWARP: Running Findwarp to find a warping with given residual limits")
   sys.stdout.flush()

   comlines = findbase + ['TargetMeanResidual ' + warplimit,
                          'InitialTransformFile ' + solvefile,
                          'OutputFile ' + warpfile]
   if xlower or xupper:
      comlines.append(fmtstr('XSkipLeftAndRight {},{}', xlower, xupper))
   if ylower or yupper:
      comlines.append(fmtstr('YSkipLowerAndUpper {},{}', ylower, yupper))
   if zlower or zupper:
      comlines.append(fmtstr('ZSkipLowerAndUpper {},{}', zlower, zupper))
   if extent:
      comlines.append('MinExtentToFit ' + str(extent))
      
   try:
      runcmd('findwarp -StandardInput', comlines, 'stdout')
      savestat = 0
   except ImodpyError:
      savestat = getLastExitStatus()

   if savestat and savestat != 2:
      sys.exit(1)
   if vectormodel:
      prnstr(" ")
      runcmd(fmtstr('patch2imod {} -n "{}" "{}" "{}"', clipsize, objname, residualfile,
                    vectormodel))
      prnstr('MATCHORWARP: Created ' + vectormodel)

   # If succeed, run warpvol
   if savestat == 0:
      prnstr(" ")
      if trial:
         prnstr("MATCHORWARP: Findwarp found a good warping")
         sys.exit(0)

      prnstr("MATCHORWARP: Findwarp found a good warping: next running Warpvol")
      prnstr(" ")
      sys.stdout.flush()
      comlines = volbase + ['TransformFile ' + warpfile];
      
      runcmd('warpvol -StandardInput', comlines, 'stdout')
      sys.exit(0)

except ImodpyError:
   exitFromImodError(progname)

prnstr(" ")
exitError("You need to get better patches, edit patches, or eliminate rows or columns")
