#!/usr/bin/env python
# transferfid - Transfer fiducials from one axis to another
#
# Author: David Mastronarde
#
# $Id: transferfid,v 54bd721b02e6 2023/04/12 18:00:41 mast $
#

progname = 'transferfid'
prefix = 'ERROR: ' + progname + ' - '

# Find the view with the minimum tilt angle from either the tilt file or track.com
# Also return the tilt angles in an array
def getMinimumAngle(setname, src, AA, lines, nz):
   tiltFile = setname + src + '.rawtlt'
   angles = []
   if not os.path.exists(tiltFile):

      # Try to find starting and increment and compute from them
      first = optionValue(lines, 'FirstTiltAngle', 2)
      increment = optionValue(lines, 'TiltIncrement', 2)
      if not first or not increment:
         exitError(fmtstr('{} not found - it is needed unless you enter the zero-tilt ' +\
                          'view number for {} with -z{} or track{}{} ' +\
                          'has starting angle and increment', tiltFile, AA, src, src,
                          comExt))
      if math.fabs(increment[0]) < math.fabs(0.000001 * first[0]):
         exitError('Tilt increment too small to find zero tilt view number')
      zero = 1 + int(math.floor(-first[0] / increment[0] + 0.5))
      if zero <= 0:
         exitError('Cannot find zero tilt view number from first angle and increment')
      for iz in range(nz):
         angles.append(first + iz * increment)
      return (zero, angles)

   # Find minimum tilt angle in rawtlt file
   angLines = readTextFile(tiltFile)
   amin = 1.e20
   zero = -1
   try:
      for i in range(len(angLines)):
         if angLines[i].strip():
            angles.append(float(angLines[i]))
            ang = math.fabs(float(angLines[i]))
            if ang < amin:
               amin = ang
               zero = i + 1
   except Exception:
      exitError('Converting lines in ' + tiltFile + ' to floating point values')
   if zero <= 0:
      exitError('Cannot find a minimum tilt angle from ' + tiltFile)
   return (zero, angles)


# the format of these 6-element arrays is xpx, xpy, ypx, ypy, dx, dy
# Functions are based on linearxforms with rows = 2
# 
# Return a rotation matrix, 
def rotationTransform(angle):
   sina = math.sin(math.radians(angle))
   cosa = math.cos(math.radians(angle))
   return [cosa, -sina, sina, cosa, 0., 0.]

# Multiply two matrices, where f1 is applied first, f2 second
def xfMult(f1, f2):
   tmp = [0] * 6
   tmp[0] = (f2[0] * f1[0]) + (f2[2] * f1[1])
   tmp[1] = (f2[1] * f1[0]) + (f2[3] * f1[1])
   tmp[2] = (f2[0] * f1[2]) + (f2[2] * f1[3])
   tmp[3] = (f2[1] * f1[2]) + (f2[3] * f1[3])
   tmp[4] = (f2[0] * f1[4]) + (f2[2] * f1[5]) + f2[4]
   tmp[5] = (f2[1] * f1[4]) + (f2[3] * f1[5]) + f2[5]
   return tmp

# Return the inverse of a matrix
def xfInvert(f):
   tmp = [0] * 6
   denom = f[0] * f[3] - f[2] * f[1]
   tmp[0] = f[3] / denom
   tmp[2] = -f[2] / denom
   tmp[1] = -f[1] / denom
   tmp[3] = f[0] / denom
   tmp[4] = -(tmp[0] * f[4] + tmp[2] * f[5])
   tmp[5] = -(tmp[1] * f[4] + tmp[3] * f[5])
   return tmp
 

#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, glob, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *
from tiltmatch import *

# Fallbacks from ../manpages/autodoc2man 3 1 transferfid
options = ["s:Setname:CH:", "b:TransferBtoA:B:", "ia:AImageFile:FN:",
           "ib:BImageFile:FN:", "f:FiducialModel:FN:", "o:SeedModel:FN:",
           "boundary:BoundaryModel:FN:", "n:ViewsToSearch:I:", "za:ACenterView:I:",
           "zb:BCenterView:I:", "a:AngleOfRotation:I:", "x:MirrorXaxis:I:",
           "m:RunMidas:B:", "scan:ScanRotationMaxAndStep:FP:",
           "c:CorrespondingCoordFile:FN:", "lowest:LowestTiltTransformFile:FN:",
           "t:LeaveTempFiles:B:", ":PID:B:", "help:usage:B:"]

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 1, 0, 0)
passOnKeyInterrupt(True)

