#!/usr/bin/env python
# edgepatches - program to set up a supermontage, find overall shifts between
# pieces, and compute patch correlation vectors in the edges (pverlap zones)
#
# Author: David Mastronarde
#
# $Id: edgepatches,v bbc2abd22ec9 2024/09/11 17:19:17 mast $
#

# FUNCTIONS

# The meat of the computations - a function to find shifts on edges
def findEdgeShifts(lower, upper, edge):
   prnstr("Finding global shifts for edge " + edge['name'])
   
   # Compute the starting and ending coordinates in long dimension and overlap dimension
   xyminlo = [0, 0]
   xymaxlo = [0, 0]
   xyminup = [0, 0]
   xymaxup = [0, 0]
   losize = lower[kSize]
   upsize = upper[kSize]

   # First access the appropriate shift at this edge
   # If original shifts exists and actual shifts do not, copy over
   if kOrigShift in edge and kShift not in edge:
      edge[kShift] = edge[kOrigShift]
   intshift = [int(round(edge[kShift][0])), int(round(edge[kShift][1])),
              int(round(edge[kShift][2]))]

   # Use last shifts plus original Z shift when using projxf, or use original shifts
   # when ignoring shifts
   if useProjxf:
      if kLastShift not in edge:
         exitError('Cannot use .projxf shifts; there is no lastCorrShift entry for ' + \
                      'this edge in the info file')
      zshift = edge[kShift][2]
      if kOrigShift in edge:
         zshift = edge[kOrigShift][2]
      intshift = [edge[kLastShift][0], edge[kLastShift][1], int(round(zshift))]
                  
   elif ignoreShifts:
      if kOrigShift not in edge:
         exitError('There is no originalShift entry for this edge in the info file')
      intshift = [int(round(edge[kOrigShift][0])), int(round(edge[kOrigShift][1])),
                  int(round(edge[kOrigShift][2]))]

   # Allow for existing shift and cut long dimension to fit
   longhalf = int(longFrac * min(losize[delx], upsize[delx])) // 2
   mid = (losize[delx] - intshift[delx]) // 2
   longhalf = min((longhalf, mid, losize[delx] - 1 - mid, \
                   mid + intshift[delx], \
                   upsize[delx] - 1 - mid - intshift[delx]))
   xyminlo[delx] = mid - longhalf
   xymaxlo[delx] = xyminlo[delx] + 2 * longhalf - 1
   xyminup[delx] = xyminlo[delx] + intshift[delx]
   xymaxup[delx] = xymaxlo[delx] + intshift[delx]

   # Get the overlap to use here
   if useProjxf:
      ovsize = edge[kLastShift][2]
   elif overlapToUse > 0:
      ovsize = overlapToUse
   else:
      ovsize = losize[dely] + intshift[dely]

   # get coordinates in short dimension, trim some here
   trim = int(trimFrac * ovsize)
   xyminlo[dely] = trim + losize[dely] - ovsize
   xymaxlo[dely] = losize[dely] - 1 - trim
   xyminup[dely] = trim
   xymaxup[dely] = ovsize - 1 - trim
   tmpproj1 = edge['name'] + '.proj1'
   tmpproj2 = edge['name'] + '.proj2'
   tmpprojst = edge['name'] + '.projst'
   tmpxf = edge['name'] + '.projxf'
   try:
      if useProjxf:
         xflines = readTextFile(tmpxf, 'transform file for projections')
         if len(xflines) < 2 or len(xflines[1].split()) < 6:
            exitError('The .projxf file does not have enough lines or entries in it')
         xfsplit = xflines[1].split()
         line = [xfsplit[4], xfsplit[5]]
   
      else:
         prnstr("Projecting overlap zones")
         projcom = fmtstr('xyzproj -ax Y -an 0,0,0 -mo 2 -xm {},{} -ym {},{} ' +\
                          '{} {}', xyminlo[0], xymaxlo[0], xyminlo[1], \
                          xymaxlo[1], lower['file'], tmpproj1)
         if verbose:
            prnstr(projcom)
         runcmd(projcom, None)
         projcom = fmtstr('xyzproj -ax Y -an 0,0,0 -mo 2 -xm {},{} -ym {},{} ' + \
                          '{} {}', xyminup[0], xymaxup[0], xyminup[1], \
                          xymaxup[1], upper['file'], tmpproj2)
         if verbose:
            prnstr(projcom)
         runcmd(projcom, None)
         runcmd(fmtstr('newstack {} {} {}', tmpproj1, tmpproj2, tmpprojst), None)
         padx = int(padFrac * (xymaxup[0] + 1 - xyminup[0]))
         pady = int(padFrac * (xymaxup[1] + 1 - xyminup[1]))
         edge[kLastShift] = [intshift[0], intshift[1], ovsize]

         prnstr("Correlating projections of overlap zones")
         txccom = fmtstr('tiltxcorr -ang 0,0 -pad {},{} -sigma1 0.03 -rot 0. ' +\
                         '{} {}', padx, pady, tmpprojst, tmpxf)
         if verbose:
            prnstr(txccom)
         txcout = runcmd(txccom, None)
         verboseOutput('TILTXCORR OUTPUT:', txcout)
            
         txcout.reverse()
         failed = 0
         for line in txcout:
            if re.search('View.*shifts', line):
               line = re.sub(r'View.*shifts *(\S+) *(\S+).*', '\\1 \\2', line).strip()
               line = line.split(' ')
               if isinstance(line, str) or len(line) < 2:
                  failed = 1
               break
         else:
            failed = 1
         if failed:
            exitError(fmtstr('Cannot find shift in output of tiltxcorr on edge {}',
                             edge['name']))

      # Subtract amount to shift upper to get new coordinate shift
      if verbose:
         prnstr(fmtstr('Initial shift: {}, {}', intshift[0], intshift[1]))
      intshift[delx] -= int(round(float(line[delx])))
      intshift[dely] = ovsize - losize[dely] - int(round(float(line[dely])))
      prnstr(fmtstr('Shift from correlation: {:.1f}, {:.1f}     Total shift' +\
                    ' {}, {}', -float(line[0]), -float(line[1]), intshift[0], \
                    intshift[1]))
      
      # Get coordinates for extracting subvolumes
      cenlo = [0, 0, 0]
      cenup = [0, 0, 0]
      size = [0, 0, 0]
      pad = [0, 0, 0]
      taper = [0, 0, 0]
      for i in [0, 1, 2]:
         mid = (losize[i] - intshift[i]) // 2
         half = min((blockSize // 2, mid, losize[i] - 1 - mid, \
                     mid + intshift[i], upsize[i] - 1 - mid - intshift[i]))
         cenlo[i] = mid
         cenup[i] = mid + intshift[i]
         size[i] = 2 * half
         pad[i] = size[i] // 5
         taper[i] = size[i] // 10

      # Extract the volumes
      prnstr("Extracting subvolumes at center of overlap")
      tmpvol1 = edge['name'] + '.vol1'
      tmpvol2 = edge['name'] + '.vol2'
      tmpcor = edge['name'] + '.cor'
      input = [lower['file'],
               tmpvol1,
               fmtstr('{},{},{}', size[0], size[1], size[2]),
               fmtstr('{},{},{}', cenlo[0], cenlo[1], cenlo[2]),
               fmtstr('{},{},{}', pad[0], pad[1], pad[2]),
               fmtstr('{},{},{}', taper[0], taper[1], taper[2])]
      verboseOutput('tapervoledge input for run 1', input)
      runcmd('tapervoledge', input)

      input[0] = upper['file']
      input[1] = tmpvol2
      input[3] = fmtstr('{},{},{}', cenup[0], cenup[1], cenup[2])
      verboseOutput('tapervoledge input for run 2', input)
      runcmd('tapervoledge', input)

      # Run clip, this order of files should give shift of B coordinates
      prnstr("Correlating subvolumes to find 3D shift")
      clipcom = fmtstr('clip corr -3d -n 0 {} {} {}', tmpvol1, tmpvol2, tmpcor)
      if verbose:
         prnstr(clipcom)
      clipout = runcmd(clipcom, None)
      verboseOutput('CLIP OUTPUT:', clipout)
      for line in clipout:
         if line.startswith('('):
            line = re.sub(r'^.*\((.*)\).*', '\\1', line).strip().split(',')
            break
      if isinstance(line, str) or len(line) < 3:
         exitError(fmtstr('Cannot find shift in output of clip on edge {}',\
               edge['name']))
      for i in [0, 1, 2]:
         edge[kShift][i] = intshift[i] + float(line[i])
      edge[kShiftDone] = '1'
      prnstr(fmtstr('Shift from correlation: {}, {}, {}    Total shift {}, ' +\
                    '{} {}', float(line[0]), float(line[1]), float(line[2]), \
             edge[kShift][0], edge[kShift][1], edge[kShift][2]))

      # Clean up
      os.remove(tmpcor)
      if leavetmp or useProjxf:
         return
      os.remove(tmpproj1)
      os.remove(tmpproj2)
      os.remove(tmpvol1)
      os.remove(tmpvol2)
      return

   except ImodpyError:
      exitFromImodError(progname)

