#!/usr/bin/env python
# splittilt - program to set up command files for parallel reconstruction
#
# Author: David Mastronarde
#
# $Id: splittilt,v a2efe816ad80 2024/04/22 18:43:55 mast $
#

progname = 'splittilt'
prefix = 'ERROR: ' + progname + ' - '
penalty = 1.33
maxextrapct = 102
numproc = 8
minslices = 50
minratio = 2
targetratio = 5

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, math, re, glob

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *


# Initializations (defaults are above or in Pip calls)
oldstyle = "#"
boundext = 'rbound'
vsBoundExt = 'vsbound'
boundpixels = parallelBoundarySize()

# Fallbacks from ../manpages/autodoc2man 3 1 splittilt
options = [":CommandFile:FN:", "outroot:RootNameOfOutput:CH:", "naming:NamingStyle:I:",
           "n:ProcessorNumber:I:", "s:SliceMinimum:I:", "t:TargetChunks:I:",
           "m:ChunkMinimum:I:", "p:OldStyleXtiltPenalty:F:", "v:VerticalSlices:B:",
           "c:SeparateChunks:B:", "b:BoundaryPixels:I:", "i:InitialComNumber:I:",
           "o:OpenForMoreComs:B:", "unique:UniqueInfoFile:B:",
           "d:DimensionsOfStack:IP:", "help:usage:B:"]

(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 0)

# Get the com file name, derive a root name and new com file name, check exists
comfile = PipGetInOutFile('CommandFile', 0)
(comfile, rootname) = completeAndCheckComFile(comfile)

(comExt, dualNum, setroot, typeExt, stackExt) = findRootAxisAndExtensions(useTilt = 
                                                                          comfile)
if not comExt:
   comExt = comfile[-3:]
comExt = '.' + comExt

(optNameStyle, typeExt) = getNamingStyle(typeExt)

# Get options
rootname = PipGetString('RootNameOfOutput', rootname)
numproc = PipGetInteger('ProcessorNumber', numproc)
minslices = PipGetInteger('SliceMinimum', minslices)
penalty = PipGetFloat('OldStyleXtiltPenalty', penalty)
targetslabs = PipGetInteger('TargetChunks', 0)
minslabs = PipGetInteger('ChunkMinimum', 0)
vertical = PipGetBoolean('VerticalSlices', 0)
direct = PipGetBoolean('SeparateChunks', 0) == 0
boundpixels = PipGetInteger('BoundaryPixels', boundpixels)
startnum = PipGetInteger('InitialComNumber', 1)
ifstartnum = 1 - PipGetErrNo()
leaveopen = PipGetBoolean('OpenForMoreComs', 0)
dimens = PipGetTwoIntegers('DimensionsOfStack', 0, 0)
ifDimens = 1 - PipGetErrNo()
uniqueInfo = PipGetBoolean('UniqueInfoFile', 0)

# Set min and target slabs if not entered
if minslabs == 0:
   minslabs = minratio * numproc
if targetslabs == 0:
   targetslabs = targetratio * numproc
targetslabs = max(minslabs, targetslabs)

# Collect info from command file
comlines = readTextFile(comfile, 'tilt command file')
xtiltArr = optionValue(comlines, 'xaxistilt', 2, 1)
fullimage = optionValue(comlines, 'fullimage', 1, 1)
thickArr = optionValue(comlines, 'thickness', 1, 1)
slices = optionValue(comlines, 'slice', 1, 1)
localali = optionValue(comlines, 'localfile', 0, 1)
binningArr = optionValue(comlines, 'imagebinned', 1, 1)
expandedFac = optionValue(comlines, 'expandedbyfactor', 2, 1, numVal = 1)
alifile = optionValue(comlines, 'inputproj', 0, 1)
recfile = optionValue(comlines, 'outputfile', 0, 1)
rectoproj = optionValue(comlines, 'recfiletoreproj', 0, 1)
sliceproj = optionValue(comlines, 'reproject', 2, 1)
widthArr = optionValue(comlines, 'width', 1, 1)
intsirt = optionValue(comlines, 'sirtiterations', 1, 1)
xtiltFile = optionValue(comlines, 'xtiltfile', 0, 1)
zfactors = optionValue(comlines, 'zfactorfile', 0, 1)
vsOutFile = optionValue(comlines, 'vertsliceoutputfile', 0, 1)

