#!/usr/bin/env python
# findsirtdiffs - collect difference statistics from sirt log files and combine
#
# Author: David Mastronarde
#
# $Id: findsirtdiffs,v 937343107256 2023/02/19 22:44:16 mast $

progname = 'findsirtdiffs'
prefix = 'ERROR: ' + progname + ' - '

def diffKey(item):
   return item[0]

# load System Libraries
import sys, os, glob, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import exitError, setExitPrefix

setExitPrefix(prefix)
if len(sys.argv) < 2:
   exitError('Root name of log files must be entered')

# Read the log files, find the lines, and extract and store the numbers
logfiles = glob.glob(sys.argv[1] + '-[0-9]*.log')
diffs = []
for log in logfiles:
   lines = readTextFile(log)
   for l in lines:
      if l.find('diff rec mean') >= 0:
         lspl = l.split()
         try:
            diffs.append([int(lspl[1].rstrip(',')), int(lspl[3]), 
                              int(lspl[4].rstrip(',')), float(lspl[8]), float(lspl[9])])
         except Exception:
            exitError('Converting a line of difference statistics')

# Sort by iteration
diffs.sort(key = diffKey)

ind = 0
while ind < len(diffs):

   # Loop through data, set up variables based on one entry for this iteration
   start = diffs[ind][1]
   end = diffs[ind][2]
   mean = diffs[ind][3]
   sd = diffs[ind][4]
   itern =  diffs[ind][0]
   jlast = ind
   if ind < len(diffs) - 1 and itern ==  diffs[ind+1][0]:

      # But if there are multiple iterations, accumulate sums and get the overall values
      dsum = 0.
      dsumsq = 0.
      numsum = 0
      for jnd in range(ind, len(diffs)):
         if diffs[jnd][0] != itern:
            break
         num = diffs[jnd][2] + 1 - diffs[jnd][1]
         numsum += num
         dsum += num * diffs[jnd][3]
         dsumsq += (num - 1) * diffs[jnd][4] * diffs[jnd][4]
         start = min(start, diffs[jnd][1])
         end = max(end, diffs[jnd][2])
         jlast = jnd

      mean = dsum / numsum
      sd = math.sqrt(dsumsq / (numsum - 1))
   
   ind = jlast + 1
   prnstr(fmtstr('{} {:3d}, {}{:6d}{:6d}, {} {} {}{:15.3f}{:15.3f}', lspl[0], itern, 
                 lspl[2], start, end, lspl[5], lspl[6], lspl[7], mean, sd))

sys.exit(0)