doPID = PipGetBoolean('PID', 0)
printPID(doPID)

# Set names of temp files, this gets pid, tmpRoot, and tmpDir in the module globals
tmpMinxf = getTempNames(progname)
(tmpRoot, tmpDir, pid) = getTempComponents()

tmpStack = tmpRoot + 'stack' + pid
tmpTwoxf = tmpRoot + 'twoxf' + pid
tmpClip = tmpRoot + 'clip' + pid
tmpXfmod = tmpRoot + 'xfmod' + pid
tmpSeed = tmpRoot + 'seed' + pid
tmpMap1 = tmpRoot + 'map1' + pid
tmpMap2 = tmpRoot + 'map2' + pid
tmpMap3 = tmpRoot + 'map3' + pid
tmpMidxf = tmpRoot + 'midxf' + pid
tmpTrans = tmpRoot + 'trans' + pid

setname = PipGetInOutFile('Setname', 0)
if not setname:
   exitError('You must enter the setname (root name of dataset)')

src = 'a'
dst = 'b'
AA = 'A'
BB = 'B'
ifBtoA = PipGetBoolean('TransferBtoA', 0)
if ifBtoA:
   src = 'b'
   dst = 'a'
   AA = 'B'
   BB = 'A'


# Get the com extension and if this fails, get it specifically from track files
(comExt, dualNum, setroot, typeExt, stackExt) = findRootAxisAndExtensions \
                                                (forceSingle = -1)
if not comExt:
   acomExists = os.path.exists('tracka.com')
   bcomExists = os.path.exists('trackb.pcm')
   apcmExists = os.path.exists('tracka.com')
   bpcmExists = os.path.exists('trackb.pcm')
   if acomExists and bcomExists and (not (apcmExists and bpcmExists)):
      comExt = 'com'
   elif (not (acomExists and bcomExists)) and apcmExists and bpcmExists:
      comExt = 'pcm'
   else:
      exitError('Cannot determine command file extension: both .com and .pcm files exist')

comExt = '.' + comExt

outFile = setname + dst + '.seed'
imageA = PipGetString('AImageFile', '')
imageB = PipGetString('BImageFile', '')
fidFile = PipGetString('FiducialModel', setname + src + '.fid')
ifFidFile = 1 - PipGetErrNo()
correspond = PipGetString('CorrespondingCoordFile', '')

# Get possible boundary file, check if conflicts, get default output file
boundaryFile = PipGetString('BoundaryModel', '')
if boundaryFile:
   if ifFidFile or correspond:
      exitError('You cannot enter a fiducial model or the -c option with a boundary ' +\
                 'model')
   outFile = ''
   if boundaryFile.startswith(setname + src):
      outFile = boundaryFile.replace(setname + src, setname + dst, 1)

outFile = PipGetString('SeedModel', outFile)
if not outFile:
   exitError('You must enter an output file with -o; the boundary model name does ' +\
             'not start with ' + setname + src)

nviews = PipGetInteger('ViewsToSearch', 5)
zeroA = PipGetInteger('ACenterView', -1)
zeroB = PipGetInteger('BCenterView', -1)
lowestXfFile = PipGetString('LowestTiltTransformFile', '')
if nviews < 0:
   exitError('The number of views to sample must be positive')

# swap inputs for filename and center z's if going backwards
if ifBtoA:
   tmp = imageA
   imageA = imageB
   imageB = tmp
   tmp = zeroA
   zeroA = zeroB
   zeroB = tmp

# Get the A track command file and insist it be PIP version; get A image file if needed
tracka = 'track' + src + comExt
if not os.path.exists(tracka):
   exitError('Cannot find ' + tracka + ' command file')
trackLines = readTextFile(tracka)

