#!/usr/bin/env python
# splitcorrection - program to set up command files for parallel ctf correction
#
# Author: David Mastronarde
#
# $Id: splitcorrection,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'splitcorrection'
prefix = 'ERROR: ' + progname + ' - '

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *

boundExt = 'cbound'
boundPixels = parallelBoundarySize()

# Fallbacks from ../manpages/autodoc2man 3 1 splitcorrection
options = [":m:I:", ":b:I:", "i:InitialComNumber:I:", "o:OpenForMoreComs:B:",
           "r:RootNameOfOutput:FN:", "dir:DirectoryForOutput:FN:",
           "unique:UniqueInfoFile:B:", "size:InputDimensions:IT:", "help:usage:B:"]

(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 0, 1)

# Get the com file name, derive a root name and new com file name, check exists
comfile = PipGetNonOptionArg(0)
(comfile, rootname) = completeAndCheckComFile(comfile)
comExt = comfile[-4:]

# Get options
maxSlices = PipGetInteger('m', 5)
boundPixels = PipGetInteger('b', boundPixels)
startNum = PipGetInteger('InitialComNumber', 0)
ifStartNum = 1 - PipGetErrNo()
leaveOpen = PipGetBoolean('OpenForMoreComs', 0)
rootname = PipGetString('RootNameOfOutput', rootname)
uniqueInfo = PipGetBoolean('UniqueInfoFile', 0)
outDir = PipGetString('DirectoryForOutput', '')
rootWithDir = rootname
if outDir:
   rootWithDir = os.path.join(outDir, rootname)

comLines = readTextFile(comfile, 'ctfphaseflip command file')
outFile = optionValue(comLines, 'OutputFileName', 0)
if not outFile:
   exitError('Cannot find name of output file in ' + comfile)
dimensions = PipGetThreeIntegers('InputDimensions', 0, 0, 0)
if PipGetErrNo():
   inStack = optionValue(comLines, 'InputStack', 0)
   if not  inStack:
      exitError('Cannot find name of input file in ' + comfile)

   if not os.path.exists(inStack):
      exitError('Input stack ' + inStack + ' does not exist')

   try:
      dimensions = getmrcsize(inStack)
   except ImodpyError:
      exitFromImodError(progname)


views = optionValue(comLines, 'StartingEndingViews', 1)
if not views:
   views = (1, dimensions[2])

if not ifStartNum:
   cleanChunkFiles(rootWithDir)

viewdel = 'StartingEndingViews'
total = 1 + views[1] - views[0]

numSlabs = (total + maxSlices - 1) // maxSlices
slabSize = total // numSlabs
remainder = total % numSlabs

(outRoot, outExt) = os.path.splitext(outFile)
thisCom = rootWithDir + '-start' + comExt
if ifStartNum:
   thisCom = rootWithDir + fmtstr('-{:03d}-sync{}', startNum, comExt)

for line in comLines:
   if line.startswith('$setenv IMOD_OUTPUT_FORMAT'):
      break
else:    # ELSE ON FOR
   outFormat = os.getenv('IMOD_OUTPUT_FORMAT')
   if not outFormat or outFormat not in standardTypeExtensions():
      outFormat = 'MRC'
   comLines.insert(0, '$setenv IMOD_OUTPUT_FORMAT ' + outFormat)

sedlist = ['/' + viewdel + '/d',
           '/DefocusFile/a/StartingEndingViews -1  -1/',
            fmtstr('/DefocusFile/a/TotalViews {} {}/', views[0], views[1])]
sedLines = pysed(sedlist, comLines)
sedLines.append('$sync')
writeTextFile(thisCom, sedLines)
totalComs = 1

boundFile = rootname + '-bound.info'
if uniqueInfo:
   boundFile = fmtstr('{}-bound-{:03d}.info', rootname, startNum)
width = dimensions[0]
boundLines = (boundPixels + width - 1) // width
ny = dimensions[1]
boundLines = min(boundLines, ny // 2 + 1)

boundOut = [fmtstr('1 0 {} {} {}', width, boundLines, numSlabs)]

origOffset = views[0]
firstview = views[0]
for num in range(1, numSlabs + 1):
   thisCom = rootWithDir + fmtstr('-{:03d}{}', num + startNum, comExt)
   origEnd = origOffset + slabSize - 1 
   if num <= remainder:
      origEnd += 1
   sedlist = ['/' + viewdel + '/d',
              fmtstr('/DefocusFile/a/StartingEndingViews {}  {}/', origOffset, origEnd),
              fmtstr('/DefocusFile/a/TotalViews {} {}/', views[0], views[1]),
              '/DefocusFile/a/BoundaryInfoFile ' + boundFile + '/']
   pysed(sedlist, comLines, thisCom)
   totalComs += 1

   boundStart = origOffset - firstview
   boundEnd = origEnd - firstview
   if num == 1:
      boundStart = -1
   if num == numSlabs:
      boundEnd = -1
   boundOut.append(fmtstr('{}-{:03d}.{}', outRoot, num, boundExt))
   boundOut.append(fmtstr('{} 0 {} -1', boundStart, boundEnd))

   origOffset = origEnd + 1

writeTextFile(boundFile, boundOut)

dirMess = ''
if outDir:
   outDir = '"' + outDir + '"'
   dirMess = 'in directory ' + outDir
thisCom = rootWithDir + '-finish' + comExt
if leaveOpen:
   thisCom = rootWithDir + fmtstr('-{:03d}-sync{}', numSlabs + startNum + 1, comExt)
   
finLines = [fmtstr('$fixboundaries "{}" "{}"', outFile, boundFile),
            fmtstr('$collectmmm pixels= "{}" {} "{}" {} {}', rootname, numSlabs, outFile,
                   startNum + 1, outDir),
            fmtstr('$b3dremove -g "{}-[0-9][0-9][0-9]*.{}" "{}"', outRoot, boundExt,
                   boundFile)]
if not leaveOpen:
   finLines.append(
   fmtstr('$b3dremove -g {0}-[0-9][0-9][0-9]*' + comExt + '* {0}-[0-9][0-9][0-9]*.log* ' +
          '{0}-start*.* {0}-finish*' + comExt + '*', rootWithDir))
writeTextFile(thisCom, finLines)
totalComs += 1

prnstr(fmtstr("{} command files for {} chunks created {}", totalComs, numSlabs, dirMess))
sys.exit(0)