binval = 1
if binningArr:
   binval = binningArr[0]
xaxistilt = 0.
if xtiltArr:
   xaxistilt = xtiltArr[0]
reproj = 0
if intsirt == None and (rectoproj or sliceproj):
   reproj = 1
if not expandedFac:
   expandedFac = 1.

# Figure out if vertical slices are even possible
vertPossible = localali == None and zfactors == None and not reproj
if vertPossible and xtiltFile:
   xtlines = readTextFile(xtiltFile, 'X-tilt file')
   firstxt = float(xtlines[0])
   for i in range(len(xtlines)):
      if math.fabs(float(xtlines[i]) - firstxt) > 1.e-5:
         vertPossible = False
         break
   else:
      xaxistilt += firstxt

# Get the input and output image file names from the command file if necessary
# DUPLICATE OF SIRTSETUP
if alifile == None:
   for ind in range(len(comlines)):
      if re.search(r'^\s*\$\s*tilt\s', comlines[ind]) or \
             re.search(r'^\s*\$\s*tilt$', comlines[ind]):
         break
   else:
      exitError("tilt command not found in com file " + comfile)
   while ind < len(comlines) - 1:
      ind += 1
      if not comlines[ind].strip().startswith('#'):
         alifile = comlines[ind].strip()
         break
   if recfile == None:
      while ind < len(comlines) - 1:
         ind += 1
         if not comlines[ind].strip().startswith('#'):
            recfile = comlines[ind].strip()
            break
   if alifile == None or recfile == None:
      exitError("Cannot find input and output file names in command file")


noali = not os.path.exists(alifile)
if noali and not ifDimens and not fullimage and not slices:
   exitError('Command file has neither a SLICE nor a FULLIMAGE entry and ' +\
             'image file does not exist yet')

if direct and noali and not ifDimens and not fullimage and not widthArr:
   exitError('Command file has neither a WIDTH nor a FULLIMAGE entry and ' +\
             'image file does not exist yet')

# Divide thickness by the binning for computations
if thickArr:
   thickness = int(math.floor(expandedFac * thickArr[0] / binval))
else:
   exitError('Command file has no THICKNESS entry')

# Extract root and extension for making filenames: if it is a type extension and there
# is and _rec or _xxx before it, pull that off, otherwise drop back to no type extension
(recroot, recext) = os.path.splitext(recfile)
if typeExt and len(recroot) > 4 and recroot[-4] == '_':
   recext = recroot[-3:]
   recroot = recroot[0:-4]
else:   
   recext = recext[1:]
   typeExt = ''
setRootAndExtension(recroot, typeExt)

# Remove any previous files now in case the number has changed or 
# direct/indirect mode.  Processchunks takes care of other files
cleanChunkFiles(rootname, ifstartnum != 0)

# Get the size from the supplied dimension or from the aligned stack instead 
# of relying on FULLIMAGE if possible, and scale them up by binning
#
if not noali or ifDimens:
   if ifDimens:
      fullimage = [dimens[0], dimens[1]]
   else:
      try:
         fullimage = list(getmrcsize(alifile))
      except ImodpyError:
         exitFromImodError(progname)

   fullimage[0] *= binval
   fullimage[1] *= binval
else:
    prnstr("WARNING: " + progname + " - aligned stack not found; sizes will" +\
           " be taken from FULLIMAGE entry")

if fullimage:
   firstslice = 0
   numslices = (fullimage[1] + binval - 1) // binval

if slices:
   firstslice = int(round(expandedFac * slices[0])) // binval
   numslices = int(round(expandedFac * slices[1])) // binval + 1 - firstslice

# Get the width before possibly changing binval
if widthArr:
   widthnum = int(math.floor(expandedFac * widthArr[0])) // binval
else:
   widthnum = fullimage[0] // binval

# If reprojecting from tomo, need to get real number of slices and starting one
slicedel = 'SLICE'
if reproj and rectoproj:
   slicedel = 'ZMinAndMax'
   slices = optionValue(comlines, 'ZMinAndMax', 1, 1)
   firstslice = 0
   if slices:
      firstslice = slices[0]
      numslices = slices[1] + 1 - firstslice
   binval = 1

