#!/usr/bin/env python
# restrictalign - script to restrict tilt alignment parameters based on number of bead
#
# Author: David Mastronarde
#
# $Id: restrictalign,v 97c143ed5bd4 2025/09/24 04:21:04 mast $
#

progname = 'restrictalign'
prefix = 'ERROR: ' + progname + ' - '

# A set of indices to the variable array
# MAKE SURE GROUP FOLLOWS OPT IN EVERY CASE
(ROT_OPT, ROT_GROUP, MAG_OPT, MAG_GROUP, TILT_OPT, TILT_GROUP, SKEW_OPT, SKEW_GROUP,
 XSTRETCH_OPT, XSTRETCH_GROUP, XTILT_OPT, XTILT_GROUP, BEAM_TILT_OPT, PROJ_STRETCH,  
 LOC_ROT_OPT, LOC_ROT_GROUP, LOC_MAG_OPT, LOC_MAG_GROUP, LOC_TILT_OPT, LOC_TIL_GROUP,
 LOC_XSTR_OPT, LOC_XSTR_GROUP, LOC_SKEW_OPT, LOC_SKEW_GROUP) = range(24)


# Tiltalign option values
TA_ONE_ROT = -1
TA_GROUP_ROT = 3
TA_ALL_ROT = 1
TA_ALL_MAG = 1
TA_GROUP_MAG = 3
TA_ALL_TILT = 2
TA_GROUP_TILT = 5
TA_GROUP_SKEW = 3
TA_GROUP_XSTRETCH = 3
TA_BEAM_TILT = 2
TA_LINEAR_XTILT = 3
TA_BLOCK_XTILT = 4

# Restrictions
RES_GROUP_ROTS = 1
RES_ONE_ROT = 2
RES_FIX_TILTS = 3
RES_GROUP_MAGS = 4
RES_FIX_MAGS = 5

# Order values for cross-validation
(CV_TEST_STRETCH, CV_TEST_XTILT, CV_TEST_TILT, CV_TEST_ROT, CV_TEST_MAG, 
 CV_TEST_SINGLES) = (1, 2, 3, 4, 5, 6)

# Compute number of variables for a grouped parameter
def groupedUnknowns(size):
   if size <= 1:
      return numViews
   return (numViews + max(1, size) - 1) // max(1, size) + 1

def prnFinal(strn):
   if finalChanges:
      prnstr(strn)

# Estimate the ratio of measurements to unknowns
def measuredToUnknown(param):
   measured = numPoints * 2.
   unknowns = 3. * (numBeads - 1) + 2. * (numViews - 1)
   if param[ROT_OPT] == TA_ALL_ROT:
      unknowns += numViews
   elif param[ROT_OPT] == TA_GROUP_ROT:
      unknowns += groupedUnknowns(param[ROT_GROUP])
   elif param[ROT_OPT] == TA_ONE_ROT:
      unknowns += 1
   if param[TILT_OPT] == TA_GROUP_TILT:
      unknowns += groupedUnknowns(param[TILT_GROUP])
   elif param[TILT_OPT] == TA_ALL_TILT:
      unknowns += numViews - 1
   if param[MAG_OPT] == TA_ALL_MAG:
      unknowns += numViews - 1
   elif param[MAG_OPT] == TA_GROUP_MAG:
      unknowns += groupedUnknowns(param[MAG_GROUP])
   if param[SKEW_OPT] == TA_GROUP_SKEW:
      unknowns += groupedUnknowns(param[SKEW_GROUP])
   if param[XSTRETCH_OPT] == TA_GROUP_XSTRETCH:
      unknowns += groupedUnknowns(param[XSTRETCH_GROUP])
   if param[XTILT_OPT] == TA_LINEAR_XTILT or param[XTILT_OPT] == TA_BLOCK_XTILT:
      unknowns += groupedUnknowns(param[XTILT_GROUP])
   if param[BEAM_TILT_OPT]:
      unknowns += 1
   if param[PROJ_STRETCH]:
      unknowns += 1
   return measured / unknowns


# Given an order entry, which can be partial, put it into the given order array and
# then test the full array
def fillAndTestOrder(orderArr, ordUse, name):
   for ind in range(min(len(orderArr), len(ordUse))):
      ordUse[ind] = orderArr[ind]
   for action in ordUse:
      if ordUse.count(action) > 1:
         exitError(fmtstr('{} {} is in the order list more than once', name, action))
      if action < 1 or action > len(ordUse):
         exitError(fmtstr('{} entry {} is outside the allowed range of 1 to {}',
                          name, action, len(ordUse)))


# Change one of the rotation, mag, or tilt and/or its grouping - worsk for local too
def setOptAndGrouping(param, optInd, groupOpt, variable, prefix):
   opt = param[optInd]
   didMess = False
   groupInd = optInd + 1
   if origParam[optInd] != opt:
      if not opt:
         prnFinal('Turned off solving for ' + variable)
      elif opt == groupOpt:
         prnFinal('Turned on grouping of ' + variable + 's to ' + str(param[groupInd]))
         didMess = True
      elif variable == 'rotation' and opt == TA_ONE_ROT:
         prnFinal('Switched to solving for one rotation')
      sedcom.append(sedModify(prefix + 'Option', opt))
   if opt == groupOpt and origParam[groupInd] != param[groupInd]:
      if not didMess:
         prnFinal('Changed ' + variable + ' grouping to ' + str(param[groupInd]))
      sedcom.append(sedModify(prefix + 'DefaultGrouping', param[groupInd]))


