#!/usr/bin/env python
# tomodataplots - Plot selected data from tomogram processing
#
# Author: David Mastronarde
#
# $Id: tomodataplots,v a3b92fc28684 2025/01/21 00:37:05 mast $
#

progname = 'tomodataplots'
prefix = 'ERROR: ' + progname + ' - '

def findLimitingLines(lines, startText, endText, startLook = 0):
   reStart = re.compile(startText)
   reEnd = re.compile(endText)
   startLine = -1
   for ind in range(startLook, len(lines)):
      if startLine < 0 and re.search(reStart, lines[ind]):
         startLine = ind
      if startLine >= 0 and re.search(reEnd, lines[ind]):
         return (startLine, ind)

   return (startLine, -1)


# Extract binning and filter value from a "Results" line of Alignframes log
def extractBinAndRad2(line):
   line = line.replace('=', ' ')
   lsplit = line.split()
   bin = None
   rad2 = None
   for ind in range(len(lsplit) - 1):
      if lsplit[ind] == 'bin':
         bin = lsplit[ind + 1]
      if lsplit[ind] == 'rad2':
         rad2 = lsplit[ind + 1]

   return (bin, rad2)


# Add the best set of data to the array for plotting alignframes data
def addToAlignframesArrays():
   global angles, setNums, allValues, key

   # Use the original 0 key if it is still that, or use the best key
   if key != '0' and bestBin:
      if useHybrid:
         key = (bestBin, None)
      else:
         key = (bestBin, bestRad2)

   if key in setDict:
      angles.append(angle)
      setNums.append(setNum)
      allValues.append(setDict[key])


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, re, datetime, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

tempFile = ''
defaultColors = ['1,navy', '2,maroon', '3,darkgreen']

# 0 = file type, 1 = columns, 2 = ordinal, 3 = symbols, 4 = connect, 5 = keys
typeTable = [(0, [5], 1, [7], 1, ['X shift']),
             (0, [6], 1, [7], 1, ['Y shift']),
             (0, [5, 6], 1, [7, 9], 1, ['X shift', 'Y shift']),
             (1, [1, 2], 1, [0, 0], 1, ['Mean error', 'Max error']),
             (2, [1, 2], 0, [0], 0, ['Rotation']),
             (2, [1, 4, 7], 0, [0, 0], 1, ['Delta tilt', 'Skew']),
             (2, [1, 5], 0, [0], 0, ['Mag']),
             (2, [1, 6], 0, [0], 0, ['X-Stretch (dmag)']),
             (2, [1, 2], 0, [9, 15], 0, ['Mean residual', '(view is multiple of 5)']),
             (2, [1, 2], 0, [9, 15], 0, ['Local mean residual',
                                        '(view is multiple of 5)']),
             (3, [1], 1, [7], 0, ['Minimum value']),
             (3, [2], 1, [7], 0, ['Maximum value']),
             (3, [1, 2], 1, [7, 9], 0, ['Minimum value', 'Maximum value']),
             (4, [1, 2, 3, 4], 0, [5, 7, 9], 0, ['X position', 'Y position',
                                                 'Z position']),
             (5, [1, 2], 0, [9], 1, ['Defocus (microns)']),
             (5, [1, 2], 0, [9], 1, ['Astigmatism (um)']),
             (5, [1, 2], 0, [9], 1, ['Astig axis (deg)']),
             (5, [1, 2], 0, [9], 1, ['Phase shift (deg)']),
             (5, [1, 2], 0, [9], 1, ['Cut-on freq (1/nm)']),
             (6, [1, 2, 3], 0, [7, 9], 1, ['Raw distance (pixels)', 
                                           'Smoothed distance']),
             (6, [1, 2, 3], 0, [7, 9], 1, ['Mean leave-out error (pixels)',
                                           'Mean weighted residual (pixels)']),
             (6, [1, 2], 0, [9], 1, ["Max of max weighted resids"])]

tempNeeded = [False, True, True, True, False, True, True]
axisLabels = ['View number', 'View number', 'View number', 'View number',
              'Tilt Angle (degrees)', 'Tilt Angle (degrees)', 'Tilt Angle (degrees)']