trackLines = extractProgramEntries(trackLines, 'beadtrack', '-Standard')
if trackLines == None:
   exitError('Old version of ' + tracka + ' cannot be used; convert it by opening ' +\
             'and closing the fiducial tracking panel in etomo')

if not imageA:
   imageA = optionValue(trackLines, 'ImageFile', 0)

# Get B image file from trackb.com if needed 
if not imageB or zeroB < 0 or lowestXfFile or boundaryFile:
   trackb = 'track' + dst + comExt
   if not os.path.exists(trackb):
      exitError(fmtstr('Cannot find {} command file; it is needed unless you enter the' +\
                       ' {} image file with -i{}', trackb, BB, dst))
   bLines = readTextFile(trackb)
   if not imageB:
      imageB =  optionValue(bLines, 'ImageFile', 0)
   if not imageB:
      exitError(fmtstr('Cannot find the {} image file name in {}', BB, trackb))

# Make sure image files exist and fid file too
for imfile in (imageA, imageB):
   if not os.path.exists(imfile):
      exitError('Image file ' + imfile + ' does not exist')

if boundaryFile and not os.path.exists(boundaryFile):
   exitError('Boundary model file ' + boundaryFile + ' does not exist')
elif not os.path.exists(fidFile):
   exitError('Fiducial file ' + fidFile + ' does not exist')

# Get image sizes and pixel size
(nxa, nya, nza, mode, pxa, pya, pza, oxa, oya, oza, mna, mxa, mean) = getmrc(imageA, True)
(nxb, nyb, nzb, mode, pxb, pyb, pzb, oxb, oyb, ozb, mnz, mxb, mean) = getmrc(imageB, True)
expandAfac = 1
if math.fabs((pxb - pxa) /pxa) > 0.025:
   prnstr(fmtstr('WARNING: - {}: Pixel sizes do not match: {} = {}, {} = {}; scaling' + \
                    ' to compensate', progname, imageA, pxa, imageB, pxb))
   prnstr('')
   expandAfac = pxa / pxb

# Get the view at minimum tilt if needed for one reason or another
zeroAview = zeroA
zeroBview = zeroB
if zeroA < 0 or lowestXfFile or boundaryFile:
   (zeroAview, anglesA) = getMinimumAngle(setname, src, AA, trackLines, nza)
if zeroB < 0 or lowestXfFile or boundaryFile:
   (zeroBview, anglesB) = getMinimumAngle(setname, dst, BB, bLines, nzb)
if zeroA < 0:
   zeroA = zeroAview
if zeroB < 0:
   zeroB = zeroBview

zeroA -= 1
zeroB -= 1

