#!/usr/bin/env python
# imodkillgroup - kill a process group given member PID or tree below given PID
#
# Author: David Mastronarde
#
# $Id: imodkillgroup,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'imodkillgroup'
prefix = 'ERROR: ' + progname + ' - '

def processStatus(pid = None):
   statDict = {}
   if windows:
      procList = []
      try:
         if pid:
            procList.append(psutil.Process(pid))
         else:
            procList = psutil.process_iter();
      except Exception:
         exitError('Getting process list from psutil: ' +  str(sys.exc_info()[1]))
      for proc in procList:
         if proc.pid:
            try:
               if newPsAPI:
                  statDict[proc.pid] = (proc.ppid(), proc.exe(), proc)
               else:
                  statDict[proc.pid] = (proc.ppid, proc.exe, proc)
               #prnstr(fmtstr('{}  {}  {}', proc.pid,  statDict[proc.pid][0],
               #  statDict[proc.pid][1]))
            except psutil.AccessDenied:
               pass
            except psutil.NoSuchProcess:
               pass
            except Exception:
               prnstr('imodkillgroup - Exception accessing status for pid ' + \
                         str(proc.pid) + ': ' + str(sys.exc_info()[1]))
               pass
      return statDict
         
   # Now get a command with desired columns, or basic cygwin one
   command = 'ps -aeo pid,ppid,comm'
   if cygwin:
      command = 'ps'
   if pid:
      command += ' -p ' + str(pid)
   try:
      pslines = runcmd(command)
   except ImodpyError:
      errs = getErrStrings()
      return errs[0]

   # Split up the lines and fetch the fields; skip leading S,I in cygwin
   for line in pslines[1:]:
      ind = 0
      if cygwin:
         ind = max(0, line.find(' '))
      lsplit = line[ind:].split()
      if len(lsplit) < 3:
         continue
      try:
         if cygwin:
            statDict[int(lsplit[0])] = (int(lsplit[1]), lsplit[len(lsplit) - 1], 
                                        int(lsplit[2]))
         else:
            statDict[int(lsplit[0])] = (int(lsplit[1]), lsplit[len(lsplit) - 1])
      except ValueError:
         continue
   
   return statDict


def killGroup(groupid):
   try:
      os.killpg(groupid, killSignal)
   except OSError:
      return 'imodkillgroup - Error killing processes with group ID ' + \
          str(groupid) + ': ' + str(sys.exc_info()[1])
   return None
      

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, signal, time

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import exitError, setExitPrefix
setExitPrefix(prefix)
#setRetryLimit(10, 3.)

cygwin = 'cygwin' in sys.platform
windows = 'win32' in sys.platform
killTree = False
pidList = {}
verbose = False
maxTrials = 5
getStatus = False
useTerm = False
killSignal = signal.SIGKILL

# Get/check arguments
try:
   for arg in sys.argv[1:]:
      if arg == '-t':
         killTree = True
      elif arg == '-v':
         verbose = True
      elif arg == '-s':
         getStatus = True
      elif arg == '-e':
         useTerm = True
      else:
         pidList[int(arg)] = None

except ValueError:
   exitError('Unrecognized option or non-integer entry: ' + arg)

if not pidList:
   prnstr("""Usage: imodkillgroup [-t | -v | -s] PID [PID ...]'
    Kills whole process group given PID of any member, or just kills process
    tree below each PID if -t option is given.  Takes multiple PIDs from
    different groups.  Options:
      -t  Tree kill, implied with Windows Python; unreliable with Cygwin Python
      -s  Use initial and follow-up ps to verifying children are gone (Cygwin)
      -e  Use SIGTERM instead of SIGKILL to allow cleanup actions
      -v  Verbose output""")
   sys.exit(0)

if getStatus and not cygwin:
   exitError('The -s option works only in Cygwin Python')

if useTerm:
   killSignal = signal.SIGTERM
   
if windows:
   try:
      import psutil
      newPsAPI = newPsutilAPI(psutil.__version__)
   except ImportError:
      exitError('The psutil module must be installed to kill processes from Windows ' +\
                'Python')
      