# And a function to find the patches
def findPatches(lower, upper, edge):
   global borderXY
   prnstr('Computing patch correlations in edge ' + edge['name'])

   # Find limits in lower volume; Z limits are numbered from 1 but min and
   # max are numbered from zero
   sizes = list(patchSize)
   patchname = edge['name'] + '.patch'
   mins = [0, 0, 0]
   maxes = [0, 0, 0]
   numpat = [0, 0, 0]

   # Heirarchy for Z limits is the option entry, then the entry for the lower
   # piece, then an entry for the edge
   # Mins and maxes for corrsearch3d are numbered from 1
   mins[2] = max(1, borderZ + 1)
   maxes[2] = lower[kSize][2] + 1 - mins[2]
   if kZlimit in lower:
      mins[2] = max(1, lower[kZlimit][0])
      maxes[2] = min(lower[kSize][2], lower[kZlimit][1])
   if kZlimit in edge:
      mins[2] = max(1, edge[kZlimit][0])
      maxes[2] = min(lower[kSize][2], edge[kZlimit][1])
   if maxes[2] - mins[2] < 0.5 * sizes[2]:
      exitError(fmtstr('Z limits ({} and {}) too narrow for patch Z size ' +\
                       '({}) for edge {}', mins[2] + 1, maxes[2] + 1, \
                       sizes[2], edge['name']))

   if maxes[2] - mins[2] < sizes[2]:
      sizes[2] = maxes[2] - mins[2]

   borderXY = max(0, borderXY)
   for i in [0, 1]:
      mins[i] = max(borderXY, borderXY - int(edge[kShift][i])) + 1
      maxes[i] = min(lower[kSize][i] - borderXY, \
                     upper[kSize][i] - borderXY - int(edge[kShift][i]))
      if maxes[i] - mins[i] <= sizes[i]:
         prnstr(fmtstr('{} Border in X and Y and/or patch size in {} needs ' +\
                       'to be\nERROR:     reduced to fit in overlap for ' +\
                       'edge {}', prefix, ('X','Y')[i], edge['name']))
         overspace = min(lower[kSize][i], upper[kSize][i] -edge[kShift][i])\
                     - max(0, -edge[kShift][i])
         wanted = 2. * borderXY + sizes[i]
         frac = overspace / wanted
         prnstr(fmtstr('ERROR:     Overlap = {}; e.g., X/Y border {} and ' +\
                       'patch size {} would work', overspace,
                       int(frac * borderXY), int(frac * sizes[i])))
         sys.exit(1)

   # Get number of patches
   delta = intervals
   if edge[kXorY] == 'Y':
      delta = [intervals[1], intervals[0], intervals[2]]
   for i in [0, 1, 2]:
      numpat[i] = max(1, int(round(float(maxes[i] - mins[i] - sizes[i]) / \
                                   delta[i])) + 1)
   if forceZ > 0:
      numpat[2] = forceZ

   prnstr(fmtstr('Number of patches in X, Y, and Z: {} {} {}', \
         numpat[0], numpat[1], numpat[2]))
   input = ['ReferenceFile ' + lower['file'],
            'FileToAlign ' + upper['file'],
            'OutputFile ' + patchname,
            fmtstr('VolumeShiftXYZ {:f} {:f} {:f}', edge[kShift][0], \
                   edge[kShift][1], edge[kShift][2]),
            fmtstr('PatchSizeXYZ {} {} {}', sizes[0], sizes[1], sizes[2]),
            fmtstr('XMinAndMax {} {}', mins[0], maxes[0]),
            fmtstr('YMinAndMax {} {}', mins[1], maxes[1]),
            fmtstr('ZMinAndMax {} {}', mins[2], maxes[2]),
            fmtstr('BSourceBorderXLoHi {} {}', borderXY, borderXY),
            fmtstr('BSourceBorderYZLoHi {} {}', borderXY, borderXY),
            fmtstr('NumberOfPatchesXYZ {} {} {}', numpat[0], numpat[1],
                   numpat[2])]

   if kModel in lower:
       input.append('RegionModel ' + lower[kModel])
   if kModel in upper:
      input.append('BRegionModel ' + upper[kModel])
   if kernel > 0.:
      input.append(fmtstr('KernelSigma {:f}', kernel))

   try:
      verboseOutput('CORRSEARCH3D COMMAND:', input)
      csout = runcmd('corrsearch3d -StandardInput', input)
      verboseOutput('CORRSEARCH3D OUTPUT:', csout)

      # There are new patches: clear out all keys for derived patches
      edge[kPatch] = patchname
      if kReduce in edge:
         del edge[kReduce]
      if kResid in edge:
         del edge[kResid]
      if kResEdit in edge:
         del edge[kResEdit]
      if kResEdit in edge:
         del edge[kResEdit]
      csout.reverse()
      for line in csout:
         if line.find('per position') >= 0:
            prnstr(line, end='')
            break
         
      modname = edge['name'] + '_ccc.mod'
      patchcom = 'patch2imod -n "Values are correlation coefficients" ' + \
                 fmtstr('-s 5 -f {} {}', patchname, modname)
      runcmd(patchcom, None)
      return
   
   except ImodpyError:
      exitFromImodError(progname)