# Fallbacks from ../manpages/autodoc2man 3 1 tomodataplots
options = ["input:InputFile:FN:", "type:TypeOfDataToPlot:IA:",
           "connect:ConnectWithLines:I:", "symbols:SymbolsForGroups:IA:",
           "hue:HueOfGroup:CHM:", "axis:XaxisLabel:CH:", "append:AppendToKey:CH:",
           "size:SizeOfPlot:IP:", "position:PositionOfPlot:IP:",
           "background:BackgroundProcess:B:", ":PID:B:", "help:usage:B:"]

(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 1, 0)

# Get input file, make sure it exists
dataName = PipGetInOutFile('InputFile', 0)
if not dataName:
   exitError('The input file name must be entered')
if not os.path.exists(dataName):
   exitError('Input file ' + dataName + ' does not exist')

doPID = PipGetBoolean('PID', 0)
printPID(doPID)

# Get type of file, and if it needs processing, read the lines and set up for temp file
dataType = PipGetInteger('TypeOfDataToPlot', 0) - 1
if dataType < 0 or dataType >= len(typeTable):
   exitError('A type of data must be entered between 1 and ' + str(len(typeTable)))
fileType = typeTable[dataType][0]
plotFile = dataName
if tempNeeded[fileType]:
   tempFile = imodTempDir() + '/' + progname + '.' + str(os.getpid())
   plotFile = tempFile
   dataLines = readTextFile(dataName)

# Get other options, fall back to defaults for connect
background = PipGetBoolean('BackgroundProcess', 0)
connect = PipGetInteger('ConnectWithLines', 0)
connect = max(0, min(1, connect))
if PipGetErrNo():
   connect = typeTable[dataType][4]
symbolsIn = PipGetIntegerArray('SymbolsForGroups', 0)
numColors = PipNumberOfEntries('HueOfGroup')
axisLabel = PipGetString('XaxisLabel', '')
addToKey = PipGetString('AppendToKey', '')

ifTypes = False

# Handle generic option setup into onegenplot
columns = typeTable[dataType][1]
symbols = typeTable[dataType][3]
keys = typeTable[dataType][5]
if not axisLabel:
   axisLabel = axisLabels[fileType]
if fileType == 3 and dataLines and 'iece' in dataLines[0]:
   axisLabel = axisLabel.replace('View number', 'Piece number')
   
comlines = ['InputDataFile ' + plotFile,
            'ConnectWithLines ' + str(connect),
            'XaxisLabel ' + axisLabel]
if typeTable[dataType][2]:
   comlines.append('OrdinalsForXvalues')
   
# Handle specific file types
# BLEND
outLines = []
if fileType == 1:
   errMatch = re.compile('^.*mean&max.*after.*:')
   for l in dataLines:
      if re.search(errMatch, l):
         outLines.append(re.sub(errMatch, '', l))

   if len(outLines) < 2:
      symbols = [7, 9]