# Add items to the sed command to achieve the given parame
def buildUpSedcom(param, robustOff, required, areaOrNum):
   global sedcom
   sedcom = []
   crossValOpt = 1
   if localAlign and not testLocal and (not finalChanges or not crossValidate):
      if not crossValidate:
         prnFinal('Turned off local alignments')
      sedcom.append(sedModify('LocalAlignments', 0))
      
   # This is not allowed to happen
   if not localAlign and testLocal:
      prnFinal('Turned on local alignments')
      sedcom.append(sedModify('LocalAlignments', 1))

   # batchruntomo looking for 'off robust'
   if robustOff != '':
      prnFinal('Turned off robust fitting because ' + robustOff + '     [rsa2]')
      sedcom += sedDelAndAdd('RobustFitting', 0, 'OutputTransformFile')
   if origParam[XSTRETCH_OPT] and not param[XSTRETCH_OPT]:
      prnFinal('Turned off solving for X stretch')
      sedcom.append(sedModify('XStretchOption', 0))
      sedcom.append(sedModify('LocalXStretchOption', 0))
   if param[XSTRETCH_OPT] and (origParam[XSTRETCH_OPT] != param[XSTRETCH_OPT] or \
                               origParam[XSTRETCH_GROUP] != param[XSTRETCH_GROUP]):
      prnFinal('Set grouping of X stretch to ' + str(param[XSTRETCH_GROUP]))
      sedcom.append(sedModify('XStretchOption', param[XSTRETCH_OPT]))
      sedcom.append(sedModify('XStretchDefaultGrouping', param[XSTRETCH_GROUP]))
      
   if origParam[SKEW_OPT] and not param[SKEW_OPT]:
      prnFinal('Turned off solving for skew')
      sedcom.append(sedModify('SkewOption', 0))
      sedcom.append(sedModify('LocalSkewOption', 0))
   if param[SKEW_OPT] and (origParam[SKEW_OPT] != param[SKEW_OPT] or \
                               origParam[SKEW_GROUP] != param[SKEW_GROUP]):
      prnFinal('Set grouping of skew to ' + str(param[SKEW_GROUP]))
      sedcom.append(sedModify('SkewOption', param[SKEW_OPT]))
      sedcom.append(sedModify('SkewDefaultGrouping', param[SKEW_GROUP]))

   if origParam[PROJ_STRETCH] and not param[PROJ_STRETCH]:
      prnFinal('Turned off solving for projection stretch')
      sedcom.append('/ProjectionStretch/d')

   setOptAndGrouping(param, ROT_OPT, TA_GROUP_ROT, 'rotation', 'Rot')
   setOptAndGrouping(param, TILT_OPT, TA_GROUP_TILT, 'tilt angle', 'Tilt')
   setOptAndGrouping(param, MAG_OPT, TA_GROUP_MAG, 'magnification', 'Mag')

   setOptAndGrouping(param, LOC_ROT_OPT, TA_GROUP_ROT, 'local rotation', 'LocalRot')
   setOptAndGrouping(param, LOC_TILT_OPT, TA_GROUP_TILT, 'local tilt angle', 'LocalTilt')
   setOptAndGrouping(param, LOC_MAG_OPT, TA_GROUP_MAG, 'local magnification', 'LocalMag')
   setOptAndGrouping(param, LOC_XSTR_OPT, TA_GROUP_XSTRETCH, 'local X-stretch', 
                     'LocalXStretch')
   setOptAndGrouping(param, LOC_SKEW_OPT, TA_GROUP_SKEW, 'local skew', 'LocalSkew')
   
   if param[XTILT_OPT] and (origParam[XTILT_OPT] != param[XTILT_OPT] or \
                            origParam[XTILT_GROUP] != param[XTILT_GROUP]):
      prnFinal('Switched to solving for single X-tilt')
      sedcom.append(sedModify('XTiltOption', param[XTILT_OPT]))
      sedcom.append(sedModify('XTiltDefaultGrouping', param[XTILT_GROUP]))
      
   if origParam[XTILT_OPT] and not param[XTILT_OPT]:
      prnFinal('Turned off solving for X-axis tilt')
      sedcom.append(sedModify('XTiltOption', 0))

   beamOpt = param[BEAM_TILT_OPT]
   if (not origParam[BEAM_TILT_OPT] and beamOpt) or \
          origParam[BEAM_TILT_OPT] and origParam[BEAM_TILT_OPT] != beamOpt:
      if beamOpt:
         prnFinal('Added beam tilt solution because solving for only one rotation')
      else:
         prnFinal('Turned off solving for beam tilt')
      sedcom += sedDelAndAdd('BeamTiltOption', beamOpt, 'OutputTransformFile')

   if testLocalArea:
      crossValOpt = 2
      if required[0] != origRequired[0] or required[1] != origRequired[1]:
         reqStr = fmtstr('{},{}', required[0], required[1])
         prnFinal('Changed required # of fiducials to ' + reqStr);
         sedcom.append(sedModify('MinFidsTotalAndEachSurface', reqStr))
      if areaOrNum[0] != origAreaOrNum[0] or areaOrNum[1] != origAreaOrNum[1]:
         areaStr = fmtstr('{},{}', areaOrNum[0], areaOrNum[1])
         if targetSize:
            prnFinal('Changed target area size to  ' + areaStr)
            sedcom.append(sedModify('TargetPatchSizeXandY', areaStr))
         else:
            prnFinal('Changed number of local areas to  ' + areaStr)
            sedcom.append(sedModify('NumberOfLocalPatchesXandY', areaStr))

   # For a test, add the cross-validation option and the one to do contours
   if not finalChanges:
      sedcom += sedDelAndAdd('CrossValidate', crossValOpt, 'OutputTransformFile')
      if crossValidate > 1:
         sedcom += sedDelAndAdd('LeaveOutPredictAndPad', '0,0', 'OutputTransformFile')


# Do one run of tiltalign and get the leave-out error lines
# Returns an error array that is either one or 4 values
def runTiltalignExtractErrors(descrip, robust):
   global badRobust
   errors = []
   badRobust = False
   runLines = pysed(sedcom, taLines)
   warnLines = []
   tag = 'Global'
   if testLocal:
      tag = 'Local'
   #for line in runLines:
   #   prnstr(line)
      
   try:
      outLines = runcmd('tiltalign -StandardInput', runLines)
      for line in outLines:

         # Look for robust failure and turn off robust for further tests, pass on warning
         if 'too few' in line.lower() and 'robust fitting' in line.lower():
            badRobust = True
         elif line.startswith('WARNING'):
            warnLines.append(line.strip())
            if 'rotation angle' in line and 'closer to' in line:
               lsplit = line.split()
               angles = []
               for word in lsplit:
                  try:
                     value = float(word)
                     angles.append(value)
                  except ValueError:
                     pass

               # Tiltalign's criterion in 15 and it can pull in a correct angle with 30,
               # so make criterion for fail somewhat higher than 15
               if len(angles) == 2 and math.fabs(angles[1] - angles[0]) > 22.:
                  exitError('Fix the initial rotation angle as suggested by that ' + \
                            'warning before trying to optimize the parameters')
               

         # For a leave-out line, extract the values
         if tag in line and 'leave-out' in line:
            lsplit = line.split()
            for ind in range(len(lsplit) - 1):
               if lsplit[ind].endswith('):'):
                  errors.append(float(lsplit[ind + 1]))
               if lsplit[ind].startswith('w') and 'g' in lsplit[ind] and \
                  't' in lsplit[ind]:
                  errors.append(float(lsplit[ind + 1]))

   except ImodpyError:
      errStr = getErrStrings()
      prnstr('WARNING: Tiltalign failed with error with ' + descrip + robust + ':')
      for line in errStr:
         prnstr(line.strip().replace('ERROR:', '   (error):'))
      errors = [-2]
      if robust:
         errors = [-2, -2, -2, -2]
      return errors

   except ValueError:
      exitError('Converting ' + lsplit[ind + 1] + ' to floating point number')

   if not errors:
      exitError('Could not find leave-out errors in Tiltalign output')
   if badRobust and len(errors) > 1:
      errors.pop()
      if len(errors) > 1:
         errors.pop()

   if warnLines:
      prnstr('WARNING: Tiltalign gave warning with ' + descrip + robust + ':')
      for line in warnLines:
         prnstr('    ' + line)
   if badRobust:
      errors += [-1, -1, -1]
   return errors


# Does the sequence of making the sed command and running tiltalign (there were originally
# two runs) ,do the verbose output, and fill out the error array to 4
def doTiltalignRuns(param, descrip, required, areaOrNum):
   global doingRobust, sedcom, cumNonRobTime, cumRobustTime
   noRobust = ''
   if robustAlign and not doingRobust:
      noRobust = 'for eval'
   buildUpSedcom(param, noRobust, required, areaOrNum)
   startTime = time.time()
   errors = runTiltalignExtractErrors(descrip, '')
   cumNonRobTime += time.time() - startTime
   if len(errors) > 1 and not (robustAlign and doingRobust):
      exitError('Inconsistent output from Tiltalign: weighted error despite ' +\
                'robust fitting turned off')
   if robustAlign and doingRobust:
      if len(errors) < 2:
         exitError('Could not find weighted leave-out error in Tiltalign output with ' +\
                   'robust fitting')
      if errors[1] == -1:
         doingRobust = False

   if verbose:
      if errors[0] >= 0.5:
         if doingRobust:
            prnstr(fmtstr('{:.3f} {:.3f} {:.3f} {:.3f}: errors with {}', errors[0], 
                          errors[1], errors[2], errors[3], descrip), end = '')
         else:
            prnstr(fmtstr('{:.3f}: error with {}', errors[0], descrip), end = '')
      else:
         if doingRobust:
            prnstr(fmtstr('{:.4f} {:.4f} {:.4f} {:.4f}: errors with {}', errors[0], 
                          errors[1], errors[2], errors[3], descrip), end = '')
         else:
            prnstr(fmtstr('{:.4f}: error with {}', errors[0], descrip), end = '')

   if len(errors) < 4:
      for ind in range(len(errors), 4):
         errors.append(-1)
   return errors


