#!/usr/bin/env python
# alttomosetup - Sets up com files for processing alternate stacks
#
# Author: David Mastronarde
#
# $Id: alttomosetup,v 45bb719cbe42 2023/07/12 04:57:18 mast $

progname = 'alttomosetup'
prefix = 'ERROR: ' + progname + ' - '

# Another copy of this function
def nextComName(sync = True):
   global comNum
   comNum += 1
   if sync:
      comName = fmtstr('{}-{:03d}-sync.com', outRoot, comNum)
   else:
      comName = fmtstr('{}-{:03d}.com', outRoot, comNum)
   return comName


# Read a com file for processing step and find the name of the output file
def readAndGetOutput(comBase, outOpt):
   comName = comBase + axisExt
   lines = readTextFile(comName)
   outName = optionValue(lines, outOpt, STRING_VALUE)
   if not outName:
      exitError('Cannot find output file name in ' + comName)
   return (lines, outName)


# Read an edf file and find all the Trimvol-related tags needed to compose a trimvol
# command.  Make up such a command.  Return the nx, ny, nz of input to the trimmimg and
# the command, or -1, -1, -1 and an error string
def makeTrimvolCommandFromEDF(rootname):
   tags = ('XMin', 'XMax', 'YMin', 'YMax', 'ZMin', 'ZMax', 'ScaleXMin', 'ScaleXMax', 
           'ScaleYMin', 'ScaleYMax', 'SectionScaleMin', 'SectionScaleMax', 
           'TrimvolFlipped', 'SwapYZ', 'RotateX', 'ConvertToBytes', 'FixedScaling',
           'FixedScaleMin', 'FixedScaleMax', 'Input.NColumns', 'Input.NRows',
           'Input.NSections')
   opts = ('x', 'y', 'z', 'sx', 'sy', 'sz')
   edfName = rootname + '.edf'
   if not os.path.exists(edfName):
      return (-1, -1, -1, 'Etomo file ' + edfName + ' dose not exist')
   edfLines = readTextFile(edfName, 'Etomo data file', returnOnErr = True)
   if isinstance(edfLines, str):
      return (-1, -1, -1, edfLines)

   # Extract the values from the edf file
   values = {}
   for line in edfLines:
      if 'Trimvol' in line:
         for tag in tags:
            if tag in line:
               ind = line.find('=')
               if ind > 0:
                  values[tag] = line[ind + 1:]

   # Compose the command line
   command = '$trimvol -f'
   if 'TrimvolFlipped' in values and values['TrimvolFlipped'] == 'true':
      if 'RotateX' in values and values['RotateX'] == 'true':
         command += ' -rx'
      elif 'SwapYZ' in values and values['SwapYZ'] == 'true':
         command += ' -yz'

   numPairs = 3
   if 'ConvertToBytes' in values and values['ConvertToBytes'] == 'true':
      if 'FixedScaling' in values and values['FixedScaling'] == 'true':
         if 'FixedScaleMin' in values and 'FixedScaleMax' in values:
            command += ' -c ' + values['FixedScaleMin'] + ' , ' + 'FixedScaleMax'
      else:
         numPairs = 6

   for ind in range(numPairs):
      if tags[2 * ind] in values and tags[2 * ind + 1] in values:
         command += fmtstr(' -{} {},{}', opts[ind], values[tags[2 * ind]], 
                           values[tags[2 * ind + 1]])

   (col, row, sec) = (0, 0, 0)
   try:
      col = int(values['Input.NColumns'])
   except Exception:
      pass
   try:
      row = int(values['Input.NRows'])
   except Exception:
      pass
   try:
      sec = int(values['Input.NSections'])
   except Exception:
      pass

   return (col, row, sec, command)