asecStart = zeroA - (nviews // 2)
asecEnd = asecStart + nviews - 1
bsecStart = zeroB - (nviews // 2)
bsecEnd = bsecStart + nviews - 1
lowestAsec = min(asecEnd, max(asecStart, zeroAview - 1))
lowestBsec = min(bsecEnd, max(bsecStart, zeroBview - 1))

# Check section numbers
if asecStart < 0 or asecEnd >= nza:
    exitError(fmtstr('The starting or ending section numbers for {} are out of range' + \
                     ' ({} and {})', AA, asecStart, asecEnd))
if bsecStart < 0 or bsecEnd >= nzb:
    exitError(fmtstr('The starting or ending section numbers for {} are out of range' + \
                     ' ({} and {})', BB, bsecStart, bsecEnd))

(asecBest, bsecBest, junk1, junk2) = \
    searchPairs(progname, zeroA, zeroB, nviews, nviews, imageA, imageB, nxa, nxb, nya,
                nyb, AA, BB, lowestXfFile, lowestAsec, lowestBsec, '', 0, False,
                expandAfac)

try:
   
   # Find the pixel size of the model and a scale factor
   modFile = fidFile
   if boundaryFile:
      modFile = boundaryFile
   try:
      infoLines = runcmd('imodinfo -h "' + modFile + '"')
   except ImodpyError:
      cleanExitError('Extracting pixel size from model')
   modPixel = 0.
   for l in infoLines:
      if l.find('SCALE  =') >= 0:
         l = l.replace(',', '')
         l = l.replace('(', '')
         l = l.replace(')', '')
         try:
            modPixel = float(l.split()[3])
         except Exception:
            modPixel = 0.
            pass
         break

   if not modPixel:
      cleanup()
      exitError('Getting model scale value')

   modScale = pxa / modPixel

   # Get the best transform
   minxf = []
   if os.path.exists(tmpMinxf):
      minxf = readTextFile(tmpMinxf)
   if len(minxf) < 1:
      cleanup()
      exitError('No alignment was computed, cannot continue')

   # do what is needed with boundary file
   if boundaryFile:
   
      # Need to find view with most points
      mess = 'Converting boundary model to point file'
      runcmd(fmtstr('model2point "{}" "{}"', boundaryFile, tmpMap1))
      boundLines = readTextFile(tmpMap1)
      if len(boundLines) < 1:
         exitError('The boundary model has no points')
      nearestZ = -100
      for line in boundLines:
         try:
            ptz = int(round(float(line.split()[2])))
            if math.fabs(ptz - asecBest) < math.fabs(nearestZ - asecBest):
               nearestZ = ptz
         except Exception:
            exitError('Reading Z coordinate from line in point file: ' + line)
      
      # Set up transform and a starting matrix
      axisRotation = optionValue(trackLines, 'RotationAngle', 2, 1, numVal = 1)
      try:
         xfsplit = minxf[0].split()
         aToBmat = [float(xfsplit[0]), float(xfsplit[1]), float(xfsplit[2]), 
                    float(xfsplit[3]), float(xfsplit[4]), float(xfsplit[5])]
      except Exception:
         exitError('Converting transform values to floats')

      baseMat = [1., 0., 0., 1., 0., 0.]

      # If angles are available, apply a stretch perpendicular to X by the ratio of the
      # cosines of asecBest and the nearest view with contours
      if len(anglesA) > max(asecBest, nearestZ) and len(anglesB) > bsecBest and \
         axisRotation != None:
         baseMat = rotationTransform(-axisRotation)
         baseMat[0] *= math.cos(math.radians(anglesA[asecBest])) / \
                       math.cos(math.radians(anglesA[nearestZ]))
                                
         baseMat = xfMult(baseMat, rotationTransform(axisRotation))
         
      # Get the full transform and write it
      fullMat = xfMult(baseMat, aToBmat)
      writeTextFile(tmpTwoxf, [fmtstr('{}  {}  {}  {}  {}  {}', fullMat[0], fullMat[1],
                                      fullMat[2], fullMat[3], fullMat[4], fullMat[5])])
      zadd = bsecBest - asecBest
      xfInputMod = boundaryFile
      mapOutputMod = outFile
      stackForTrans = imageB
      newOpt = ''

   else:

      # Fiducial transfer operations
      prnstr(fmtstr('Transferring fiducials from view {} in {} to view {} in {} with' + \
                    ' Beadtrack:', asecBest + 1, AA, bsecBest + 1, BB))
      prnstr("              (Type Ctrl-C to interrupt)")
      
      # Stack the two best sections
      zadd = 0.
      xfInputMod = tmpMap1
      mapOutputMod = tmpMap2
      stackForTrans = tmpStack
      newOpt = '-new 1'
      writeTextFile(tmpTwoxf, ['1 0 0 1 0 0', minxf[0]])
      try:
         runcmd(fmtstr('newstack -sec {} -sec {} -xform {} -use 0,1 -float 2 "{}" "{}" "{}"',
                       bsecBest, asecBest, tmpTwoxf, imageB, imageA, tmpStack))
      except ImodpyError:
         cleanExitError('Stacking two best views')

      # clip out the model and remap it to z = 1
      clipcom = ['InputFile ' + fidFile,
                 'OutputFile ' + tmpClip,
                 fmtstr('ZMinAndMax {},{}', asecBest - 0.5, asecBest + 0.5),
                 'KeepEmptyContours']
      try:
         runcmd('clipmodel -StandardInput', clipcom)
      except ImodpyError:
         cleanExitError('Clipping out best view from ' + AA + ' fiducial model')
      try:
         runcmd(fmtstr('remapmodel -new 1 "{}" "{}"', tmpClip, tmpMap1))
      except ImodpyError:
         cleanExitError('Remapping ' + AA + ' fiducials to section 1')

   # Common operations
   # transform model then adjust its coordinates to new center
   xadd = modScale * (nxb - nxa) / 2.
   yadd = modScale * (nyb - nya) / 2.
   try:
      runcmd(fmtstr('xfmodel -xforms "{}" -scale {} "{}" "{}"', tmpTwoxf, modScale,
                    xfInputMod, tmpXfmod))
   except ImodpyError:
      cleanExitError('Transforming ' + AA + ' model to match ' + BB + ' image')

   # Need to modify the image reference data in the model to match B if there was scaling
   # or if the origins are not the same between the files
   remapIn = tmpXfmod
   if expandAfac != 1 or oxa != oxb or oya != oyb or oza != ozb:
      remapIn = tmpTrans
      try:
         runcmd(fmtstr('imodtrans -I "{}" "{}" "{}"', stackForTrans, tmpXfmod, tmpTrans))
      except ImodpyError:
         cleanExitError('Changing scale information in transformed model')
         
   try:
      runcmd(fmtstr('remapmodel {} -add {},{},{} "{}" "{}"', newOpt, xadd, yadd, zadd, 
                    remapIn, mapOutputMod))
   except ImodpyError:
      cleanExitError('Recentering transformed ' + AA + ' model')

   # Done if it is a boundary model
   if boundaryFile:
      cleanup()
      sys.exit(0)

   # Prepare the blendmont command; keep tracking parameters but modify for two untilted
   # images
   sedcom = ['?^ImageFile?s?[ 	].*? ' + tmpStack + '?',
             '?^InputSeedModel?s?[ 	].*? ' + tmpMap2 + '?',
             '?^OutputModel?s?[ 	].*? ' + tmpSeed + '?',
             '?^RotationAngle?s?[ 	].*? 0?',
             '?^FirstTiltAngle?d',
             '?^TiltIncrement?d',
             '?^TiltFile?d',
             '?^TiltAngles?d',
             '?^SkipViews?d',
             '?^ShiftsNearZeroTilt?d',
             '?^SeparateGroup?d',
             '?^RoundsOfTracking?s?[ 	].*? 1?',
             '?^RotationAngle?a?TiltAngles 0,0?']
   sedlines = pysed(sedcom, trackLines, None, False, '?')

   # If there is local tracking, definitely track objects together
   iflocal = optionValue(trackLines, 'LocalAreaTracking', 0)
   if iflocal and iflocal[0]:
      sedcom.append('TrackObjectsTogether')

   try:
      tracklog = runcmd('beadtrack -StandardInput', sedlines)
   except ImodpyError:
      cleanExitError('Running Beadtrack to get fiducials onto ' + BB + ' view')

   try:
      btnum = tracklog[len(tracklog) - 1].split()
      prnstr('Number of fiducials that failed to transfer: ' + btnum[len(btnum) - 1])
   except Exception:
      cleanup()
      exitError('Finding # failed message in track output')

   # Remap seed model to the section in B
   try:
      runcmd(fmtstr('remapmodel -new {},-999 "{}" "{}"', bsecBest, tmpSeed, tmpMap3))
   except ImodpyError:
      cleanExitError('Remapping seed model up to view in ' + BB)

   # Repack the model to remove empty points, and pass through mapping report
   # First find out if the fid.xyz is available and has contour data
   xyzName = setname + src + 'fid.xyz'
   if os.path.exists(xyzName):
      xyzlines = readTextFile(xyzName)
      if len(xyzlines[len(xyzlines) - 1].split()) < 6:
         xyzName = ''
   else:
      xyzName = ''

   comlines = [fidFile, xyzName, tmpMap3, outFile, correspond,
               fmtstr('{},{},{}', asecBest, bsecBest, ifBtoA)]
   try:
      repLines = runcmd('repackseed', comlines)
   except ImodpyError:
      cleanExitError('Repacking seed model and establishing correspondence')
   doOut = False
   for l in repLines:
      doOut = doOut or l.find('follow') >= 0
      if doOut:
         prnstr(l.strip())

except KeyboardInterrupt:
   pass

cleanup()
sys.exit(0)

                     
             
