#!/usr/bin/env python
# splitbatch - to divide a batch run into multiple jobs
#
# Author: David Mastronarde
#
# $Id: splitbatch,v 4784d6af50ed 2023/05/26 18:54:09 mast $
#

progname = 'splitbatch'
prefix = 'ERROR: ' + progname + ' - '

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, copy, platform

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

options = ['comfile:CommandFile:FN:Command file for running Batchruntomo (required)',
           'maxgpu:MaxGPUsForOneJob:I:Maximum # of GPUs that one batch run would use ' + \
           '(default 4)',
           'help:usage:B:']

PipExitOnError(False, prefix)
(numOpts, numNonOpts) = PipParseInput(sys.argv, options)

ifHelp = PipGetBoolean('help', 0)
if not numOpts or ifHelp:
   PipPrintHelp(progname, 0, 0, 0)
   sys.exit(0)

# Get options
comfile = PipGetInOutFile('CommandFile', 0)
(comfile, rootname) = completeAndCheckComFile(comfile)
(root, comExt) = os.path.splitext(comfile)
maxGPU = PipGetInteger('MaxGPUsForOneJob', 4)
if maxGPU < 1:
   exitError('The maximum number of GPUs to use must be at least 1')

# Read command file
fullLines = readTextFile(comfile)

# Collect the data-set specific lines into separate lists and everything
# else except what will be added below into the common list
commonLines = []
directives = []
setRoots = []
currentDirs = []
deliverDirs = []

for line in fullLines:
   line = line.strip()
   if line.startswith('DirectiveFile'):
      directives.append(line)
   elif line.startswith('RootName'):
      setRoots.append(line)
   elif line.startswith('CurrentLocation'):
      currentDirs.append(line)
   elif line.startswith('DeliverToDirectory'):
      deliverDirs.append(line)
   elif not (line.startswith('CPUMachineList') or line.startswith('SingleOnFirstCP') or \
             line.startswith('LimitLocalThreads') or line.startswith('GPUMachineList')):
      commonLines.append(line)

# If there is a remote directory entry, figure out the underlying mount rule
# and add it as translation options to the batch files
remoteRoot = optionValue(fullLines, 'RemoteDirectory', STRING_VALUE)
if remoteRoot:
   comDirRoot = imodAbsPath(os.path.dirname(comfile)).rstrip('/\\')
   remoteRoot = remoteRoot.rstrip('/\\')

   # Strip off directories that match, quit when last directory does not
   while remoteRoot and comDirRoot:
      remoteTemp = os.path.split(remoteRoot)
      comTemp = os.path.split(comDirRoot)
      if remoteTemp[1] != comTemp[1]:
         break
      remoteRoot = remoteTemp[0]
      comDirRoot = comTemp[0]

   if len(remoteRoot) > 1 or len(comDirRoot) > 1:
      for line in commonLines:
         if 'TranslatePathsFrom' in line and comDirRoot in line:
            break
      else:    # ELSE ON FOR
         commonLines.append('TranslatePathsFrom ' + comDirRoot)
         commonLines.append('TranslatePathsTo ' + remoteRoot)

numSets = len(setRoots)
if not numSets:
   numSets = len(directives)

# Check the validity of these lists
for (opts, name) in list(zip((directives, currentDirs, deliverDirs),
                             ('DirectiveFile', 'CurrentLocation', 'DeliverToDirectory'))):
   if len(opts) > 1 and len(opts) != numSets:
      exitError(fmtstr('The number of entries for {} ({}) does not match the number ' + \
                       'of data sets ({})', name, len(opts), numSets))

# Check if queue
queueCommand = optionValue(fullLines, 'QueueCommand', STRING_VALUE)
if queueCommand and optionValue(fullLines, 'MaxJobsOnQueue', INT_VALUE) == None:
   exitError('Command file must have a MaxJobsOnQueue entry if it has QueueCommand')

# Get CPU list and object if none
cpuList = optionValue(fullLines, 'CPUMachineList', STRING_VALUE)
if queueCommand and cpuList:
   exitError('CPUMachineList cannot be included with QueueCommand')

#Check validity of some other entries
gpuQueue = optionValue(fullLines, 'GPUQueueCommand', STRING_VALUE)
if gpuQueue:
   if optionValue(fullLines, 'MaxGPUJobsOnQueue', INT_VALUE) == None:
         exitError('Command file must have a MaxGPUJobsOnQueue entry if it has ' + \
                   'GPUQueueCommand')