# Use the PrintXYSizeAndExit option to newstack or blendmont to find out aligned stack
def getAlignedStackSizeFromProgram(newstExists, newstLines):
   sizeCom = []

   # Extract the input lines and add the option
   prog = 'blendmont'
   if newstExists:
      prog = 'newstack'
   sizeCom = extractProgramEntries(newstLines, prog, 'Standard')
   if not sizeCom:
      exitError('Could not find ' + prog + ' input lines in command file')
   sizeCom.append('PrintXYSizeAndExit')

   # Run the program
   try:
      sizeLines = runcmd(prog + ' -Standard', sizeCom)
      if not sizeLines:
         exitError('Empty output when running newstack or blendmont to determine ' + \
                   'aligned stack size')

   except ImodpyError:
      prnstr('Error trying to determine aligned stack size')
      exitFromImodError(progname)

   # Get the size
   sizeSplit = sizeLines[-1].split()
   if len(sizeSplit) != 4 or sizeSplit[0] != 'Output' or sizeSplit[1] != 'size:':
      exitError('Unexpected output from running newstack or blendmont to determine ' +\
                'aligned stack size')
   try:
      nxAli = int(sizeSplit[2])
      nyAli = int(sizeSplit[3])

   except ValueError:
      exitError('Converting aligned stack output file size from newstack or blendmont' +\
                ' to integer')

   return (nxAli, nyAli)
         

# THE MAIN
#
# load System Libraries
import sys, os

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *
from tomocoords import *

# Fallbacks from ../manpages/autodoc2man 3 1 alttomosetup
options = ["rootname:RootnameToProcess:CH:", "evenodd:EvenAndOddPairs:B:",
           "axis:AxisToProcess:CH:", "preproc:PreprocessForExtremes:I:",
           "ctf:CorrectCTF:B:", "erase:EraseFiducials:B:", "filter:FilterIn2D:B:",
           "trim:TrimVolume:B:", "clean:CleanUpIntermediates:B:",
           "procs:NumberOfProcessors:I:", "restore:JustRestoreInitialSet:B:",
           "help:usage:B:"]

# PIP startup and help
(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 1)

chunksPerProc = 3

# Get options
evenOdd = PipGetBoolean('EvenAndOddPairs', 0)
fromRoot = PipGetString('RootnameToProcess', '')
if evenOdd and fromRoot:
   exitError('You cannot enter both -evenodd and -rootname')
if not (evenOdd or fromRoot):   
   exitError('You must enter either -evenodd or -rootname')

preProcess = PipGetInteger('PreprocessForExtremes', 0)
correctCTF = PipGetBoolean('CorrectCTF', 0)
eraseGold = PipGetBoolean('EraseFiducials', 0)
filter2D = PipGetBoolean('FilterIn2D', 0)
doTrim = PipGetBoolean('TrimVolume', 0)
numProcs = PipGetInteger('NumberOfProcessors', 0)
cleanUp = PipGetBoolean('CleanUpIntermediates', 0)
doAxis = PipGetString('AxisToProcess', '')
justRestore = PipGetBoolean('JustRestoreInitialSet', 0)
if doAxis and doAxis.upper() != 'A' and doAxis.upper() != 'B':
   exitError('Entry for -axis must be a, A, b, or B')

# Get properties of data set and insist they are all deducible
(comExt, dualNum, rootName, typeExt, stackExt) = findRootAxisAndExtensions()
single = dualNum != 2
if not comExt or dualNum < 0 or not rootName or not stackExt or typeExt == None:
   exitError('There are non-standard files in the directory and not all ' +\
             'of the data set can be determined')

# Error checks
if single and doAxis:
   exitError('This appears to be a single-axis data set, you cannot enter -axis')
if dualNum == 2 and doTrim:
   exitError('This appears to be a dual-axis data set, you cannot enter -trim')
if dualNum == 2 and evenOdd:
   exitError('This appears to be a dual-axis data set and you cannot use -evenodd')

# Set up number of loops to run on the steps
numLoop = 1
if evenOdd or (dualNum == 2 and not doAxis):
   numLoop = 2

setRootAndExtension(rootName, typeExt)

# Get the trimvol command if possible
if doTrim:
   (ncol, nrow, nsec, trimCommand) = makeTrimvolCommandFromEDF(rootName)
   if ncol < 0:
      exitError(trimCommand)

outRoot = 'alttomo'
comNum = 0

cleanChunkFiles(outRoot)
boundList = glob.glob(outRoot + '-bound-*.info')
if boundList:
   cleanupFiles(boundList)

# Set up possible default axis letter and entry for -single option to swaptomostacks
axisLet = ''
singleOpt = ''
if doAxis:
   axisLet = doAxis.lower()
   singleOpt = '-single'