# Function for verbose multiline output
def verboseOutput(startLine, outputLines):
   if verbose:
      if len(startLine):
         prnstr(startLine)
      for line in outputLines:
         if line.endswith('\n'):
            prnstr(line, end='')
         else:
            prnstr(line)
   return

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys

progname = 'edgepatches'
prefix = 'ERROR: ' + progname + ' - '

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from supermont import *

setSMErrorPrefix(prefix)

# Initializations
longFrac = 0.5
padFrac = 0.2
trimFrac = 0.1
blockSize = 200
patchSize = (100, 100, 50)
intervals = (80, 120, 50)
borderXY = 50
borderZ = 10
forceZ = 0
kernel = 0.
testMode = 0
verbose = 0
leavetmp = 0

# Fallbacks from ../manpages/autodoc2man 3 1 edgepatches
options = ["info:InfoFile:FN:", "all:RunAll:B:", "xrun:XRunStartEnd:IP:",
           "yrun:YRunStartEnd:IP:", "zrun:ZRunStartEnd:IP:", "skip:SkipDone:B:",
           "redo:RedoShifts:B:", "use:UseProjXformFile:B:",
           "ignore:IgnoreLastShifts:B:", "long:LongFraction:F:",
           "overlap:OverlapToAssume:I:", "size:PatchSizeXYZ:IT:",
           "intervals:IntervalsShortLongZ:IT:", "force:ForceNumberInZ:I:",
           "borders:BordersInXYandZ:IP:", "kernel:KernelSigma:F:", "test:TestMode:I:"]