# ALIGN
elif fileType == 2:
   (start, end) = findLimitingLines(dataLines, '^ view.*deltilt', '^$')
   if start < 0 or end < 0:
      exitError('Could not find global solution table in ' + dataName)
   if 'residual' not in keys[0]:
      outLines = dataLines[start + 1 : end]

      # Look for fixed value of second column
      if len(symbols) > 1:
         col = columns[len(columns) - 1] - 1
         colval = ''
         for l in outLines:
            lsplit = l.split()
            if len(lsplit) <= col:
               exitError('Not enough columns in solution table in ' + dataName)
            if not colval:
               colval = lsplit[col]
            elif colval != lsplit[col]:
               break
         else:   # ELSE ON FOR; peel off the last column
            columns = columns[:len(columns) - 1]
            symbols = symbols[:len(symbols) - 1]
            keys = keys[:len(keys) - 1]

   else:

      # Local or global mean residual: first find maximum view in global
      try:
         lsplit = dataLines[end - 1].split()
         lastView = int(lsplit[0])
      except Exception:
         exitError('Converting view number in align log')
      errSum = (lastView + 1) * [0.]
      numInSum = (lastView + 1) * [0]
      ifTypes = True

      # Analyze each line, add residual to sum for view
      while True:
         local = 'global'
         if 'Local mean' in keys[0]:
            local = 'local'
            (start, end) = findLimitingLines(dataLines, '^ view.*deltilt', '^$', end)
            if start < 0 or end < 0:
               break
         for l in dataLines[start + 1 : end]:
            try:
               lsplit = l.split()
               view = int(lsplit[0])
               resid = float(lsplit[7])
            except Exception:
               exitError('Analyzing ' + local + ' solution in ' + dataName)
            if view > lastView:
               exitError('View number higher in local than global solution')
            errSum[view] += resid
            numInSum[view] += 1

         if local != 'local':
            break

      # Make the output lines with means
      for view in range(lastView + 1):
         if numInSum[view]:
            group = 1
            if view % 5 == 0:
               group = 2
            outLines.append(fmtstr('{}  {}  {}', group, view,
                                   errSum[view] / numInSum[view]))
      if not outLines:
         exitError('No local solutions found in ' + dataName)
         
# CLIP STATS
elif fileType == 3:
   (start, end) = findLimitingLines(dataLines, '----', 'all')
   if start < 0 or end < 0:
      exitError('Cannot find starting and ending lines for stats in ' + dataName)
   for l in dataLines[start + 1 : end]:
      l = l.replace('*', ' ')
      try:
         lsplit = l.split()
         parsplit = l.split(')')
         maxsplit = parsplit[1].split()
         outLines.append(lsplit[1] + '  ' + maxsplit[0])
      except Exception:
         exitError('Extracting min and max from lines in ' + dataName)
   if not outLines:
      exitError('No min/max data found in ' + dataName)

# CTFPLOTTER
elif fileType == 5:
   lsplit = dataLines[0].split()
   startLine = 0
   hasAstig = False
   hasPhase = False
   hasCuton = False;
   angles = []
   values = []
   minAxis = 1000.
   maxAxis = -1000.
   minAbsAxis = 1000
   numPlus = 0
   numMinus = 0
   try:
      if len(lsplit) >= 6:
         version = int(lsplit[5])
         if version > 2:
            startLine = 1
            hasAstig = int(lsplit[0]) & 1
            hasPhase = int(lsplit[0]) & 4
            hasCuton = int(lsplit[0]) & 32
      if dataType > 17:
         if not hasCuton:
            exitError('There are no cut-on frequencies in this defocus file')
      elif dataType > 16:
         if not hasPhase:
            exitError('There are no phase shift solutions in this defocus file')
      elif dataType > 14 and not hasAstig:
         exitError('There are no astigmatism solutions in this defocus file')
      
      phaseCol = 5
      cutonCol = 6
      if hasAstig:
         phaseCol = 7
         cutonCol = 8
      
      for l in dataLines[startLine:]:
         if l.strip() == '':
            continue
         lsplit = l.split()
         angle = 0.5 * (float(lsplit[2]) + float(lsplit[3]))
         if dataType < 17:
            defocus1 = float(lsplit[4]) / 1000.
            value = defocus1
            if hasAstig:
               defocus2 = float(lsplit[5]) / 1000.
               axis = float(lsplit[6])
               if defocus2:
                  astig = defocus1 - defocus2
                  defocus1 = 0.5 * (defocus1 + defocus2)
               if dataType == 14:
                  value = defocus1
               else:
                  if abs(defocus2) < 1.e-6:
                     continue
                  if dataType == 15:
                     value = astig
                  else:
                     value = axis
                     minAxis = min(axis, minAxis)
                     maxAxis = max(axis, maxAxis)
                     minAbsAxis = min(minAbsAxis, math.fabs(axis))
                     if axis >= 0:
                        numPlus += 1
                     else:
                        numMinus += 1

         elif dataType == 17:
            value = float(lsplit[phaseCol])
         else:
            value = float(lsplit[cutonCol])

         angles.append(angle)
         values.append(value)
      
   except IOError:
      exitError('Extracting values from the defocus file')

   # If the axis values are extreme on both sides of zero and none are near zero,
   # then adjust the minority sign values by 180
   if minAxis < -60 and maxAxis > 60 and minAbsAxis > 45:
      for ind in range(len(values)):
         if values[ind] < 0 and numPlus > numMinus:
            values[ind] += 180.
         elif values[ind] > 0 and numPlus <= numMinus:
            values[ind] -= 180.
            
   for (angle, value) in list(zip(angles, values)):
      outLines.append(fmtstr('{} {}', angle, value))
      