# PROCESS GROUP KILLING
exitVal = 0
if not killTree and not windows:

   # Set up variables and get all the group IDs first
   numPids = len(pidList)
   psFailed = False
   numDone = 0
   if not getStatus:
      maxTrials = 2
   groupIds = {}
   pidDone = {}
   groupStats = {}
   messList = {}
   for pidKill in pidList:
      pidDone[pidKill] = 0
      groupStats[pidKill] = {}
      messList[pidKill] = []
      try:
         groupIds[pidKill] = os.getpgid(pidKill)
         if verbose:
            prnstr('Group id: ' + str(groupIds[pidKill]))
      except OSError:
         prnstr('imodkillgroup - Error getting group ID for process ID ' + str(pidKill) +\
                   ': ' + str(sys.exc_info()[1]))
         exitVal += 1
         pidDone[pidKill] = 1
         numDone += 1
         continue

   if exitVal == numPids:
      exitError('No group ID(s) could be found')

   # For cygwin getting status, do an initial ps and get lists of all PID's 
   if cygwin and getStatus:
      psdict = processStatus()
      if isinstance(psdict, str):
         psFailed = True
         maxTrials = 2
         prnstr('imodkillgroup - Error getting initial ps: ' + psdict + \
                '  just going to do two group kills without checking processes')
      else:
         for pidKill in pidList:
            if not pidDone[pidKill]:
               for pid in psdict:
                  if psdict[pid][2] == groupIds[pidKill]:
                     groupStats[pidKill][pid] = psdict[pid]
      
   # Loop on multiple trials for cygwin
   for trial in range(maxTrials):
      if verbose:
         prnstr('Trial ' + str(trial + 1) + ' for killing groups')
      for pidKill in pidList:
         if not pidDone[pidKill]:
            mess = killGroup(groupIds[pidKill])

            # If doing blind pair of kills because ps failed, or second trial when
            # not getting status, cancel message if it is just no such process
            if mess and (psFailed or not getStatus) and trial and \
                   'No such process' in mess:
               mess = None
            messList[pidKill].append(mess)
            if mess:
               if not cygwin:
                  exitVal += 1
               if not cygwin or verbose:
                  prnstr(mess)

            # If there is no message (anymore), mark as done if not getting status or
            # if on second trial
            elif not getStatus or trial:
               pidDone[pidKill] = 1
               numDone += 1

      # Break out always for Unix, or on first trial when not getting status and all OK
      # Otherwise it will finish the loop so error processing/reporting occurs
      if not cygwin or (numDone == numPids and trial == 0):
         break
      if psFailed or not getStatus:
         continue

      # For status, get new ps and count up the processes left or just rerun if that fails
      time.sleep(0.05)
      psdict = processStatus()
      if isinstance(psdict, str):
         prnstr('imodkillgroup - Error getting follow-up ps on trial ' + str(trial) + \
                   ': ' + psdict)
         numAllLeft = -1
      else:
         numAllLeft = 0
         for pidKill in pidList:
            numLeft = 0
            if not pidDone[pidKill]:
               for pid in groupStats[pidKill]:
                  if pid in psdict and psdict[pid][2] == groupIds[pidKill] and \
                         psdict[pid][0] == groupStats[pidKill][pid][0] and \
                         psdict[pid][1] == groupStats[pidKill][pid][1]:
                     numLeft += 1
               if not numLeft:
                  pidDone[pidKill] = 1
                  if verbose:
                     prnstr('Children all gone for group ID ' + str(groupIds[pidKill]))
            numAllLeft += numLeft

      if not numAllLeft:
         break
      if numAllLeft > 0:
         prnstr(fmtstr('imodkillgroup - {} processes left after trial {}', numAllLeft,
                       trial))
   else:    # ELSE ON FOR
      for pid in pidList:
         if not pidDone[pid]:
            exitVal += 1
      if exitVal:
         if psFailed or not getStatus:
            prnstr('imodkillgroup - Kill may have failed, messages on trials:')
         else:
            prnstr('imodkillgroup - Kill of group failed, messages on trials:')
         for pid in pidList:
            for ind in range(len(messList[pid])):
               if messList[pid][ind]:
                  prnstr(fmtstr('{:5d} - {}: {}', pid, ind + 1, messList[pid][ind]))
      elif verbose and (psFailed or not getStatus):
         prnstr('All processes not present on second group kill')

   sys.exit(exitVal)


# For TREE KILL, Start an array of PID's for each level
pidTree = [pidList]
for level in range(100):

   # Need a ps to get going for cygwin maybe windows too
   if not level and (cygwin or windows):
      psdict = processStatus()
      if isinstance(psdict, str):
         exitError(prefix + psdict)

   # Stop processes for PID's at the current level
   # In windows/cygwin only stop python and tcsh processes, due to a cygwin bug
   # where stopped tilt could not be killed
   for pid in pidTree[level]:
      stopProc = True
      if cygwin or windows:
         stopProc = False

         # In cygwin, it seemed to help to stop only the controlling scripts, but not
         # bad things happened even stopping them.  So just do stop in Windows, 
         # for all processes
         if pid in psdict:
            stopProc = windows
            # For windows, save the process object now if it is found in this process list
            if windows:
               if verbose:
                  prnstr('Saving process object for PID ' + str(pid))
               pidTree[level][pid] = psdict[pid][2]

      if stopProc:
         stopstr = 'Stopping PID ' + str(pid)
         if (level or cygwin or windows) and pid in psdict:
            stopstr += ': ' + psdict[pid][1]
         if verbose:
            prnstr(stopstr)
         try:
            if windows:
               psdict[pid][2].suspend()
            else:
               os.kill(pid, signal.SIGSTOP)
         except Exception:
            prnstr('imodkillgroup - Error occurred trying to stop ' + str(pid) + ': ' + \
                str(sys.exc_info()[1]))
      
   # Get a ps and first find out if each PID is still there
   psdict = processStatus()
   if isinstance(psdict, str):
      exitError(prefix + psdict)
   pidTree.append({})
   for pid in pidTree[level]:
      if not pid in psdict:
         prnstr('imodkillgroup - PID ' + str(pid) + ' is no longer in the process list')
         exitVal += 1
         pidTree[level][pid] = -1
         
   # find children of these processes
   for pid in psdict:
      parent = psdict[pid][0]
      if parent in pidTree[level] and (not windows or pidTree[level][parent]):
         pidTree[level + 1][pid] = None
         if verbose:
            prnstr(fmtstr('Adding child {} - {} of {} at level {}', pid, psdict[pid][1], 
                          parent, level))

   if not pidTree[level + 1]:
      break

# Kill all the processes from the bottom level up
for level in range(len(pidTree) - 1, -1, -1):
   for pid in pidTree[level]:
      if isinstance(pidTree[level][pid], int):
         continue
      if verbose and (not windows or pidTree[level][pid]):
         prnstr(fmtstr('Killing PID {} at level {}', pid, level))

      try:
         if windows:
            proc = pidTree[level][pid]
            if proc:
               proc.kill()
         else:
            if useTerm:
               os.kill(pid, signal.SIGCONT)
            os.kill(pid, killSignal)
      except Exception:
         prnstr('imodkillgroup - Error occurred trying to kill ' + str(pid) + ': ' + \
                str(sys.exc_info()[1]))
         exitVal += 1

sys.exit(exitVal)