(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 2, \
                                              1, 0)
infofile = PipGetInOutFile('InfoFile', 0)
if not infofile:
   exitError("Info file name must be entered")

# Get the info file
if not os.path.exists(infofile):
   exitError('Info file ' + infofile + ' does not exist')
   
# Get running options
anyrun = 0
xrunStart, yrunStart, zrunStart = -10000, -10000, -10000
xrunEnd, yrunEnd, zrunEnd = 10000, 10000, 10000

anyrun = PipGetBoolean('RunAll', anyrun)
if not anyrun:
   xrunStart, xrunEnd = PipGetTwoIntegers('XRunStartEnd', xrunStart, xrunEnd)
   anyrun += 1 - PipGetErrNo()
   yrunStart, yrunEnd = PipGetTwoIntegers('YRunStartEnd', yrunStart, yrunEnd)
   anyrun += 1 - PipGetErrNo()
   zrunStart, zrunEnd = PipGetTwoIntegers('ZRunStartEnd', zrunStart, zrunEnd)
   anyrun += 1 - PipGetErrNo()
skipDone = PipGetBoolean('SkipDone', 0)
redoShifts = PipGetBoolean('RedoShifts', 0)
useProjxf = PipGetBoolean('UseProjXformFile', 0)
ignoreShifts = PipGetBoolean('IgnoreLastShifts', 0)
if useProjxf:
   redoShifts = 1
   if ignoreShifts:
      prnstr('WARNING: It is not necessary to enter -ignore with -use: shifts from' + \
                ' .projxf files file be used')