# Do a comparison of the described type with the current best param
def compareNextParam(descrip):
   global newErrors, newParam, nextParam, nextRequired, nextAreaOrNum, newRequired
   global newAreaOrNum, lastErrDiff, lastErrors, prevLastErrs, prevErrDiff, curErrDiff
   global prevErrors, prevParam, prevRequired, prevAreaOrNum
   errors = doTiltalignRuns(nextParam, descrip, nextRequired, nextAreaOrNum)
   diff = -999
   prevLastErrs = copy.deepcopy(lastErrors)
   prevErrDiff = lastErrDiff
   lastErrDiff = -999
   if doingRobust and errors[3] > 0 and newErrors[3] > 0:
      diff = ((newErrors[2] - errors[2]) / newErrors[2] + 
              (newErrors[3] - errors[3]) / newErrors[3]) / 2.
   elif errors[0] > 0.:
      diff = (newErrors[0] - errors[0]) / newErrors[0]

   # Get difference from last error too, so that local area can count up unique
   # ones that are worse
   if doingRobust and errors[3] > 0 and lastErrors[3] > 0:
      lastErrDiff = ((lastErrors[2] - errors[2]) / lastErrors[2] + 
                     (lastErrors[3] - errors[3]) / lastErrors[3]) / 2.
      lastErrors = copy.deepcopy(errors)
   elif errors[0] > 0.:
      lastErrDiff = (lastErrors[0] - errors[0]) / lastErrors[0]
      lastErrors = copy.deepcopy(errors)

   # Give output if verbose or if it is better, do assignments to best if it is better
   if verbose:
      if diff > 0 or diff == -999:
         prnstr(' ')
      elif diff < 0:
         prnstr(fmtstr(' -  {:.2f}% higher', -100. * diff))
      else:
         prnstr(' -  the same')
   if diff > 0:

      # A new best one, copy to the "new" params and errors, and save the current "new"
      # params in prec so it is possible to revert
      curErrDiff = diff
      prevErrors = copy.deepcopy(newErrors)
      prevParam = copy.deepcopy(newParam)
      prevRequired = copy.deepcopy(newRequired)
      prevAreaOrNum = copy.deepcopy(newAreaOrNum)
      newErrors = copy.deepcopy(errors)
      newParam = copy.deepcopy(nextParam)
      newRequired = copy.deepcopy(nextRequired)
      newAreaOrNum = copy.deepcopy(nextAreaOrNum)
      prnstr(fmtstr('Leave-out error {:.2f}% lower with {}', diff * 100., descrip))
      return 1
   else:

      # If not better, set up nextParam as copy of current best
      nextParam = copy.deepcopy(newParam)
      nextRequired = copy.deepcopy(newRequired)
      nextAreaOrNum = copy.deepcopy(newAreaOrNum)
      if diff == 0:
         return 0
      return -1


# Assign the given set of state parameters to the "new" params to return to that state
def revertToStep(errDiff, lastErrs, errors, param, required, areaOrNum):
   global lastErrors, lastErrDiff, newErrors, newParam, newRequired, newAreaOrNum
   global nextParam, nextRequired, nextAreaOrNum
   lastErrors = copy.deepcopy(lastErrs)
   lastErrDiff = errDiff
   newErrors = copy.deepcopy(errors)
   newParam = copy.deepcopy(param)
   newRequired = copy.deepcopy(required)
   newAreaOrNum = copy.deepcopy(areaOrNum)
   nextParam = copy.deepcopy(newParam)
   nextRequired = copy.deepcopy(newRequired)
   nextAreaOrNum = copy.deepcopy(newAreaOrNum)


# Test GLOBAL STRETCH grouping and not solving
def cvTestStretch(takeOneStep):

   # Test with default grouping of stretch
   needGroupSkew = newParam[SKEW_OPT] and (newParam[SKEW_OPT] != TA_GROUP_SKEW or
                                           newParam[SKEW_GROUP] < dfltSkewGrouping)
   needGroupXstr = newParam[XSTRETCH_OPT] and \
                   (newParam[XSTRETCH_OPT] != TA_GROUP_XSTRETCH or
                    newParam[XSTRETCH_GROUP] < dfltXStretchGrouping)
   if (needGroupSkew or needGroupXstr) and wasTested[XSTRETCH_OPT] < 1:
      if needGroupSkew:
         nextParam[SKEW_OPT] = TA_GROUP_SKEW
         nextParam[SKEW_GROUP] = max(newParam[SKEW_GROUP], dfltSkewGrouping)
      if needGroupXstr:
         nextParam[XSTRETCH_OPT] = TA_GROUP_XSTRETCH
         nextParam[XSTRETCH_GROUP] = max(newParam[XSTRETCH_GROUP],dfltXStretchGrouping)
      wasTested[XSTRETCH_OPT] = 1
      if compareNextParam('standard grouping of X-stretch and skew') > 0 and takeOneStep:
         return 1
      
   # Test with large grouping or no stretch
   if ((newParam[XSTRETCH_OPT] and nextParam[XSTRETCH_GROUP] < bigGrouping) or \
      (newParam[SKEW_OPT] and nextParam[SKEW_GROUP] < bigGrouping)) and \
      wasTested[XSTRETCH_OPT] < 2:
      if newParam[SKEW_OPT]:
         nextParam[SKEW_OPT] = TA_GROUP_SKEW
         nextParam[SKEW_GROUP] = bigGrouping
      if newParam[XSTRETCH_OPT]:
         nextParam[XSTRETCH_OPT] = TA_GROUP_XSTRETCH
         nextParam[XSTRETCH_GROUP] = bigGrouping
      wasTested[XSTRETCH_OPT] = 2
      if compareNextParam('large grouping of X-stretch and skew') > 0 and takeOneStep:
         return 1

   if (nextParam[XSTRETCH_OPT] or nextParam[SKEW_OPT]) and wasTested[XSTRETCH_OPT] < 3:
      nextParam[SKEW_OPT] = 0
      nextParam[XSTRETCH_OPT] = 0
      wasTested[XSTRETCH_OPT] = 3
      if compareNextParam('not solving for X-stretch or skew') > 0:
         return 1

   return 0


# Test VARIABLE X-TILT
def cvTestXTilt():

   # Test not solving for many x-tilts
   if newParam[XTILT_OPT] and (newParam[XTILT_GROUP] < numViews or \
                               newParam[XTILT_OPT] == TA_LINEAR_XTILT) and\
      wasTested[XTILT_OPT] < 1:
      nextParam[XTILT_OPT] = TA_BLOCK_XTILT
      nextParam[XTILT_GROUP] = 2 * numViews
      wasTested[XTILT_OPT] = 1
      if compareNextParam('solving for only a single X-tilt') > 0:
         return 1

   return 0


# Test tilt angle grouping and fixing
def cvTestTilt(takeOneStep):

   # Check tilt with standard grouping
   if (newParam[TILT_OPT] == TA_ALL_TILT or \
       (newParam[TILT_OPT] == TA_GROUP_TILT and \
        newParam[TILT_GROUP] < minTiltGrouping)) and wasTested[TILT_OPT] < 1:
      nextParam[TILT_OPT] = TA_GROUP_TILT
      nextParam[TILT_GROUP] = max(nextParam[TILT_GROUP], minTiltGrouping)
      wasTested[TILT_OPT] = 1
      if compareNextParam('standard grouping of tilt') > 0 and takeOneStep:
         return 1

   # Then tilt with big group or not at all
   if newParam[TILT_OPT] == TA_GROUP_TILT and newParam[TILT_GROUP] < bigGrouping and \
      wasTested[TILT_OPT] < 2:
      nextParam[TILT_OPT] = TA_GROUP_TILT
      nextParam[TILT_GROUP] = bigGrouping
      wasTested[TILT_OPT] = 2
      if compareNextParam('large grouping of tilt') > 0 and takeOneStep:
         return 1

   if newParam[TILT_OPT] and wasTested[TILT_OPT] < 3:
      nextParam[TILT_OPT] = 0
      wasTested[TILT_OPT] = 3
      if compareNextParam('not solving for tilt') > 0:
         return 1

   return 0
      

# Test global ROTATION grouping, solving for one, and fixed
def cvTestRotation(takeOneStep):
   # rotation standard grouping
   if (newParam[ROT_OPT] == TA_ALL_ROT or \
       (newParam[ROT_OPT] == TA_GROUP_ROT and newParam[ROT_GROUP] < minRotGrouping)) and \
      wasTested[ROT_OPT] < 1:
      nextParam[ROT_OPT] = TA_GROUP_ROT
      nextParam[ROT_GROUP] = max(minRotGrouping, nextParam[ROT_GROUP])
      wasTested[ROT_OPT] = 1
      if compareNextParam('standard grouping of rotation') > 0 and takeOneStep:
         return 1

   # Rotation large grouping and solving for one
   if (newParam[ROT_OPT] == TA_ALL_ROT or \
      (newParam[ROT_OPT] == TA_GROUP_ROT and newParam[ROT_GROUP] < bigGrouping)) and \
      wasTested[ROT_OPT] < 2:
      nextParam[ROT_OPT] = TA_GROUP_ROT
      nextParam[ROT_GROUP] = bigGrouping
      wasTested[ROT_OPT] = 2
      if compareNextParam('large grouping of rotation') > 0 and takeOneStep:
         return 1
      
   if newParam[ROT_OPT] != TA_ONE_ROT and wasTested[ROT_OPT] < 3:
      nextParam[ROT_OPT] = TA_ONE_ROT
      wasTested[ROT_OPT] = 3
      if compareNextParam('solving for one rotation') > 0 and takeOneStep:
         return 1

   # Finally try fixed if solving for one
   if newParam[ROT_OPT] == TA_ONE_ROT and wasTested[ROT_OPT] < 4:
      nextParam[ROT_OPT] = 0
      wasTested[ROT_OPT] = 4
      if compareNextParam('rotation fixed at initial value') > 0 and takeOneStep:
         return 1

   if newParam[ROT_OPT] == TA_ONE_ROT and origParam[ROT_OPT] != TA_ONE_ROT and \
      not skipBeamTilt and numBeads >= minBeadsForBeamTilt and wasTested[ROT_OPT] < 5:
      nextParam[BEAM_TILT_OPT] = 2
      wasTested[ROT_OPT] = 5
      if compareNextParam('solving for beam tilt when solving for only one tilt') > 0:
         return 1

   return 0