# ALIGNFRAMES
elif fileType == 6:
   angles = []
   setNums = []
   allValues = []
   haveAllAngles = True
   haveLeaveOut = False
   haveWeighted = False
   gotSet = False
   useHybrid = False
   try:
      for line in dataLines:

         # Convert , to space so splitting is easy
         line = line.replace(',', ' ')

         # Look for valid option to use hybrid result
         if 'UseHybrid' in line and not line.strip().startswith('#'):
            lsplit = line.split('=')
            useHybrid = len(lsplit) < 2 or int(lsplit[1]) != 0
               
         # Set lines can be of 3 kinds
         if line.startswith('Set ') or line.startswith('File '):

            # A line about the best binning/filter
            if ': Best' in line:
               setDict[key] = values
               (bestBin, bestRad2) = extractBinAndRad2(line)

            # A line NOT about dropped frames in FISE
            elif ': drop' not in line:
               if gotSet:
                  # If there is a previous set, put current values in dictionary and
                  # add to data array
                  setDict[key] = values
                  if angle == None:
                     haveAllAngles = False
                  addToAlignframesArrays()

               # Initialize for a set, get the set number and hopefully the degrees
               gotSet = True
               setDict = {}
               key = '0'
               bestBin = None
               values = [None] * 5
               lsplit = line.split()
               angle = None
               setNum = lsplit[1]
               if 'deg' in lsplit[-1]:
                  angle = lsplit[-2].replace('(', '')

         # For a "Results" line, process the previous data if any get new key
         elif line.startswith('Results with') or line.startswith('Hybrid results'):
            setDict[key] = values
            key = extractBinAndRad2(line)
            values = [None] * 5

         elif gotSet:

            # Once there is a set, process each line looking for the values
            line = line.replace('=', ' ')
            lsplit = line.split()
            lastInd = len(lsplit) - 1
            residLine = 'esid' in line and 'mean' in line and 'max max' in line
            if residLine and (' wgtd' in line.lower() or ' weighted' in line.lower()):
               haveWeighted = True
            for ind in range(len(lsplit)):
               if residLine:
                  if ind and lsplit[ind] == 'mean' and 'esid' in lsplit[ind - 1] and \
                     ind < lastInd:
                     values[0] = lsplit[ind + 1]
                  if ind and lsplit[ind] == 'max' and lsplit[ind - 1] == 'max' and \
                     ind < lastInd:
                     values[1] = lsplit[ind + 1]
                  if lsplit[ind] == 'l-o' and ind < lastInd:
                     if ind < lastInd - 1 and lsplit[ind + 1] == 'err':
                        values[4] = lsplit[ind + 2]
                        haveLeaveOut = True
                     elif lsplit[ind + 1] != 'err':
                        values[4] = lsplit[ind + 1]
                        haveLeaveOut = True

               if lsplit[ind] == 'Dist':
                  values[2] = lsplit[ind + 1]
               if 'smooth' in lsplit[ind] or 'smth' in lsplit[ind]:
                  values[3] = lsplit[ind + 1]

      # At end, add last set of data
      if gotSet:
         setDict[key] = values
         addToAlignframesArrays()

   except ValueError:
      exitError('Extracting values from Alignframes output file')

   # Load the data array whenever data exists for an angle
   xvals = setNums
   if haveAllAngles:
      xvals = angles
   else:
      axisLabels[typeTable[dataType][0]] = 'Set number'
   for ind in range(len(xvals)):
      if dataType == 19 and allValues[ind][2] != None and allValues[ind][3] != None:
         outLines.append(xvals[ind] + ' ' + allValues[ind][2] + ' ' + allValues[ind][3])
      elif dataType == 20:
         if haveLeaveOut and allValues[ind][0] != None and allValues[ind][4] != None:
            outLines.append(xvals[ind] + ' ' + allValues[ind][4] + ' '+ allValues[ind][0])
         elif not haveLeaveOut and allValues[ind][0] != None:
            outLines.append(xvals[ind] + ' ' + allValues[ind][0])
      elif dataType == 21 and allValues[ind][1] != None:
         outLines.append(xvals[ind] + ' ' + allValues[ind][1])

   if dataType == 20 and not haveLeaveOut:
      columns = [1, 2]
      symbols = [symbols[1]]
      keys = [keys[1]]
   if dataType == 20 and not haveWeighted:
      keys[len(keys) - 1] = 'Mean residual (pixels)'