if redoShifts and skipDone:
   exitError('You cannot enter -skip with -redo or -use options')

# Get options related to patch correlation
intervals = PipGetThreeIntegers('IntervalsShortLongZ', intervals[0],
                                intervals[1], intervals[2])
borderXY, borderZ = PipGetTwoIntegers('BordersInXYandZ', borderXY, borderZ)
patchSize = PipGetThreeIntegers('PatchSizeXYZ', patchSize[0],
                                patchSize[1], patchSize[2])
longFrac = PipGetFloat('LongFraction', longFrac)
overlapToUse = PipGetInteger('OverlapToAssume', -1)
forceZ = PipGetInteger('ForceNumberInZ', forceZ)
kernel = PipGetFloat('KernelSigma', kernel)
testMode = PipGetInteger('TestMode', testMode)
if testMode > 1:
   leavetmp = 1
if testMode % 2:
   verbose = 1

patchParams = fmtstr("-size {},{},{} -interval {},{},{} -border {},{}", \
              patchSize[0], patchSize[1], patchSize[2], intervals[0],
               intervals[1], intervals[2], borderXY, borderZ)
if forceZ:
   patchParams += fmtstr(" -force {}", forceZ)
if kernel:
   patchParams += fmtstr(" -kernel {:.2f}", kernel)

PipDone()

# Initialize dictionary arrays, then read existing info file 
predata = {}
pieces = []
edges = []
slices = []
addedData = 0

readMontInfo(infofile, predata, slices, pieces, edges)
if not len(pieces):
   exitError(" No pieces are defined in the Info file " + infofile)
(xmin, xmax, ymin, ymax, zmin, zmax, zlist) = montMinMax(pieces)

# Build a piece map and fill in size entries for pieces
(xsize, xysize, addedData, pieceMap) = buildPieceMap(pieces, xmin, xmax, ymin, ymax, \
                                                        zmin, zmax, addedData, progname)

# Run the desired edges if any
if anyrun:
   xrunStart = max(xmin, xrunStart)
   xrunEnd = min(xmax, xrunEnd)
   yrunStart = max(ymin, yrunStart)
   yrunEnd = min(ymax, yrunEnd)
   zrunStart = max(zmin, zrunStart)
   zrunEnd = min(zmax, zrunEnd)
   for z in range(zrunStart, zrunEnd + 1):
      delx = 1
      dely = 0
      zbase = xysize * (z - zmin)
      for xory in ('X', 'Y'):
         for x in range(xrunStart, xrunEnd + dely):
            for y in range(yrunStart, yrunEnd + delx):
               lower = pieceMap[x - xmin + xsize * (y - ymin) + zbase]
               upper =  pieceMap[x + delx - xmin + xsize * \
                                 (y + dely - ymin) + zbase]
               if lower >= 0 and upper >= 0:
                  for edge in edges:
                     if edge[kXorY] == xory and edge[kLower][0] == x and \
                            edge[kLower][1] == y and edge[kLower][2] == z:
                        break
                  else:
                     exitError(fmtstr('Failed to find {} edge in edges array'+\
                           ' for lower piece {} {} {}', xory, x, y, z))

                  if redoShifts or not kShiftDone in edge:
                     addedData = 1
                     findEdgeShifts(pieces[lower], pieces[upper], edge)
                     writeMontInfo(infofile, predata, slices, pieces, edges)
                  if not (skipDone and kPatch in edge):
                     addedData = 1
                     findPatches(pieces[lower], pieces[upper], edge)
                     edge[kPatchParam] = patchParams
                     writeMontInfo(infofile, predata, slices, pieces, edges)
                  
         delx = 0
         dely = 1

# Write the info file 
if addedData:
   prnstr('New info file written')
else:
   prnstr('Nothing was done')
sys.exit(0)