# Test GLOBAL MAGNIFICATION grouping and fixing
def cvTestMagnification(takeOneStep):
   # Magnification standard grouping
   if (newParam[MAG_OPT] == TA_ALL_MAG or \
       (newParam[MAG_OPT] == TA_GROUP_MAG and newParam[MAG_GROUP] < minMagGrouping)) and \
      wasTested[MAG_OPT] < 1:
      nextParam[MAG_OPT] = TA_GROUP_MAG
      nextParam[MAG_GROUP] = max(minMagGrouping, nextParam[MAG_GROUP])
      wasTested[MAG_OPT] = 1
      if compareNextParam('standard grouping of magnification') > 0 and takeOneStep:
         return 1

   # Magnification large grouping and fixed
   if (newParam[MAG_OPT] == TA_ALL_MAG or \
       (newParam[MAG_OPT] == TA_GROUP_MAG and newParam[MAG_GROUP] < bigGrouping)) and \
      wasTested[MAG_OPT] < 2:
      nextParam[MAG_OPT] = TA_GROUP_MAG
      nextParam[MAG_GROUP] = bigGrouping
      wasTested[MAG_OPT] = 2
      if compareNextParam('large grouping of magnification') > 0 and takeOneStep:
         return 1
   if nextParam[MAG_OPT] and wasTested[MAG_OPT] < 3:
      nextParam[MAG_OPT] = 0
      wasTested[MAG_OPT] = 3
      if compareNextParam('not solving for magnification') > 0:
         return 1

   return 0


# Test BEAM TILT, single X-TILT, PROJECTION STRETCH
def cvTestSingleVars(takeOneStep):

   # Single-variable items
   if newParam[BEAM_TILT_OPT] and wasTested[BEAM_TILT_OPT] < 1:
      nextParam[BEAM_TILT_OPT] = 0
      wasTested[BEAM_TILT_OPT] = 1
      if compareNextParam('not solving for beam tilt') > 0 and takeOneStep:
         if takeOneStep > 1:
            return BEAM_TILT_OPT
         return 1

   if newParam[XTILT_OPT] and wasTested[XTILT_OPT] < 2:
      nextParam[XTILT_OPT] = 0
      wasTested[XTILT_OPT] = 2
      if compareNextParam('not solving for x-axis tilt') > 0 and takeOneStep:
         if takeOneStep > 1:
            return XTILT_OPT
         return 1

   if newParam[PROJ_STRETCH] and wasTested[PROJ_STRETCH] < 1:
      nextParam[PROJ_STRETCH] = 0
      wasTested[PROJ_STRETCH] = 1
      if compareNextParam('not solving for projection stretch') > 0:
         if takeOneStep > 1:
            return PROJ_STRETCH
         return 1

   return 0