if tempFile:
   if not outLines:
      exitError('Did not find any lines of the selected data type in ' + dataName)
   writeTextFile(tempFile, outLines)

# Now that everything is set, make up columns and symbols input
colstr = 'ColumnsToPlot '
for ind in range(len(columns)):
   if ind:
      colstr += ','
   colstr += str(columns[ind])
comlines.append(colstr)
if ifTypes:
   comlines.append('TypesToPlot 1,2')

symstr = 'SymbolsForTypes '
for ind in range(len(symbols)):
   sym = symbols[ind]
   if symbolsIn and ind < len(symbolsIn):
      sym = symbolsIn[ind]
   if ind:
      symstr += ','
   symstr += str(sym)
comlines.append(symstr)

# Append filename, time, or string to key
if addToKey:
   if addToKey == '@file':
      keys[-1] += ' ' + os.path.basename(dataName)
   elif addToKey == '@time':
      keys[-1] += ' ' + datetime.datetime.now().strftime('%H:%M:%S')
   else:
      keys[-1] += ' ' + addToKey

for key in keys:
   comlines.append('KeyLabels ' + key)

# Start with default colors, take in each entry and replace in default if present
if len(symbols) > 1:
   colors = defaultColors
else:
   colors = []
for ind in range(numColors):
   newColor = PipGetString('HueOfGroup', '')
   ncsplit = newColor.split(',')
   for ind in range(len(colors)):
      dcsplit = colors[ind].split(',')
      if dcsplit[0] == ncsplit[0]:
         colors[ind] = newColor
         break
   else:  # ELSE ON FOR
      colors.append(newColor)
for col in colors:
   comlines.append('HueOfGroup ' + col)

(xsize, ysize) = PipGetTwoIntegers('SizeOfPlot', 0, 0)
if xsize > 0 and ysize > 0:
   comlines.append(fmtstr('SizeOfPlot {},{}', xsize, ysize))
(xpos, ypos) = PipGetTwoIntegers('PositionOfPlot', 0, 0)
if not PipGetErrNo():
   comlines.append(fmtstr('PositionOfPlot {},{}', xpos, ypos))

# Run onegenplot
if background:

   # In the background: compose a command array and add the cleanup option
   comArray = ['onegenplot']
   #if tempFile:
   #   comArray.append('-remove')
   for line in comlines:
      lsplit = line.split(' ', 1)
      comArray.append('-' + lsplit[0])
      if len(lsplit) > 1:
         comArray.append(lsplit[1])

   err =  bkgdProcess(comArray, None, 'stdout', True)
   if err:
      exitError('Cannot start onegenplot: ' + err)
   sys.exit(0)
         
else:

   # In foreground, run as usual command
   try:
      prnstr('Close graph window to exit', flush = True)
      runcmd('onegenplot -StandardInput', comlines)
   except ImodpyError:
      if tempFile:
         cleanupFiles([tempFile])
      exitFromImodError(progname)

if tempFile:
   cleanupFiles([tempFile])
sys.exit(0)
