#!/usr/bin/env python
# splitblend - program to set up command files for parallel blendmont
#
# Author: David Mastronarde
#
# $Id: splitblend,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'splitblend'
prefix = 'ERROR: ' + progname + ' - '

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, copy, glob

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *

# Fallbacks from ../manpages/autodoc2man 3 1 splitblend
options = [":n:I:", ":t:I:", ":y:B:", ":e:B:", ":u:B:", ":r:B:", ":b:I:"]

(numOptArgs, numNonOptArgs) = PipReadOrParseOptions(sys.argv, options, \
                                                    progname, 1, 1, 0)
os.environ['PIP_PRINT_ENTRIES'] = '0'

# Get the com file name, derive a root name and new com file name, check exists
comfile = PipGetNonOptionArg(0)
(comfile, rootname) = completeAndCheckComFile(comfile)
comExt = comfile[-4:]

# Get options
boundPixels = parallelBoundarySize()
targetRatio = 4
numproc = PipGetInteger('n', 4)
targetChunks = PipGetInteger('t', numproc * targetRatio)
boundPixels = PipGetInteger('b', boundPixels)
yChunks = PipGetBoolean('y', 0)
edgeFuncOnly = PipGetBoolean('e', 0)
recomputeEF = PipGetBoolean('r', 0)
useOldFunc = PipGetBoolean('u', 0)
if useOldFunc and edgeFuncOnly:
   exitError('You cannot enter -r and -u together (recompute and use old' +\
             'edge functions)')
   
# Get the com file and analyze for beginning, blendmont, and ending commands
# Don't strip lines for that: $ is required to be at front
comlines = readTextFile(comfile, 'blendmont command file')

beforeInd = blendInd = afterInd = -1
for ind in range(len(comlines)):
   line = comlines[ind]
   if line.startswith('$'):
      if blendInd >= 0 and afterInd < 0:
         afterInd = ind
      elif line.find('blendmont') > 0:
         blendInd = ind
      else:
         beforeInd = ind

if blendInd < 0:
   exitError('Cannot find blendmont command in command file')

lastInd = afterInd
if afterInd < 0:
   lastInd = len(comlines)
blendLines = comlines[blendInd:lastInd]

# Get some options from command file
outputFile = optionValue(blendLines, 'ImageOutputFile', 0)
inputFile = optionValue(blendLines, 'ImageInputFile', 0)
plInput = optionValue(blendLines, 'PieceListInput', 0)
oldfuncArr = optionValue(blendLines, 'OldEdgeFunctions', 1)
readxcorrArr = optionValue(blendLines, 'ReadInXcorrs', 1)
binningArr = optionValue(blendLines, 'BinByFactor', 1)
blendroot = optionValue(blendLines, 'RootNameForEdges', 0)
if not outputFile or not blendroot or not inputFile:
   exitError('Cannot find output file or root name for edges in command file')
setname = (os.path.splitext(outputFile))[0]
if not recomputeEF and oldfuncArr and oldfuncArr[0]:
   useOldFunc = 1

if useOldFunc and edgeFuncOnly:
   exitError('You requested edge functions only but command file says ' +\
             'to use old edge functions')
binning = 1
if binningArr:
   binning = binningArr[0]

# Find size to see if need to do in Y
try:
   sizelines = runcmd('montagesize ' + inputFile + ' ' + plInput)
   for line in sizelines:
      if line.find('Total NX') >= 0:
         lsplit = line.split()
         nz = int(lsplit[6])
         if nz == 1 and not yChunks:
            yChunks = 1
            prnstr('Setting up chunks in Y because the Z size is only 1')

except ImodpyError:
   exitFromImodError(progname)

# Run blendmont to determine chunking
if not edgeFuncOnly:
   runLines = copy.deepcopy(blendLines[1:])
   runLines.append(fmtstr('ParallelMode {} {}', targetChunks, yChunks))
   try:
      queryOut = runcmd('blendmont -StandardInput', runLines)
   except ImodpyError:
      prnstr('ERROR: running blendmont to determine chunk extents')
      exitFromImodError(progname)

   # get the entry lines and other chunk information
   subsetLines = []
   boundStarts = []
   boundEnds = []
   nxout = 0
   for line in queryOut:
      if line.find('SubsetToDo') >= 0:
         subsetLines.append(line)
      if line.find('ChunkBoundary') >= 0:
         lsplit = line.split()
         boundStarts.append(int(lsplit[1]))
         boundEnds.append(int(lsplit[2]))
      if line.find('Output image size:') >= 0:
         lsplit = line.split()
         nxout = int(lsplit[3])
         nyout = int(lsplit[4])

   numChunks = len(subsetLines)
   if not numChunks or numChunks != len(boundStarts) or not nxout:
      exitError('Could not find chunk information in output from blendmont')

# Clean up existing coms and logs if any
cleanChunkFiles(rootname)

# From here out, blendLines are used in their entirety so prepend output file type to 
# protect against incompatible settings here or elsewhere
outFormat = os.getenv('IMOD_OUTPUT_FORMAT')
if not outFormat or outFormat not in standardTypeExtensions():
   outFormat = 'MRC'