coresPerJob = optionValue(fullLines, 'CoresPerClusterJob', INT_VALUE, numVal = 1)
gpusPerJob = optionValue(fullLines, 'GPUsPerClusterJob', INT_VALUE, numVal = 1)
if coresPerJob != None:
   if queueCommand:
      exitError('The command file has both a CoresPerClusterJob and a QueueCommand')
   if coresPerJob <= 0:
      exitError('The command file has a non-positive value for CoresPerClusterJob')
   if cpuList:
      exitError('CPUMachineList cannot be included with CoresPerClusterJob')


if gpusPerJob != None:
   if not coresPerJob:
      exitError('The command file has a GPUsPerClusterJob but no CoresPerClusterJob')
   if gpuQueue:
      exitError('The command file has both GPUsPerClusterJob and GPUQueueCommand')
   if gpusPerJob <= 0:
      exitError('The command file has a non-positive value for GPUsPerClusterJob')

if not cpuList and not (queueCommand or coresPerJob != None):
   exitError('The command file must include a CPUMachineList entry with multiple CPUs')

# Get GPU list
# If the list is 1 because local GPU was selected, need to see if there are other machines
gpuList = optionValue(fullLines, 'GPUMachineList', STRING_VALUE)
if gpuList != None and (queueCommand or coresPerJob != None):
   exitError('GPUMachineList cannot be included with QueueCommand or CoresPerClusterJob')
if gpuList == '1':
   cpuArray = []
   for machine in cpuList.split(','):
      if machine not in cpuArray:
         cpuArray.append(machine)

   # IF there is more than one machine or the name is not localhost, try to find 
   # this hostname in the list and use it as specified
   if len(cpuArray) > 1 or cpuArray[0] != 'localhost':
      localName = platform.node()
      for machine in cpuArray:
         if machine.split('.')[0] == localName.split('.')[0]:
            if len(cpuArray) > 1:
               gpuList = machine
            break
         else:   # ELSE ON FOR: Otherwise use the full name
            gpuList = localName

# Add common options
commonLines.append('ParallelBatchRootName ' + rootname)
if gpuList:
   commonLines.append('GPUMachineList ' + gpuList)
   commonLines.append('MaxGPUsInParallelBatch ' + str(maxGPU))

# Parse the CPU list, make list of machine names and get total count
if cpuList:
   numCPUbyInt = 0
   numCPUtot = 0
   machines = []
   try:
      numCPUbyInt = int(cpuList)
      numCPUtot = numCPUbyInt
      machines = [['localhost', numCPUtot, numCPUtot]]
   except ValueError:
      for machine in cpuList.split(','):
         msplit = machine.replace('#', ':').split(':')
         if len(msplit) > 2:
            exitError('A machine name cannot be followed by two : or # signs')
         if len(msplit) < 2:
            numCPU = 1
         else:
            try:
               numCPU = int(msplit[1])
               if numCPU < 1:
                  exitError('The value after : or # is less than 1 in ' + machine)
            except ValueError:
               exitError('Failed to convert value after : or # to integer in ' + machine)
         numCPUtot += numCPU
         machines.append([msplit[0], numCPU, numCPU])

   if numCPUtot < 2:
      exitError('The CPUMachineList entry only contains a single CPU')

cleanChunkFiles(rootname)

checkFile = optionValue(fullLines, 'CheckFile', STRING_VALUE)
commonLines.append('SingleOnFirstCPU')

for dset in range(numSets):
   allComLines = copy.deepcopy(commonLines)
   for opts in (directives, currentDirs, deliverDirs, setRoots):
      if len(opts):
         ind = min(dset, len(opts) - 1)
         allComLines.append(opts[ind])
   comName = fmtstr('{}-{:03d}{}', rootname, dset + 1, comExt)
   writeTextFile(comName, allComLines)

writeTextFile(rootname + '-finish' + comExt,
              [fmtstr('$b3dremove -g {0}-[0-9][0-9][0-9]*' + comExt +
                      '* {0}-finish*' + comExt + '*', rootname)])


prnstr(str(numSets + 1) + ' command files created with root name ' + rootname + \
       ' to be run with:')
prnstr(fmtstr('   "processchunks -M # {} {}"', cpuList, rootname))
prnstr(' where # is the maximum # of files to run in parallel')
if checkFile:
   prnstr('To stop all processing, use "echo Q > ' + checkFile + '"')
   prnstr('   echo F instead of Q to finish current sets before quitting')
sys.exit(0)