# Start with target size, make sure bigger than minimum
slabsize = max(minslices, numslices // targetslabs)

if vertPossible and xaxistilt:

    # If no locals or Z factors and X axis tilt, go for maximum # of slabs 
    # that has a percentage of extra slices within a minimal limit, down to the
    # "min" # of slabs
    nslabs = targetslabs
    extrathick = math.fabs(thickness * math.sin(0.01745329 * xaxistilt))
    while nslabs >= minslabs:
        slabsize = max(minslices, numslices // nslabs)
        nslabs -= 1
        
        # Get percent of extra slices required
        extranum = int(100. * (slabsize + extrathick) / slabsize)
        if extranum <= maxextrapct:
           break

    pennum = int(100. * penalty)

    # If extra is less than penalty, proceed
    # Otherwise, drop to old-style tilting unless vertical specified
    if extranum > pennum:
       if not vertical:
          oldstyle = "XTILTINTERP 0"
       else:

          # If vertical specified, compute optimum size that just breaks
          # even with penalty for old-style tilting, but limit it
          # However, in this case allow it to go down to one chunk per processor
          slabsize = int(extrathick / (penalty - 1.))
          maxsize = numslices // numproc
          slabsize = max(minslices, min(maxsize, slabsize))

numslabs = max(1, (numslices + slabsize // 2) // slabsize)
slabsize = numslices // numslabs
remainder = numslices % numslabs

# Now that slab size is known, get # of bound lines
#
boundlines = (boundpixels + widthnum - 1) // widthnum
if reproj:
   boundlines = min(boundlines, slabsize // 2 + 1)
elif slabsize == 1:
   boundlines = min(boundlines, thickness // 2 + 1)
else:
   boundlines = min(boundlines, thickness - 1)

# Manage output file type: pass through incoming line or...
for line in comlines:
   if line.startswith('$setenv IMOD_OUTPUT_FORMAT'):
      break
else:    # ELSE ON FOR: protect against strange settings on other machines
   outFormat = os.getenv('IMOD_OUTPUT_FORMAT')
   if not outFormat or outFormat not in standardTypeExtensions():
      outFormat = 'MRC'
   comlines.insert(0, '$setenv IMOD_OUTPUT_FORMAT ' + outFormat)
   
templist = []
recsed = recfile
mintotslice = binval * firstslice
maxtotslice = binval * (firstslice + numslices - 1)
totsed = 'gibberish'
boundsed = 'gibberish'
boundfile = rootname + '-bound.info'
if uniqueInfo:
   boundfile = fmtstr('{}-bound-{:03d}.info', rootname, startnum)

totalComs = 0
if reproj:
   boundfile = rootname + '-rpbound.info'
   if uniqueInfo:
      boundfile = fmtstr('{}-rpbound-{:03d}.info', rootname, startnum)
if vsOutFile:
   vsBoundFile = rootname + '-vsbound.info'
if direct:
   totsed = 'THICKNESS'
   recsed = 'gibberish'
   boundsed = 'THICKNESS'
   thiscom = rootname + '-start' + comExt
   if ifstartnum:
      thiscom = fmtstr(rootname + '-{:03d}-sync{}', startnum, comExt)
      startnum += 1

   sedcom = [r"|^\s*" + slicedel + "|d",
             '|savework|d',
             '|^ *THICKNESS|a|' + slicedel + ' -1 -1|',
             fmtstr("|^ *THICKNESS|a|TOTALSLICES {} {}|", mintotslice,
                    maxtotslice),
             sedModify('ActionIfGPUFails', '2,2', delim = '|')]
   sedlines = pysed(sedcom, comlines, None, True, delim = '|')
   sedlines.append('$sync')
   writeTextFile(thiscom, sedlines)
   totalComs += 1

   boundhead = fmtstr("1 {} {} {} {}", reproj, widthnum, boundlines, numslabs)
   boundtext= [boundhead]
   if vsOutFile:
      vsBoundText = [boundhead]

firstofall = firstslice
for num in range(1, numslabs+1):
   numrec = num + startnum - 1
   numtext = fmtstr('{:03d}', numrec)
   thiscom = rootname + '-' + numtext + comExt
   tempname = datasetFilename('-' + numtext + '.' + recext)
   templist.append(tempname)
   lastslice = firstslice + slabsize - 1
   if num <= remainder:
      lastslice += 1

   # Get unbinned first and last slices for output
   ubfirst = firstslice * binval
   ublast = lastslice * binval

   # Modify the command file: delete existing slice, get rid of savework,
   # Set the new slice command and the xtiltinterp control
   sedcom = ["|" + recsed + "|s||" + tempname + "|",
             r"|^\s*" + slicedel + "|d",
             '|savework|d',
             fmtstr("|^ *THICKNESS|a|{} {} {}|", slicedel, ubfirst, ublast),
             "|^ *THICKNESS|a|" + oldstyle + "|",
             fmtstr("|{}|a|TOTALSLICES {} {}|", totsed, mintotslice, maxtotslice),
             "|" + boundsed + "|a|BoundaryInfoFile " + boundfile + "|",
             sedModify('ActionIfGPUFails', '2,2', delim = '|')]
   if vsOutFile:
      sedcom.append("|" + boundsed + "|a|VertBoundaryFile " + vsBoundFile + "|")
   sedlines = pysed(sedcom, comlines, None, True, delim = '|')
   sedlines.insert(0, '$sync')
   writeTextFile(thiscom, sedlines)
   totalComs += 1
   if direct:
      boundtext.append(recroot + '-' + numtext + '.' + boundext)
      boundstart = firstslice - firstofall
      boundend = lastslice - firstofall
      if reproj:
         boundend -= boundlines - 1
      if num == 1:
         boundstart = -1
      if num == numslabs:
         boundend = -1
      if reproj:
         boundtmp = fmtstr("-1 {} -1 {}", boundstart, boundend)
      else:
         boundtmp = fmtstr("{} 0 {} -1", boundstart, boundend)
      boundtext.append(boundtmp)
      if vsOutFile:
         vsBoundText.append(recroot + '-' + numtext + '.' + vsBoundExt)
         vsBoundText.append(boundtmp)
   firstslice = lastslice + 1

finish = rootname + '-finish' + comExt
cleanup = '$b3dremove -g ' + rootname + '-[0-9][0-9][0-9]*' + comExt + '* ' + rootname +\
          '-[0-9][0-9][0-9]*.log* '
cleanbound = '"' + boundfile + '"'
if vsOutFile:
   cleanbound += ' "' + vsBoundFile + '"'
if leaveopen:
   finish = fmtstr(rootname + '-{:03d}-sync{}', num + startnum, comExt);
   cleanup = '$b3dremove -g '
   cleanbound = ''

if not direct and not reproj:
   finishlines = ['$newstack -StandardInput',
                  'OutputFile ' + recfile]
   for num in range(numslabs):
      finishlines.append('InputFile ' + templist[num])
   finishlines.append(cleanup + fmtstr('"{}-[0-9][0-9][0-9]*.{}"', recroot, recext))
elif not direct:
   finishlines = ['$assemblevol -StandardInput', 
                  'OutputFile ' + recfile,
                  fmtstr('NumberOfFilesInY {}', numslabs)]

   for num in range(numslabs):
      finishlines.append('InputFile ' + templist[num])
   
   finishlines.append(cleanup + fmtstr('"{}-[0-9][0-9][0-9]*.{}"', recroot, recext))
else:

   finishlines = [fmtstr('$fixboundaries "{}" "{}"', recfile, boundfile),
                  fmtstr('$collectmmm pixels= "{}" {} "{}" {}', rootname, numslabs, 
                        recfile, startnum),
                  cleanup + fmtstr('"{}-[0-9][0-9][0-9]*.{}" {}', recroot, boundext,
                                   cleanbound)]
   if vsOutFile:
      finishlines.insert(1, fmtstr('$fixboundaries "{}" "{}"', vsOutFile, vsBoundFile))
      finishlines.append(cleanup + fmtstr('"{}-[0-9][0-9][0-9]*.{}"', recroot, 
                                          vsBoundExt))

writeTextFile(finish, finishlines)
totalComs += 1
if direct:
   writeTextFile(boundfile, boundtext)
   if vsOutFile:
      writeTextFile(vsBoundFile, vsBoundText)
      

prnstr(fmtstr('{} command files for {} chunks created and ready to run', totalComs,
              numslabs))
prnstr('  with processchunks or parallel processing interface in Etomo')