blendLines.insert(0, '$setenv IMOD_OUTPUT_FORMAT ' + outFormat)

# Look for summing to correct intensities and possible base value
sumInd = -1
baseOpt = ''
for ind in range(len(blendLines)):
   lsplit = blendLines[ind].strip().split()
   if len(lsplit) > 0 and 'SumPiecesForGradient'.startswith(lsplit[0]):
      sumInd = ind
   if len(lsplit) > 1 and 'BaseIntensityForScaling'.startswith(lsplit[0]):
      baseOpt = '-l ' + lsplit[1] + ' '

# If found, add to end of pre-command if any, set up for sync chunk if not, adjust indexes
# And change the option line to read in gradient
if sumInd > 0:
   fname = rootname + '_grad.txt'
   comlines.insert(blendInd, '$clip plane ' + baseOpt + inputFile + ' ' + fname)
   if beforeInd < 0:
      beforeInd = blendInd
   blendInd += 1
   if afterInd >= 0:
      afterInd += 1
   blendLines[sumInd] = 'OtherSumGradientFile ' + fname

# Start with two chunks for edge functions, with possible sync before it
comnum = 0
combineLines = []
if not useOldFunc:
   if beforeInd >= 0:
      comnum += 1
      comname = rootname + '-001-sync' + comExt
      writeTextFile(comname, comlines[0:blendInd])
      
   sedlist = [r'/^\s*OldEdgeFunctions.*/s//OldEdgeFunctions  0/']
   if recomputeEF:
      sedlist.append(r'/^\s*ReadInXcorrs.*/s//ReadInXcorrs  1/')
   for axis in (1, 2):
      comnum += 1
      comname = rootname + fmtstr('-{:03d}{}', comnum, comExt)
      runLines = copy.deepcopy(blendLines)
      runLines.append(fmtstr('EdgeFunctionsOnly  {}', axis))
      pysed(sedlist, runLines, comname, True)

   combineLines = [fmtstr('$b3dcatfiles {0}.xecd {0}.yecd {0}.ecd', blendroot)]

if edgeFuncOnly:
   comnum += 1
   comname = rootname + fmtstr('-{:03d}-sync{}', comnum, comExt)
   writeTextFile(comname, combineLines)
   runLines = None

else:

   # Put out the file to set up the output file
   comnum += 1
   comname = rootname + fmtstr('-{:03d}-sync{}', comnum, comExt)
   runLines = []
   if beforeInd >= 0 and useOldFunc:
      runLines.extend(comlines[0:blendInd])
   if combineLines:
      runLines.extend(combineLines)
   runLines.extend(blendLines)
   runLines.append(fmtstr('ParallelMode  -1 {}', yChunks))
   if not oldfuncArr:
      runLines.append('OldEdgeFunctions  1')
   sedlist = [r'/^\s*ReadInXcorrs.*/s//ReadInXcorrs  1/',
              r'/^\s*OldEdgeFunctions.*/s//OldEdgeFunctions  1/']
   pysed(sedlist, runLines, comname, True)

   # Figure out the boundary lines
   boundLines = (nxout + boundPixels - 1) // nxout
   infoLines = [fmtstr('1 {} {} {} {}', yChunks, nxout, boundLines, numChunks)]
   boundFile = rootname + '-bound.info'

   # Make the com files for chunks and accumulate boundary info
   for ind in range(numChunks):
      comnum += 1
      comname = rootname + fmtstr('-{:03d}{}', comnum, comExt)
      runLines = copy.deepcopy(blendLines)
      runLines.append(fmtstr('ParallelMode  -2 {}', yChunks))
      runLines.append(subsetLines[ind])
      runLines.append('BoundaryInfoFile  ' + boundFile)
      if not oldfuncArr:
         runLines.append('OldEdgeFunctions  1')
      pysed(sedlist, runLines, comname, True)
      infoLines.append(setname + fmtstr('-{:03d}.bound', ind + 1))
      start = boundStarts[ind]
      if not ind:
         start = -1
      if yChunks:
         end = boundEnds[ind] + 1 - boundLines
         if ind == numChunks - 1:
            end = -1
         infoLines.append(fmtstr('-1 {} -1 {}', start, end))
      else:
         end = boundEnds[ind]
         if ind == numChunks - 1:
            end = -1
         infoLines.append(fmtstr('{} 0 {} -1', start, end))

   # Write the info file and the finishing operations
   writeTextFile(boundFile, infoLines)
   runLines = [fmtstr('$fixboundaries  {} {}', outputFile, boundFile),
               fmtstr('$collectmmm pixels= {} {} {} {}',
                      rootname, numChunks, outputFile, comnum + 1 - numChunks)]
   if afterInd > 0:
      runLines.extend(comlines[afterInd:])
   runLines.append('$b3dremove -g ' + setname + '-[0-9][0-9][0-9]*.bound* ' + boundFile)

writeFinishAndMessage(runLines, rootname, comnum, False, comExt)
sys.exit(0)
