#!/usr/bin/env python
# gpuallocator - a program to allocate GPUs when needed to different processes 
# from a common pool of GPUs.  Both the processes and the GPUs can be on different
# machines.  Call with neither -maxadd nor -fulllist to release an allocation
#
# Author: David Mastronarde
#
# $Id: gpuallocator,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'gpuallocator'
prefix = 'ERROR: ' + progname + ' - '
fileSuffix = '.gpuUse'
commonDir = '.'
countLim = 15
verbose = 0


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, socket, datetime, time

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

isWindows = 'win32' in sys.platform
if isWindows:
   import msvcrt
else:
   import fcntl

# Fallbacks from ../manpages/autodoc2man 3 1 gpuallocator
options = ["rootname:RootNameForProcesses:CH:", "pid:ProcessID:I:",
           "controller:ControllingMachine:CH:", "fulllist:FullListOfGPUs:CH:",
           "maxgpu:MaximumGPUsToAssign:I:", "common:CommonDirectory:CH:",
           "verbose:VerboseOutput:B:", "help:usage:B:"]

PipEnableEntryOutput(False)
(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 0)

if not numOpts:
   PipPrintHelp(progname, 0, 0, 0)
   sys.exit(0)

commonDir = PipGetString('CommonDirectory', '.')

verbose = PipGetBoolean('VerboseOutput', 0)
rootname = PipGetString('RootNameForProcesses', '')
if not rootname:
   exitError('A root name shared by the processes must be entered')

PID = PipGetInteger('ProcessID', -1)
fullList = PipGetString('FullListOfGPUs', '')
maxAdd = PipGetInteger('MaximumGPUsToAssign', 0)
controller = PipGetString('ControllingMachine', '')
if PID < 0 or not controller:
   exitError('You must enter a PID and controller')
if (fullList and not maxAdd) or (maxAdd and not fullList):
   exitError('You must enter both -full and -maxadd, or neither')

filename = os.path.join(commonDir, rootname + fileSuffix)
usageList = []
lockfile = filename + '.lock'
retryTime = 5.0
startTime = datetime.datetime.now()
opened = False
mess = 'Could not open'

while (datetime.datetime.now() - startTime).seconds < retryTime:
   try:
      if not opened:
         lockf = open(lockfile, 'w')
         opened = True
         mess = 'Could not obtain exclusive lock on'

      if isWindows:
         msvcrt.locking(lockf.fileno(), msvcrt.LK_RLCK, 10)
      else:
         fcntl.lockf(lockf, fcntl.LOCK_EX | fcntl.LOCK_NB)

      mess = ''
      break

   except Exception:
      time.sleep(0.1)

if mess:
   exitError(mess + ' locking file')

if verbose:
   prnstr('Got lock in ' + str((datetime.datetime.now() - startTime).seconds))

# Read our file if it exists
if os.path.exists(filename):
   usageList = readTextFile(filename)
      
# make list of all the machines to do ps on
hostname = socket.gethostname().split('.')[0]
checkList = []
listChanged = True
for line in usageList:
   lsplit = line.split()
   if len(lsplit) >  0 and not checkList.count(lsplit[0]):
      checkList.append(lsplit[0])

# For each machine, run a ps
for contHost in checkList:
   psCom = 'ps -ae'
   if 'win32' in sys.platform:
      psCom = 'b3dwinps'
   if contHost == hostname:
      command = psCom
   else:
      command = 'ssh -x -o PreferredAuthentications=publickey ' + \
                '-o StrictHostKeyChecking=no ' + contHost + ' ' + psCom
   try:
      pslines = runcmd(command)
   except ImodpyError:
      exitFromImodError(progname)

   # Loop on lines in the usage list
   checked = []
   for ind in range(len(usageList) - 1, -1, -1):
      uline = usageList[ind]
      lsplit = uline.split()
      if len(lsplit) < 2:
         continue
      lineHost = lsplit[0]
      linePID = lsplit[1]
      ourLine = linePID == str(PID) and lineHost == contHost

      # If this line is for a process on this host, look for the process
      if lineHost == contHost:
         found = False
         for pline in pslines[1:]:
            while len(pline) > 0 and not pline[0].isdigit():
               pline = pline[1:]
            if len(pline) > 0:
               psplit = pline.split()
               if len(psplit) > 0 and psplit[0] == linePID:
                  found = True
                  break

         # If the controlling process is not found on this machine's ps, take it out
         # of the list.  Also drop our own processes
         if not found or ourLine:
            if verbose and not checked.count(linePID) and not ourLine:
               prnstr(fmtstr("Controller {}, PID {} no longer running", lineHost,
                             linePID))
            if verbose and len(lsplit) > 3:
               prnstr(fmtstr("Removing {} {} from usage list", lsplit[2], lsplit[3]))
            usageList.pop(ind)
            listChanged = True

         checked.append(linePID)

if fullList:
   numAdded = 0
   machines = fullList.split(',')
   for machine in machines:
      gpus = machine.split(':')

      # Assume a 0 if no GPU is specified
      if len(gpus) == 1:
         gpus.append('0')
      for i in range(1, len(gpus)):
         free = 1
         for ind in range(len(usageList)):
            line = usageList[ind]
            fields = line.split()
            if len(fields) < 4:
               continue
            if fields[2] == gpus[0] and fields[3] == gpus[1]:
               free = 0
               if verbose:
                  prnstr(fmtstr('GPU {} on {} is in use', gpus[i], gpus[0]))

         if free:
            usageList.append(fmtstr('{} {} {} {}', controller, PID, gpus[0], gpus[i]))
            listChanged = True
            numAdded += 1

            # THE output to the caller if all goes well
            if gpus[i] != '0':
               prnstr(gpus[0] + ':' + gpus[i])
            else:
               prnstr(gpus[0])
            if verbose:
               prnstr('Adding: ' + usageList[-1])
            if numAdded >= maxAdd:
               break
            
      if numAdded >= maxAdd:
         break

if listChanged:
   if len(usageList):
      writeTextFile(filename, usageList)
   else:
      try:
         os.remove(filename)
      except Exception:
         if verbose:
            prnstr("Error removing existing usage log " + filename)

# Unlock and close the lock file
try:
   if isWindows:
      msvcrt.locking(lockf.fileno(), msvcrt.LK_UNLCK, 10)
   else:
      fcntl.lockf(lockf, fcntl.LOCK_UN)

   lockf.close();
except Exception:
   if verbose:
      prnstr("Error removing lock or closing file when done: " + lockfile)

if fullList and not numAdded:
   exitError('No GPUs were free    [GPA1]')

sys.exit(0)