# Loop, set up root names for swapping and axis letter
for loop in range(numLoop):
   altRoot = fromRoot
   setRoot = rootName
   toRoot = setRoot + '_primts'
   if evenOdd:
      altRoot = rootName + ('_even', '_odd')[loop]
   if dualNum and not doAxis:
      axisLet = ('a', 'b')[loop]
   if doAxis:
      setRoot += axisLet
      altRoot += axisLet
      toRoot += axisLet
   axisExt = axisLet + '.' + comExt
   axisRoot = rootName + axisLet

   # Set up names for testing whether stacks are already swapped in
   toTest = toRoot
   altTest = altRoot
   if dualNum and not doAxis:
      toTest += axisLet
      altTest += axisLet

   # Test if the _primts stack exists and the alt stack does not; if so swap back
   if os.path.exists(toTest + '.' + stackExt) and not \
      os.path.exists(altTest + '.' + stackExt):
      if not justRestore:
         prnstr('WARNING: altomosetup - the alternate stack ' + altTest + \
                ' is already swapped in; running swaptomostacks to restore files')
      else:
         prnstr('The alternate stack ' + altTest + \
                ' is swapped in; running swaptomostacks to restore files')
      try:
         runcmd(fmtstr('swaptomostacks {} -root "{}" -from "{}" -to "{}"', singleOpt,
                       setRoot, toRoot, altRoot))
      except ImodpyError:
         exitFromImodError(progname)

   if justRestore:
      if evenOdd:
         continue
      sys.exit(0)
                
   # Get newstack or blendmont lines for aligned stack
   newstCom = 'newst' + axisExt
   blendCom = 'blend' + axisExt
   newstExists = os.path.exists(newstCom)
   blendExists = os.path.exists(blendCom)
   if not (newstExists or blendExists):
      exitError('Neither ' + newstCom + ' nor ' + blendCom + ' exists')
   if newstExists and blendExists:
      exitError('Both ' + newstCom + ' and ' + blendCom + ' exist, cannot tell which ' + \
                'to use')
   if newstExists:
      newstLines = readTextFile(newstCom)
   else:
      newstLines = readTextFile(blendCom)
      
   # Get tilt.com and its input and output, detect GPU and set up default # of procs
   tiltName = 'tilt' + axisExt
   tiltLines = readTextFile(tiltName)
   recName = optionValue(tiltLines, 'outputfile', 0, True)
   aliName = optionValue(tiltLines, 'inputproj', 0, True)
   if not recName or not aliName:
      exitError('Cannot find name of input file or output file in ' + tiltName)

   useGPU = optionValue(tiltLines, 'UseGPU', INT_VALUE, True, numVal = 1)
   if useGPU == None:
      useGPU = -1
   useProcs = numProcs
   if not useProcs:
      useProcs = 8
      if useGPU >= 0:
         useProcs = 1

   # When more than one processing unit, Splittilt will need size
   if useProcs > 1:
      (nxAli, nyAli) = getAlignedStackSizeFromProgram(newstExists, newstLines)
         
   # If trimming, get size of current rec file
   if doTrim:
      try:
         (nx, ny, nz) = getmrcsize(recName)
         if (ncol and nx != ncol) or (nrow and ny != nrow) or (nsec and nz != nsec):
            exitError(fmtstr('Size of current output from Tilt ({}x{}x{}) does not ' +\
                             'match size when Trimvol was run ({}x{}x{})', nx, ny, nz,
                             ncol, nrow, nsec))
      except ImodpyError:
         pass

   # Set up the swap on first loop or both if even/odd
   comLines = []
   if not loop or evenOdd:
      comLines.append(fmtstr('$swaptomostacks {} -check -root {} -from {} -to {}',
                             singleOpt, setRoot, altRoot, toRoot))
   
   # Eraser and archive
   if preProcess:
      xrayLines = readTextFile('eraser' + axisExt)
      comLines += xrayLines
      eraseOut = optionValue(xrayLines, 'OutputFile', STRING_VALUE)
      stackName = axisRoot + '.' + stackExt
      if not eraseOut:
         eraseOut = axisRoot + '_fixed.' + stackExt
      comLines.append(fmtstr('$b3drename {} {}_orig.{}', stackName, axisRoot, stackExt))
      comLines.append(fmtstr('$b3drename {} {}.{}', eraseOut, axisRoot, stackExt))
      if preProcess > 1:
         comLines.append('$archiveorig ' + stackName)
   
   comLines += newstLines
   writeTextFile(nextComName(), comLines)
   comLines = []
   
   # CTF correction
   if correctCTF:
      (ctfLines, ctfOut) = readAndGetOutput('ctfcorrection', 'OutputFileName');
      ctfMod = pysed(sedDelAndAdd('UseGPU', useGPU, 'DefocusFile'), ctfLines)
      if useProcs > 1:
         try:
            mainStack = axisRoot + '.' + stackExt
            if blendExists:
               mscom = fmtstr('montagesize "{}"', mainStack)
               if os.path.exists(axisRoot + '.pl'):
                  mscom += ' "' + axisRoot + '.pl"'
               montSizeLines = runcmd(mscom)
               montSplit = montSizeLines[-1].split();
               nzAli = int(montSplit[-1])
            else:
               (nxRaw, nyRaw, nzAli) = getmrcsize(mainStack)

            # Split correction if multiple procs
            numChunks = useProcs * chunksPerProc
            maxSlices = max(1, (nzAli + numChunks - 1) // numChunks)

            tempName = 'ctfcorrection_tmp' + axisExt
            writeTextFile(tempName, ctfMod)
            splitLines = runcmd(fmtstr('splitcorrection -i {} -o -m {} -uni -size ' +\
                                       '{},{},{} -r {} "{}"', comNum + 1, maxSlices,
                                       nxAli, nyAli, nzAli, outRoot, tempName))

         except ImodpyError:
            exitFromImodError(progname)
         except ValueError:
            exitError('Getting NZ from last value in montagesize output: ' + \
                      montSizeLines[-1])
         except IndexError:
            exitError('Getting NZ from last value in montagesize output: ' + \
                      montSizeLines[-1])

         numAdded = findSplitComNumber(splitLines, 'output of splitcorrection')
         comNum += numAdded
         cleanupFiles([tempName])

      else:

         # Or just copy the com lines
         comLines = ctfMod

      comLines.append('$b3drename ' + ctfOut + ' ' + aliName)
      
   # Gold erasing
   if eraseGold:
      (goldLines, eraseOut) = readAndGetOutput('golderaser', 'OutputFile')
      comLines += goldLines
      comLines.append('$b3drename ' + eraseOut + ' ' + aliName)

   # 2D filtering
   if filter2D:
      (filtLines, filtOut) = readAndGetOutput('mtffilter', 'OutputFile')
      comLines += filtLines
      comLines.append('$b3drename ' + filtOut + ' ' + aliName)

   if comLines:
      writeTextFile(nextComName(), comLines)
      comLines = []

   # Tilt!  Split up if multiple procs, or use the lines as is
   if useProcs > 1:      
      splitCom = ['CommandFile  ' + tiltName,
                  'RootNameOfOutput  ' + outRoot,
                  'ProcessorNumber  ' + str(useProcs),
                  'TargetChunks  ' + str(useProcs * chunksPerProc),
                  'InitialComNumber  ' + str(comNum + 1),
                  'OpenForMoreComs  1',
                  'UniqueInfoFile  1',
                  fmtstr('DimensionsOfStack {},{}', nxAli, nyAli)]
      try:
         splitLines = runcmd('splittilt -StandardInput', splitCom)
      except ImodpyError:
         exitFromImodError(progname)

      numAdded = findSplitComNumber(splitLines, 'output of splittilt')
      comNum += numAdded

   else:
      comLines += tiltLines

   # Trim
   if doTrim:
      trimName = datasetFilename('.rec', root = axisRoot)
      comLines.append(trimCommand + ' ' + recName + ' ' + trimName)

   # Cleanup
   if cleanUp:
      comLines.append('$b3dremove ' + aliName)
      if doTrim:
         comLines.append('$b3dremove ' + recName)

   # Swap back at end or for each loop of even/odd
   if evenOdd or loop == numLoop - 1:
      comLines.append(fmtstr('$swaptomostacks {} -root {} -from {} -to {}', singleOpt,
                             setRoot, toRoot, altRoot))

   if comLines:
      writeTextFile(nextComName(), comLines)
      comLines = []

if not justRestore:
   writeFinishAndMessage([], outRoot, comNum, useProcs < 2)

sys.exit(0)