# Test LOCAL STRETCH
def cvTestLocalStretch(takeOneStep):

   # Stretch/skew is a pain, first test grouping if none, then increase it, then
   # turn it off
   defSkew = max(dfltSkewGrouping, newParam[LOC_SKEW_GROUP])
   defStretch = max(dfltXStretchGrouping, newParam[LOC_XSTR_GROUP])
   mediumSkew = min(defSkew * 2, (defSkew + bigGrouping) // 2)
   mediumStretch = min(defStretch * 2, (defStretch + bigGrouping) // 2)

   for (grpSkew, grpStretch, testVal, typeText) in \
       ((defSkew, defStretch, 1, ''), (mediumSkew, mediumStretch, 2, 'more '),
        (bigGrouping, bigGrouping, 3, 'large ')):
      needGroupSkew = newParam[LOC_SKEW_OPT] and \
                      (newParam[LOC_SKEW_OPT] != TA_GROUP_SKEW or \
                       newParam[LOC_SKEW_GROUP] < grpSkew)
      needGroupXstr = newParam[LOC_XSTR_OPT] and \
                      (newParam[LOC_XSTR_OPT] != TA_GROUP_XSTRETCH or
                       newParam[LOC_XSTR_GROUP] < grpStretch)
                                          
      if (needGroupXstr or needGroupSkew) and wasTested[LOC_XSTR_OPT] < testVal:
         if needGroupXstr:
            nextParam[LOC_XSTR_OPT] = TA_GROUP_XSTRETCH
            nextParam[LOC_XSTR_GROUP] = grpStretch
         if needGroupSkew:
            nextParam[LOC_SKEW_OPT] = TA_GROUP_SKEW
            nextParam[LOC_SKEW_GROUP] = grpSkew
         wasTested[LOC_XSTR_OPT] = testVal
         if compareNextParam(typeText + 'grouping of local X-stretch and skew') > 0 and \
            takeOneStep:
            return 1

   # Turn off both variables
   if (nextParam[LOC_XSTR_OPT] or nextParam[LOC_SKEW_OPT]) and \
      wasTested[LOC_XSTR_OPT] < 4:
      nextParam[LOC_XSTR_OPT] = 0
      nextParam[LOC_SKEW_OPT] = 0
      wasTested[LOC_XSTR_OPT] = 4
      if compareNextParam('not solving for local X-stretch and skew') > 0:
         return 1

   return 0

   
# TEST LOCAL TILT, MAGNIFICATION, or ROTATION
def cvTestLocalTiltRotMag(opt, groupOpt, label, takeOneStep):

   defGroup = newParam[opt + 1]
   mediumGroup = min(defGroup * 2, (defGroup + bigGrouping) // 2)
   
   for (group, testVal, typeText) in ((defGroup, 1, ''), (mediumGroup, 2, 'more '),
                                      (bigGrouping, 3, 'large ')):

      # Test grouping if not, or increase it
      if newParam[opt] and (newParam[opt] != groupOpt or newParam[opt + 1] < group) and \
         wasTested[opt] < testVal:
         nextParam[opt] = groupOpt
         nextParam[opt + 1] = group
         wasTested[opt] = testVal
         if compareNextParam(typeText + 'grouping of local ' + label) > 0 and \
            takeOneStep:
            return 1

   if nextParam[opt] and wasTested[opt] < 4:
      nextParam[opt] = 0
      wasTested[opt]  = 4
      if compareNextParam('not solving for local ' + label) > 0:
         return 1

   return 0


# For a patch tracking model, determine number of full tracks and unique points
def getNumUnchoppedConts(modelFile):
   try:
      modLines = runcmd('imodinfo -a "' + modelFile + '"')
   except ImodpyError:
      exitFromImodError(progname)
   try:
      prevCont = []
      trueBeads = 0
      truePoints = 0
      ind = 0

      # loop on lines to find contours
      while ind < len(modLines):
         line = modLines[ind]
         ind += 1
         if line.startswith('contour '):

            # Found a contour, make sure it can all be read
            lsplit = line.split()
            if len(lsplit) != 4:
               continue;
            numPts = int(lsplit[3])
            if ind + numPts > len(modLines):
               exitError('Output from imodinfo -a to analyze for chopped contours is ' +\
                         'truncated')

            # Shallow copy the lines now
            newCont = modLines[ind:ind + numPts]
            truePoints += numPts
            if prevCont:

               # Look for duplicate point in previous contour, if found, subtract from
               # total points
               for jnd in range(len(prevCont)):
                  point = prevCont[jnd]
                  if point == newCont[0]:
                     truePoints -= len(prevCont) - jnd
                     break
               else:   #ELSE ON FOR
                  trueBeads += 1

            else:
               trueBeads += 1

            prevCont = copy.deepcopy(newCont)
            ind += numPts

      return (trueBeads, truePoints)

   except ValueError:
      exitError('Converting # of points to integer in output from imodinfo -a to ' +\
                'analyze for chopped contours')
      

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, copy, math, time, itertools

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *

# Fallbacks from ../manpages/autodoc2man 3 1 restrictalign
options = ["align:AlignCommandFile:FN:", "fiducials:NumberOfFiducials:I:",
                      "views:NumberOfViews:I:", "target:TargetMeasurementRatio:F:",
                      "minimum:MinMeasurementRatio:F:", "cross:UseCrossValidation:I:",
                      "local:LocalAlignValidation:I:", "benefit:MinRobustBenefit:F:",
                      "order:OrderOfRestrictions:IA:", "cvorder:CrossValTestOrder:IA:",
                      "onestep:OneStepPerVariableTest:I:", "permute:TestPermutations:IA:",
                      "skipbeam:SkipBeamTiltWithOneRot:B:", "trial:TrialMode:B:",
                      "verbose:VerboseOutput:B:", ":PID:B:", "help:usage:B:"]

# default in adoc = 1, 4, 3, 2, 5 will override this
order = [RES_GROUP_ROTS, RES_GROUP_MAGS, RES_FIX_TILTS, RES_ONE_ROT, RES_FIX_MAGS]
cvOrder = [CV_TEST_STRETCH, CV_TEST_XTILT, CV_TEST_TILT, CV_TEST_ROT, CV_TEST_MAG, 
         CV_TEST_SINGLES]
minMagGrouping = 4
minRotGrouping = 5
minTiltGrouping = 5
dfltSkewGrouping = 11
dfltXStretchGrouping = 7
finalChanges = False
minBeadsForCVOnly = 5
minBeadsForBeamTilt = 4
oneStepPerVar = 1

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 0)

doPID = PipGetBoolean('PID', 0)
printPID(doPID)

comfile = PipGetInOutFile('AlignCommandFile', 0)
if not comfile:
   exitError('The name of the Tiltalign command file must be entered')

comLines = readTextFile(comfile)

# Cross-validation
crossValidate = PipGetInteger('UseCrossValidation', 0)
testLocal = PipGetInteger('LocalAlignValidation', 0)
minBenefitForEval = PipGetFloat('MinRobustBenefit', 2.)
testLocalArea = testLocal > 1
testLocalVars = testLocal > 0 and testLocal != 2

# Get the order of actions and check it
if crossValidate:
   if os.getenv('TILTALIGN_SKIP_CROSS_VAL'):
      del os.environ['TILTALIGN_SKIP_CROSS_VAL']
   orderArr = PipGetIntegerArray('CrossValTestOrder', 0)
   if not PipGetErrNo():
      fillAndTestOrder(orderArr, cvOrder, "Variable")
else:
   orderArr = PipGetIntegerArray('OrderOfRestrictions', 0)
   fillAndTestOrder(orderArr, order, "Restriction")

permuteArr = PipGetIntegerArray('TestPermutations', 0)
doPermute = 1 - PipGetErrNo()
if doPermute:
   oneStepPerVar = 0
oneStepPerVar = PipGetInteger('OneStepPerVariableTest', oneStepPerVar)

if doPermute:
   if not crossValidate:
      exitError('Permutations can be done only with cross-validation')
   if oneStepPerVar:
      exitError('Permutations cannot be done with taking one step per variable')
   for action in permuteArr:
      if permuteArr.count(action) > 1:
         exitError(fmtstr('Variable {} is in the permutation list more than once', 
                          action))
      if action < 1 or action > len(cvOrder):
         exitError(fmtstr('Permutation entry {} is outside the allowed range of 1 to {}',
                          action, len(cvOrder)))

numViews = PipGetInteger('NumberOfViews', -1)
if not PipGetErrNo() and numViews < 1:
   exitError('The number of views entered must be positive')

numBeads = PipGetInteger('NumberOfFiducials', -1)
if not PipGetErrNo() and numBeads < 1:
   exitError('The number of beads entered must be positive')
targetRatio = PipGetFloat('TargetMeasurementRatio', 3.6)   # default in adoc
minRatio = PipGetFloat('MinMeasurementRatio', 3.2)         # default in adoc
skipBeamTilt = PipGetBoolean('SkipBeamTiltWithOneRot', 0)

# Other options
trialMode = PipGetBoolean('TrialMode', 0)
verbose = PipGetBoolean('VerboseOutput', 0)
modelFile = optionValue(comLines, 'ModelFile', STRING_VALUE)

# Get the model file and number of beads from the model file
numPoints = 0
patchTrack = False
if numBeads < 0:
   if not modelFile:
      exitError('The number of fiducials was not entered and a ModelFile entry ' +\
                   'cannot be found in ' + comfile)
   if not os.path.exists(modelFile):
      exitError('The number of fiducials was not entered and the model file ' + modelFile
                + ' does not exist')
   try:
      infoLines = runcmd('imodinfo "' + modelFile + '"')
   except ImodpyError:
      exitError('The number of fiducials was not entered and an error occurred ' +\
                   'running imodinfo on the model file ' + modelFile)

   numBeads = 0
   conMatch = re.compile(r'^\s*CONTOUR\s#\S*\s*([0-9]*)\s*points.*')
   for line in infoLines:
      if line.startswith('# NAME') and 'Patch Tracking Model' in line:
         patchTrack = True
      if re.match(conMatch, line):
         contPoints = convertToInteger(re.sub(conMatch, r'\1', line),
                                       'number or contour points in imodinfo output')
         if contPoints > 1:
            numBeads += 1
            numPoints += contPoints

   if not numBeads:
      exitError('The number of fiducials was not entered and the model file ' + modelFile
                + ' has no contours with more than one point')

# Determine the current state of the parameters
origParam = 25 * [None]

for opts in ((ROT_OPT, 'Rot'), (MAG_OPT, 'Mag'), (TILT_OPT, 'Tilt'), (SKEW_OPT, 'Skew'),
             (XSTRETCH_OPT, 'XStretch'), (XTILT_OPT, 'XTilt'), (LOC_ROT_OPT, 'LocalRot'),
             (LOC_MAG_OPT, 'LocalMag'), (LOC_TILT_OPT, 'LocalTilt'), 
             (LOC_SKEW_OPT, 'LocalSkew'), (LOC_XSTR_OPT, 'LocalXStretch')):
   (opt, prefix) = opts
   origParam[opt] = optionValue(comLines, prefix + 'Option', INT_VALUE, numVal = 1)
   origParam[opt + 1] = optionValue(comLines, prefix + 'DefaultGrouping', INT_VALUE,
                                    numVal = 1)

   #if origParam[opt] == None and (not key.startswith('Local') or testLocalVars):
   #   exitError('Option ' + key + 'Option is missing from command file')
   #if origParam[opt + 1] == None and (not key.startswith('Local') or testLocalVars):
   #   exitError('Option ' + key + 'DefaultGrouping is missing from command file')
   
origParam[BEAM_TILT_OPT] = optionValue(comLines, 'BeamTiltOption', INT_VALUE, numVal = 1)
origParam[PROJ_STRETCH] = optionValue(comLines, 'ProjectionStretch', BOOL_VALUE)
localAlign = optionValue(comLines, 'LocalAlignments', BOOL_VALUE)
robustAlign = optionValue(comLines, 'RobustFitting', BOOL_VALUE)
imageFile = optionValue(comLines, 'ImageFile', STRING_VALUE)
origRequired = newRequired = [0, 0]
origAreaOrNum = newAreaOrNum = [0, 0]
robustOffForEval = False
robustOff = ''
if testLocal and not localAlign:
   exitError('Local alignments must be turned on to test them')

if testLocal and doPermute:
   doPermute = 0
   prnstr('WARNING: permutations are not tested with local alignments')
   
if doPermute:

   # For permutations, remove items from array that are not being solved for
   if CV_TEST_STRETCH in permuteArr and origParam[XSTRETCH_OPT] == 0 and  \
      origParam[SKEW_OPT] == 0:
      permuteArr.remove(CV_TEST_STRETCH)

   for (opt, var) in ((ROT_OPT, CV_TEST_ROT), (TILT_OPT, CV_TEST_TILT),
                      (MAG_OPT, CV_TEST_MAG)):
      if var in permuteArr and origParam[opt] == 0:
         permuteArr.remove(var)

   if len(permuteArr) < 2:
      doPermute = 0
   else:

      # Make cross-index from variable in order array to position in permutation array
      permuteInds = []
      for action in permuteArr:
         for ind in range(len(cvOrder)):
            if action == cvOrder[ind]:
               permuteInds.append(ind)
               break

   permuteList = itertools.permutations(permuteArr)


# Get the needed options when testing local area size
if testLocalArea:
   origRequired = optionValue(comLines, 'MinFidsTotalAndEachSurface', INT_VALUE)
   targetSize = optionValue(comLines, 'TargetPatchSizeXandY', INT_VALUE)
   numAreas = optionValue(comLines, 'NumberOfLocalPatchesXandY', INT_VALUE)
   imageSize = optionValue(comLines, 'ImageSizeXandY', INT_VALUE)
   if not origRequired or len(origRequired) < 2:
      exitError('Option for required number of fiducials in local areas not found ' +\
                'or has only one value')
   if origRequired[0] + 1 > numBeads // 2:
      exitError('There are not enough beads to evaluate areas by requiring more beads')
   if targetSize and numAreas:
      exitError('Command file has both TargetPatchSizeXandY and ' + \
                'NumberOfLocalPatchesXandY options')
   if not (targetSize or numAreas):
      exitError('Command file has neither TargetPatchSizeXandY nor ' + \
                'NumberOfLocalPatchesXandY option')
   if targetSize:
      if len(targetSize) < 2:
         exitError('TargetPatchSizeXandY entry in command file has only one value')
      origAreaOrNum = targetSize

   if numAreas:
      if len(numAreas) < 2:
         exitError('NumberOfLocalPatchesXandY entry in command file has only one value')
      origAreaOrNum = numAreas
   
   newRequired = copy.deepcopy(origRequired)
   newAreaOrNum = copy.deepcopy(origAreaOrNum)

# Get the image file and number of views from it if not entered
# Or get the image size if that is needed
# Or fall back to ImageSize Entry
# or fall back to the imodinfo max values
needSize = testLocalArea and targetSize
imSizeOK = not needSize or (imageSize and len(imageSize) > 2)
if numViews < 0 or needSize:
   mess = 'Need to determine '
   if numViews < 0:
      mess+= '# of views '
   if needSize and not imSizeOK:
      if numViews < 0:
         mess += 'and '
      mess += 'image size'
   mess += '; '

   imageFile = optionValue(comLines, 'ImageFile', STRING_VALUE)
   fullNx = 0
   headerFailed = False
   if imageFile and os.path.exists(imageFile):
      try:
         (fullNx, fullNy, nz) = getmrcsize(imageFile)
         if numViews < 0:
            numViews = nz
      except ImodpyError:
         headerFailed = True
         mess += 'an error occurred running header on image file ' + imageFile + ', '

   elif not imageFile:
      mess += 'there is no ImageFile entry in ' + comfile + ', '
   else:
      mess += 'image file ' + imageFile + ' does not exist, '

   if not fullNx and needSize and imageSize and len(imageSize) > 2:
      fullNx = imageSize[0]
      fullNy = imageSize[1]
      needSize = False;

   if numViews < 0 or needSize:
      if not modelFile:
         mess += 'and a ModelFile entry cannot be found in ' + comfile
      else:
         try:
            modLines = runcmd('imodinfo -a "' + modelFile + '"')
            for line in modLines:
               if line.startswith('max'):
                  lsplit = line.split()
                  if len(lsplit) >= 4:
                     if needSize:
                        fullNx = int(lsplit[1])
                        fullNy = int(lsplit[2])
                        needSize = False
                     if numViews < 0:
                        numViews = int(lsplit[3])
                  break
            if numViews < 0 or needSize:
               'and the "max" line could not be found in imodinfo output on ' + modelFile
         except ImodpyError:
            mess += 'and there was an error running imodinfo on ' + modelFile
         except ValueError:
            mess += 'and there was an error converting max values from the model file'

   if numViews < 0 or needSize:
      exitError(mess)

# If patch tracking, get true number of contours
if patchTrack:
   (numBeads, numPoints) = getNumUnchoppedConts(modelFile)

# if no model, now set number of points assuming a complete model
if not numPoints:
   numPoints = numBeads * numViews
      
if verbose and numPoints:
   if patchTrack:
      prnstr(fmtstr('{} full tracks, {} unique points, {:.1f} average points per view',
                    numBeads, numPoints, numPoints / float(numViews)))
   else:
      prnstr(fmtstr('{} beads, {} total points, {:.1f} average points per view', numBeads,
                    numPoints, numPoints / float(numViews)))

# modify com lines to use ImageSize if image file doesn't exist
if imageFile and not os.path.exists(imageFile) and \
   optionValue(comLines, '#ImageSizeXandY', INT_VALUE):
   tempLines = []
   for line in comLines:
      if line.startswith('ImageFile'):
         continue
      if line.startswith('#ImageSizeXnadY'):
         line = line[1:]
      tempLines.append(line)

   comLines = tempLines

# Copy parameter and change None entry to 0
newParam = copy.deepcopy(origParam)
for ind in range(len(origParam)):
   if newParam[ind] == None:
      newParam[ind] = 0

oneBead = numBeads == 1
boostedRatio = False

newRatio = measuredToUnknown(newParam)
if verbose:
   prnstr(fmtstr('Original estimated ratio of measurements to unknowns: {:.2f}', 
                 newRatio))

# batchruntomo looking for 'No restriction'
if newRatio >= targetRatio and not crossValidate:
   prnstr(progname + ': No restriction of parameters needed')
   sys.exit(0)

if newRatio < targetRatio and (not crossValidate or numBeads < minBeadsForCVOnly):
   for orderInd in range(-1, len(order)):
      nextParam = copy.deepcopy(newParam)

      # Turn off hard variables on first round
      if orderInd < 0:
         if newParam[SKEW_OPT]:
            nextParam[SKEW_OPT] = 0
         if newParam[XSTRETCH_OPT]:
            nextParam[XSTRETCH_OPT] = 0
         if newParam[XTILT_OPT] and (newParam[XTILT_GROUP] < numViews or numBeads < 3):
            nextParam[XTILT_OPT] = 0
         if newParam[TILT_OPT] == TA_ALL_TILT \
                or (newParam[TILT_OPT] == TA_GROUP_TILT \
                       and newParam[TILT_GROUP] < minTiltGrouping):
            nextParam[TILT_OPT] = TA_GROUP_TILT
            nextParam[TILT_GROUP] = max(nextParam[TILT_GROUP], minTiltGrouping)
         restrict = 0
      else:
         restrict = order[orderInd]

      # Handle switching to grouped rots of minimum group size and to one rot
      if restrict == RES_GROUP_ROTS and (newParam[ROT_OPT] == TA_ALL_ROT or \
                                          (newParam[ROT_OPT] == TA_GROUP_ROT and \
                                              newParam[ROT_GROUP] < minRotGrouping)):
         nextParam[ROT_OPT] = TA_GROUP_ROT
         nextParam[ROT_GROUP] = max(minRotGrouping, nextParam[ROT_GROUP])
      elif restrict == RES_ONE_ROT and newParam[ROT_OPT] > 0:
         nextParam[ROT_OPT] = TA_ONE_ROT
         if not skipBeamTilt and numBeads >= minBeadsForBeamTilt:
            nextParam[BEAM_TILT_OPT] = 2

      # Handle fixing tilts
      if restrict == RES_FIX_TILTS or oneBead:
         nextParam[TILT_OPT] = 0

      # Handle fixing mags or grouping them
      if restrict == RES_FIX_MAGS or oneBead:
         nextParam[MAG_OPT] = 0
      elif restrict == RES_GROUP_MAGS and (newParam[MAG_OPT] == TA_ALL_MAG or \
                                            (newParam[MAG_OPT] == TA_GROUP_MAG and \
                                                newParam[MAG_GROUP] < minMagGrouping)):
         nextParam[MAG_GROUP] = minMagGrouping
         nextParam[MAG_OPT] = TA_GROUP_MAG

      # Fix everything else if one bead; skip beam tilt and projection stretch for 2 or 3
      if oneBead:
         nextParam[ROT_OPT] = 0
      if numBeads < minBeadsForBeamTilt:
         nextParam[PROJ_STRETCH] = 0
         nextParam[BEAM_TILT_OPT] = 0

      # Get the ratio on the next restriction and see if it is good enough or if there is
      # just one bead
      nextRatio = measuredToUnknown(nextParam)
      if oneBead or nextRatio >= targetRatio:

         # Adopt the next parameter set if one bead, or last ratio below the minimum, or
         # the next one is closer to target
         if oneBead or newRatio < minRatio or \
                math.fabs(newRatio - targetRatio) > math.fabs(nextRatio - targetRatio):
            newParam = copy.deepcopy(nextParam)
            newRatio = nextRatio
         break

      # Otherwise shift the next set into the "new" set for the next iteration
      newParam = copy.deepcopy(nextParam)
      newRatio = nextRatio

   if robustAlign and newRatio < minRatio:
      robustOff = 'ratio of measurements to unknowns is too low'

   if crossValidate and not oneBead:
      prnstr('Initial changes were made to boost ratio of measurements to unknowns:')
      finalChanges = True
      boostedRatio = True
      buildUpSedcom(newParam, robustOff, newRequired, newAreaOrNum)
      finalChanges = False


if crossValidate and numBeads > 1:

   # Extract the input to tiltalign 
   taLines = []
   gotTA = False
   for line in comLines:
      if line.startswith('$'):
         if gotTA:
            break
         elif 'tiltalign' in line and '-St' in line:
            gotTA = True
      elif gotTA:
         taLines.append(line)

   if not gotTA:
      exitError('Could not find input to Tiltalign in command file')

   # Turn off ridiculous ones with 2 or 3 beads
   if numBeads < minBeadsForBeamTilt:
      newParam[SKEW_OPT] = 0
      newParam[XSTRETCH_OPT] = 0
      newParam[XTILT_OPT] = 0
      newParam[PROJ_STRETCH] = 0
      newParam[BEAM_TILT_OPT] = 0

   # Get baseline run
   cumNonRobTime = 0.
   cumRobustTime = 0.
   doingRobust = robustAlign and newRatio >= minRatio
   robustOrig = doingRobust
   newErrors = doTiltalignRuns(newParam, 'initial values', newRequired, newAreaOrNum)

   # Check if robust was bad and whether benefit is worth evaluating with
   if doingRobust:
      if badRobust:
         prnstr('Turning off robust alignments for evaluation; it failed with initial ' +\
                'values')
         doingRobust = False
      else:
         benefitOrig = 100. * (newErrors[1] - newErrors[3]) / newErrors[1]
         if benefitOrig < minBenefitForEval:
            if verbose:
               prnstr(' ')
            if benefitOrig < 0:
               prnstr('Turning off robust alignments for evaluation; it has negative ' + \
                      'benefit')
            else:
               prnstr(fmtstr('Turning off robust alignments for evaluation; the ' + \
                             'benefit is only {:.1f}%', benefitOrig))
            doingRobust = False

      if not doingRobust:
         robustOffForEval = True
         newErrors = doTiltalignRuns(newParam, 'new initial values', newRequired, 
                                     newAreaOrNum)

   if newErrors[0] < 0 or newErrors[1] == -2:
      exitError('Tiltalign failed on runs with initial parameters')
   if verbose:
      prnstr('')

   bigGrouping = numViews // 2
   origErrors = copy.deepcopy(newErrors)
   paramForRestart = copy.deepcopy(newParam)
   lastErrors = copy.deepcopy(newErrors)
   lastErrDiff = 0
   nextParam = copy.deepcopy(newParam)
   nextRequired = copy.deepcopy(newRequired)
   nextAreaOrNum = copy.deepcopy(newAreaOrNum)
   lastRequired = newRequired[0]
   numSameMin = 0

   wasTested = 25 * [0]
   changed = 1
   areasFinished = False
   ordInd = 0
   while changed:
      changed =  0

      # Loop on local tests twice to allow order to be tested
      for localLoop in (0, 1):
         if testLocalArea and ((localLoop and testLocal > 3) or \
                               (not localLoop and testLocal < 4)):

            # Test local area required number and size
            while not areasFinished:

               # Step the required numbers up and increase size or drop number maybe
               # Terminate only when bead number is too high, just fix the area size or
               # number if that reaches a limit
               nextRequired[0] = max(lastRequired + 1, int(round(1.1 * lastRequired)))
               if nextRequired[0] > numBeads // 2:
                  nextRequired[0] = numBeads // 2
                  if nextRequired[0] == lastRequired:
                     areasFinished = True
                     break

               lastRequired = nextRequired[0]
               ratio = nextRequired[0] / float(origRequired[0])
               nextRequired[1] = int(round(origRequired[1] * ratio))
               ratio = math.sqrt(ratio)

               if targetSize:
                  lastAreaOrNum = copy.deepcopy(nextAreaOrNum)
                  nextAreaOrNum[0] = 5 * int(round(origAreaOrNum[0] * ratio / 5))
                  nextAreaOrNum[1] = 5 * int(round(origAreaOrNum[1] * ratio / 5))
                  if nextAreaOrNum[0] * nextAreaOrNum[1] > 0.45 * fullNx * fullNy:
                     nextAreaOrNum = copy.deepcopy(lastAreaOrNum)

               else:
                  nextAreaOrNum[0] = int(math.ceil(origAreaOrNum[0] / ratio))
                  nextAreaOrNum[1] = int(math.ceil(origAreaOrNum[1] / ratio))
                  if nextAreaOrNum[0] < 2 and nextAreaOrNum[1] < 2:
                     if fullNx > fullNy:
                        nextAreaOrNum[0] = 2
                        nextAreaOrNum[1] = 1
                     else:
                        nextAreaOrNum[0] = 1
                        nextAreaOrNum[1] = 2

               better = compareNextParam(fmtstr('area requirements {},{} and {},{}', 
                                                nextRequired[0], nextRequired[1],
                                                nextAreaOrNum[0], nextAreaOrNum[1]))

               # Continue until two unique results that are higher than the best one
               if better > 0:
                  numSameMin = 0
                  changed = 1
                  if oneStepPerVar:
                     break

               if better < 0 and lastErrDiff != 0.:
                  numSameMin += 1
                  if numSameMin > 1:
                     areasFinished = True
                     break

         # Test local variables
         if testLocalVars and ((localLoop and testLocal < 4) or \
                               (not localLoop and testLocal > 3)):
            for loop in range(len(cvOrder)):
               varToTest = cvOrder[ordInd]
               ordInd += 1
               if ordInd >= len(cvOrder):
                  ordInd = 0
               if varToTest == CV_TEST_STRETCH:
                  changed += cvTestLocalStretch(oneStepPerVar)
                  if changed and oneStepPerVar:
                     break

               if varToTest == CV_TEST_TILT:
                  changed += cvTestLocalTiltRotMag(LOC_TILT_OPT, TA_GROUP_TILT, 'tilt', 
                                                   oneStepPerVar)
                  if changed and oneStepPerVar:
                     break

               if varToTest == CV_TEST_ROT:
                  changed += cvTestLocalTiltRotMag(LOC_ROT_OPT, TA_GROUP_ROT, 'rotation', 
                                                   oneStepPerVar)
                  if changed and oneStepPerVar:
                     break

               if varToTest == CV_TEST_MAG:
                  changed += cvTestLocalTiltRotMag(LOC_MAG_OPT, TA_GROUP_MAG,
                                                   'magnification', oneStepPerVar)
                  if changed and oneStepPerVar:
                     break

      # This breaks the while loop not the loop on which tests to do
      if not oneStepPerVar or not changed:
         break

         #print cumNonRobTime, cumRobustTime


   # Test permutations
   if doPermute:
      bestDiff = 10000.
      for permutation in permuteList:

         # Load the permutation into the order list ans start from original state
         for ind in range(len(permutation)):
            cvOrder[permuteInds[ind]] = permutation[ind]

         newErrors = copy.deepcopy(origErrors)
         newParam = copy.deepcopy(paramForRestart)
         lastErrors = copy.deepcopy(newErrors)
         lastErrDiff = 0
         nextParam = copy.deepcopy(newParam)
         wasTested = 25 * [0]

         ordText = ''
         for ordInd in range(len(cvOrder)):
            ordText += ' ' + str(cvOrder[ordInd])
            
         # Do classic full test of each variable in turn
         for ordInd in range(len(cvOrder)):
            varToTest = cvOrder[ordInd]
            if varToTest == CV_TEST_STRETCH:
               result = cvTestStretch(0)

            if varToTest == CV_TEST_XTILT:
               result = cvTestXTilt()
               
            if varToTest == CV_TEST_TILT:
               result = cvTestTilt(0)

            if varToTest == CV_TEST_ROT:
               result = cvTestRotation(0)

            if varToTest == CV_TEST_MAG:
               result = cvTestMagnification(0)

            if varToTest == CV_TEST_SINGLES:
               result = cvTestSingleVars(0)

         # Keep track of best one
         if newErrors[2] > 0 and origErrors[2] > 0:
            diff = ((newErrors[1] - origErrors[1]) / origErrors[1] +
                    (newErrors[2] - origErrors[2]) / origErrors[2]) / 2.
         else:
            diff = (newErrors[0] - origErrors[0]) / origErrors[0]
         if diff < 0:
            prnstr(fmtstr('Permutation' + ordText + ' reduced leave-out error by {:.1f}%',
                          -diff * 100.))
         else:
            prnstr('Permutation' + ordText + ' did not reduce leave-out error')

         if diff < bestDiff:
            bestPrevErrDiff = prevErrDiff
            bestLastErrors = copy.deepcopy(lastErrors)
            bestErrors = copy.deepcopy(newErrors)
            bestParam = copy.deepcopy(newParam)
            bestRequired = copy.deepcopy(newRequired)
            bestAreaOrNum = copy.deepcopy(newAreaOrNum)
            bestDiff = diff

      
      prnstr(fmtstr('Biggest change was {:.2f}%', -bestDiff * 100.))
      newParam = copy.deepcopy(bestParam)
      newErrors = copy.deepcopy(bestErrors)

   elif not testLocal:

      # Regular non-local variables in defined order, but with possibility of
      # one step per var
      changed = 1
      while changed:
         changed = 0
         compDir = 1
         if oneStepPerVar > 2:
            compDir = -1
            
         bestDiff = -compDir * 10000
         varNames = ('X-stretch and skew', 'X-tilt', 'tilt', 'rotation', 'magnification',
                     'single variable')
         for ordInd in range(len(cvOrder)):
            varToTest = cvOrder[ordInd]
            if varToTest == CV_TEST_STRETCH:
               varOpt = XSTRETCH_OPT
               result = cvTestStretch(oneStepPerVar)

            if varToTest == CV_TEST_XTILT:
               varOpt = XTILT_OPT
               result = cvTestXTilt()
               
            if varToTest == CV_TEST_TILT:
               varOpt = TILT_OPT
               result = cvTestTilt(oneStepPerVar)

            if varToTest == CV_TEST_ROT:
               varOpt = ROT_OPT
               result = cvTestRotation(oneStepPerVar)

            if varToTest == CV_TEST_MAG:
               varOpt = MAG_OPT
               result = cvTestMagnification(oneStepPerVar)

            if varToTest == CV_TEST_SINGLES:
               result = cvTestSingleVars(oneStepPerVar)
               varOpt = result

            # If taking one step, mark as changed if there is any result on this loop,
            # and if doing best or worst step, keep track of that step
            if result and oneStepPerVar:
               changed += 1
               if oneStepPerVar > 1:
                  wasTested[varOpt] -= 1

               if oneStepPerVar > 1 and compDir * curErrDiff > compDir * bestDiff:
                  bestPrevErrDiff = prevErrDiff
                  bestLastErrors = copy.deepcopy(lastErrors)
                  bestErrors = copy.deepcopy(newErrors)
                  bestParam = copy.deepcopy(newParam)
                  bestRequired = copy.deepcopy(newRequired)
                  bestAreaOrNum = copy.deepcopy(newAreaOrNum)
                  bestDiff = curErrDiff
                  bestVar = varToTest
                  bestOpt = varOpt
                  revertToStep(prevErrDiff, prevLastErrs, prevErrors, prevParam, 
                               prevRequired, prevAreaOrNum)

         # end of loop, done if not one step, otherwise repeat until no change
         if not oneStepPerVar:
            break

         # If doing best or worst step, now apply that step
         if changed and oneStepPerVar > 1:
            prnstr(fmtstr('Changing {} for {:.2f}% improvement', 
                          varNames[bestVar - 1], 100. * bestDiff))
            revertToStep(bestPrevErrDiff, bestLastErrors, bestErrors, bestParam,
                         bestRequired, bestAreaOrNum)
            wasTested[bestOpt] += 1

   if robustOff == '' and robustAlign and not doingRobust and not robustOffForEval:
      robustOff = 'robust fitting failed'


# After all that, are there any changes?  If not, exit
noParamChange = True
if testLocalArea:
   noParamChange = newRequired[0] == origRequired[0] and \
      newRequired[1] == origRequired[1] and \
      newAreaOrNum[0] == origAreaOrNum[0] and \
      newAreaOrNum[1] == origAreaOrNum[1] 

if noParamChange and (not testLocal or testLocalVars):
   for ind in range(len(origParam)):
      if (not origParam[ind] and newParam[ind]) or \
         origParam[ind] and origParam[ind] != newParam[ind]:
         noParamChange = False
         break
      
if crossValidate and robustOrig and robustOff == '':
   if noParamChange:
      benefit = benefitOrig
   elif not robustOffForEval:
      benefit = 100. * (newErrors[1] - newErrors[3]) / newErrors[1]
   else:
      finalErrors = copy.deepcopy(newErrors)
      doingRobust = True
      saveTestLoc = testLocal
      if localAlign and not testLocal:
         testLocal = 1
      newErrors = doTiltalignRuns(newParam, 'robust alignment back on', newRequired,
                                  newAreaOrNum)
      benefit = 100. * (newErrors[1] - newErrors[3]) / newErrors[1]
      newErrors = copy.deepcopy(finalErrors)
      testLocal = saveTestLoc
      if verbose:
         prnstr(fmtstr(' - benefit now {:.1f}%', benefit))
         
   if benefit <= 0:
      robustOff = fmtstr('robust fitting gives no benefit ({:.1f}%)', benefit)

if noParamChange and robustOff == '':
   prnstr(progname + ': No restriction of parameters needed')
   sys.exit(0)

finalChanges = True

outFile = comfile
changeText = 'Changed ' + comfile
if trialMode:
   (root, ext) = os.path.splitext(comfile)
   outFile = root + '_new' + ext
   changeText = 'Wrote ' + outFile
if crossValidate:
   if noParamChange:
      prnstr(fmtstr('{}: {} because {}', progname, changeText, robustOff))
   else:
      if newErrors[2] > 0 and origErrors[2] > 0:
         diff = ((newErrors[1] - origErrors[1]) / origErrors[1] +
                 (newErrors[2] - origErrors[2]) / origErrors[2]) / 2.
      else:
         diff = (newErrors[0] - origErrors[0]) / origErrors[0]
      if diff < 0:
         prnstr(fmtstr('{}: {} to reduce errors of points left out by {:.1f}%',
                       progname, changeText, -diff * 100.))
         if boostedRatio:
            prnstr(' (after initial changes to boost the measurement/unknown ratio -' + \
                   'see log file)')
      else:
         prnstr(fmtstr('{}: {} just to boost ratio of measurements to unknowns',
                       progname, changeText))
else:
   if noParamChange:
      prnstr(fmtstr('{}: {} given the measured/unknown ratio of ~{:.1f}', 
                    progname, changeText, newRatio))
   else:
      prnstr(fmtstr('{}: {} to achieve measured/unknown ratio of ~{:.1f}', 
                    progname, changeText, newRatio))

# Now that we know what to do, build up the sed command for the changes and list them
buildUpSedcom(newParam, robustOff, newRequired, newAreaOrNum)
makeBackupFile(outFile)
pysed(sedcom, comLines, outFile)
if crossValidate:
   prnstr('Rerunning the alignment with the final file    [rsa1]')
   try:
      runcmd('submfg ' + outFile)
   except ImodpyError:
      exitFromImodError(progname)
sys.exit(0)

