#!/usr/bin/env python
# makejoincom - program to set up for joining
#
# Author: David Mastronarde
#
# $Id: makejoincom,v 937343107256 2023/02/19 22:44:16 mast $
#
progname = 'makejoincom'
prefix = 'ERROR: ' + progname + ' - '
version = 5
DTOR = 0.0174533

def getAdjustedZRange(zentry, itomo, nz, newnz, rotangle, ifflip, bigrot, bottop):
   zst = zentry[0] - 1
   znd = zentry[1] - 1
   if zst < 0 or zst >= nz or znd < 0 or znd >= nz:
      exitError(fmtstr('{} entry for tomogram # {} has coordinates out of range', bottop,
                       itomo + 1))

   # If using rotatevol, need to adjust slice numbers
   # bigrot case rotates the Y axis of an unflipped volume
   # Other case is needed for rotation of Z in flipped vol
   if ifflip == 2:
      if bigrot == 1:
         zst = int((math.cos(DTOR * rotangle[0]) * math.sin(DTOR * rotangle[1]) *
                    math.sin(DTOR * rotangle[2]) + math.sin(DTOR * rotangle[0]) *
                    math.cos(DTOR * rotangle[2])) * (zst - 0.5 * (nz - 1)) + 0.5 * newnz)
         znd = int((math.cos(DTOR * rotangle[0]) * math.sin(DTOR * rotangle[1]) *
                    math.sin(DTOR * rotangle[2]) + math.sin(DTOR * rotangle[0]) *
                    math.cos(DTOR * rotangle[2])) * (znd - 0.5 * (nz - 1)) + 0.5 * newnz)
      else:
         zst = int(math.cos(DTOR * rotangle[0]) * math.cos(DTOR * rotangle[1]) *
                   (zst - 0.5 * (nz - 1)) + 0.5 * newnz)
         znd = int(math.cos(DTOR * rotangle[0]) * math.cos(DTOR * rotangle[1]) *
                   (znd - 0.5 * (nz - 1)) + 0.5 * newnz)

   # Figure out if inverting, and proper order for extraction
   ifinvert = 0
   if (bottop == 'bottom' and zst > nz // 2) or (bottop == 'top' and znd < nz // 2):
      ifinvert = 1
   if (zst < znd and ifinvert) or (zst >= znd and not ifinvert):
      (zst, znd) = (znd, zst)
   chunkadd = 1 + abs(zst - znd)
   range = str(zst)
   if zst != znd:
      range += '-' + str(znd)
   return (chunkadd, range, ifinvert)


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, shutil, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

# Fallbacks from ../manpages/autodoc2man 3 1 makejoincom
options = ["root:RootName:CH:", "input:InputTomogram:FNM:", "top:TopSlices:IPL:",
           "bottom:BottomSlices:IPL:", "flip:FlipYandZ:BL:",
           "rotate:RotateByAngles:FTL:", "already:AlreadyRotated:BL:",
           "xyzsize:XYZsizeForRotation:ITL:", "fullsize:FullSizeRotation:BL:",
           "maxxysize:MaxXYsizeForRotation:BL:", "dir:DirectoryOfSource:FN:",
           "srcext:SourceExtension:CH:", "tmpext:TemporaryExtension:CH:",
           "reference:ReferenceForDensity:I:", "midaslim:MidasSizeLimit:I:",
           "style:NamingStyle:I:", "pcm:MakeComExtensionPcm:I:", "test:TestMode:B:",
           "help:usage:B:"]

# On Windows, make sure the current directory is rxw or at least writeable
if makeCurrentDirWritable():
   exitError('Cannot make the current directory writable or write files to it')

PipSetLinkedOption('InputTomogram')
(numOpts, numNonOpts) = PipReadOrParseOptions(sys.argv, options, progname, 3, 0, 0)
os.environ['PIP_PRINT_ENTRIES'] = '0'

# get defaults for adjusting filenames
defdir = PipGetString('DirectoryOfSource', '')
deftmpext = PipGetString('TemporaryExtension', 'tmp')
defsrcext = PipGetString('SourceExtension', 'rec')
scaleref = PipGetInteger('ReferenceForDensity', 1)
midasLimit = PipGetInteger('MidasSizeLimit', 1024)
testMode = PipGetBoolean('TestMode', 0)
(nameStyleDflt, typeExt) = defaultNamingStyle()
(nameStyle, typeExt) = getNamingStyle(typeExt, allowExtra = True)
if nameStyle < 0:
   nameStyle = nameStyleDflt
comExt = comExtensionFromOption(0)
joincom = "startjoin" + comExt

# Convert backslashes in default directory
if defdir:
   defdir = defdir.replace('\\', '/')

# Get the root name.  If it comes from last non-option argument, decrement number
joinroot = PipGetString('RootName', '')
if not joinroot:
   if not numNonOpts:
      exitError('You must enter the root name')
   numNonOpts -= 1
   joinroot = PipGetNonOptionArg(numNonOpts)

setRootAndExtension(joinroot, typeExt)

# get the input names.  Insist they are either all option entries or all non-option args
numInput = PipNumberOfEntries('InputTomogram')
if numInput and numNonOpts:
   exitError('You must enter input tomograms either with -input or as non-option '+\
             'arguments, but not both')

reclist = []
ntomo = numInput + numNonOpts
if ntomo < 2:
   exitError('You must enter at least two tomograms')
for i in range(ntomo):
   if numInput:
      namein = PipGetString('InputTomogram')
   else:
      namein = PipGetNonOptionArg(i)

   # Replace backslashes now
   namein = namein.replace('\\', '/')
   reclist.append(namein)

# Build the lists for the other options
botlist = [None] * ntomo
toplist = [None] * ntomo
angleoplist = [None] * ntomo
sizeoplist = [None] * ntomo
ifrotlist = [None] * ntomo
didrotlist = [None] * ntomo
linkopts = ('BottomSlices', 'TopSlices', 'FlipYandZ', 'RotateByAngles', 'AlreadyRotated',
            'XYZsizeForRotation', 'FullSizeRotation', 'MaxXYsizeForRotation')
linklists = (botlist, toplist, ifrotlist, angleoplist, didrotlist, sizeoplist, sizeoplist,
             sizeoplist)
linktype = (2, 2, 0, 3, 0, 3, 0, 0)

for iopt in range(len(linkopts)):
   option = linkopts[iopt]
   numin = PipNumberOfEntries(option)
   for i in range(numin):
      index = PipLinkedIndex(option)
      if index >= ntomo:
         exitError('You cannot enter the option ' + option + \
                   ' after the last tomogram entry')
      if linktype[iopt] == 3:
         if option == 'RotateByAngles':
            if ifrotlist[index]:
               exitError(fmtstr('You entered both -flip and -rotate for tomogram # {}',
                                index + 1))
            sizes = PipGetThreeFloats(option, 0., 0., 0.)
            ifrotlist[index] = 2
         else:
            sizes = PipGetThreeIntegers(option, 0, 0, 0)
         linklists[iopt][index] = sizes

      elif linktype[iopt] == 2:
         slices = PipGetTwoIntegers(option, 0, 0)
         linklists[iopt][index] = slices

      else:
         bval = PipGetBoolean(option, 0)
         if bval:
            if option == 'FullSizeRotation':
               sizeoplist[index] = 'full'
            elif option == 'MaxXYsizeForRotation':
               sizeoplist[index] = 'maxxy'
            else:
               linklists[iopt][index] = bval
            
# Check the bottom and top entries now that all is known
if PipNumberOfEntries('BottomSlices') != ntomo - 1 or botlist[0]:
   exitError('You must enter bottom slices for all but the first tomogram')
if PipNumberOfEntries('TopSlices') != ntomo - 1 or toplist[ntomo-1]:
   exitError('You must enter top slices for all but the last tomogram')
if scaleref < 1 or scaleref > ntomo:
   exitError('The tomogram number as reference for scaling is out of range')
scaleref -= 1

# Loop on tomos
nxmax = 0
nymax = 0
fliplist = []
invertlist = ''
sizelist = []
anglelist = []
chunklist = ''
avglist = []
for itomo in range(ntomo):

   rotsize = [0, 0, 0]
   rotangle = [0., 0., 0.]
   namein = reclist[itomo]
   ifflip = ifrotlist[itomo]
   didrot = didrotlist[itomo]

   # Decompose filename into header, root and tail
   (dirname, tail) = os.path.split(namein)
   header = ''
   if not dirname:
      if defdir:
         header = defdir + '/'
   else:
      header = dirname + '/'
   (root, ext) = os.path.splitext(tail)

   # Compose filenames for source and flipped files, get size
   if not ext:
      recsource = header + root + '.' + defsrcext
   else:
      recsource = header + tail

   if os.path.exists(recsource):
      try:
         (nx, ny, nz) = getmrcsize(recsource)
      except ImodpyError:
         exitFromImodError(progname)
   else:
      (nx, ny, nz) = (1024, 1024, 100)
      prnstr(fmtstr('WARNING: {} NOT FOUND; SETTING SIZE TO {} {} {} FOR TEST PURPOSES',
                    recsource, nx, ny, nz))

   bigrot = 0
   if ifflip:

      # If flipping, then need a rec source and flip source
      flipsource = datasetFilename('.' + deftmpext, root = root)

      # If flipping, or if ny < nz and rotating, swap ny and nz
      if ny < nz:
         bigrot = 1
      if ifflip == 1 or bigrot:
         (ny, nz) = (nz, ny)

   else:

      # If already flipped, set flip source same (so it says..., is this right?)
      flipsource = recsource

   nxformax = nx
   nyformax = ny
   newnz = nz

   # If doing rotatevol, get the size of output file and the angles
   if ifflip == 2:
      rotangle = angleoplist[itomo]
      if not sizeoplist[itomo]:
         rotsize = [nx, ny, nz]
      elif sizeoplist[itomo] == "maxxy" or sizeoplist[itomo] == "full":
         try:
            rotcom = ['QuerySizeNeeded',
                      fmtstr('RotationAnglesZYX {},{},{}', rotangle[2],
                             rotangle[1], rotangle[0]),
                      'InputFile ' + recsource]
            rotout = runcmd('rotatevol -StandardInput',  rotcom)
         except ImodpyError:
            exitFromImodError(progname)
         if not len(rotout) or len(rotout[0].split()) < 3:
            exitError('rotatevol size query failed on ' + recsource)
         rotsplit = rotout[0].split()
         rotsize = [int(rotsplit[0]), int(rotsplit[1]), int(rotsplit[2])]
         if sizeoplist[itomo] == "maxxy":
            rotsize[2] = nz
         prnstr(fmtstr('Dimensions of {} will be {} {} {}', flipsource, rotsize[0],
                       rotsize[1], rotsize[2]))
      else:
         rotsize = sizeoplist[itomo]

      newnz = rotsize[2]
      nxformax = rotsize[0]
      nyformax = rotsize[1]

   # Keep track of maximum nx, ny
   nxmax = max(nxmax, nxformax)
   nymax = max(nymax, nyformax)

   # Get the slices to extract from bottom of section
   chunksize = 0
   botrange = None
   if itomo:
      (chunksize, botrange, ifinvert) = getAdjustedZRange(
         botlist[itomo], itomo, nz, newnz, rotangle, ifflip, bigrot, 'bottom')
      avglist.append(chunksize)
      chunklist += ',' + str(chunksize)

   # Get slice numbers to extract at top
   toprange = None
   if itomo < ntomo - 1:
      (chunkadd, toprange, ifinvert) = getAdjustedZRange(toplist[itomo], itomo, nz, newnz,
                                                         rotangle, ifflip, bigrot, 'top')
      avglist.append(chunkadd)
      chunksize += chunkadd
      if itomo:
         chunklist += ','
      chunklist += str(chunkadd)

   # Add all parameters to lists
   reclist[itomo] = recsource
   fliplist.append(flipsource)
   botlist[itomo] = botrange
   toplist[itomo] = toprange
   sizelist.append(rotsize)
   anglelist.append(rotangle)
   if itomo:
      invertlist += ' '
   invertlist += str(ifinvert)

# Work out whether squeezing is needed, set up transforms

xymax = float(max(nxmax, nymax))
ifsquoze = 0
newsize = fmtstr('{},{}', nxmax, nymax)
if xymax > midasLimit:
   ifsquoze = 1
   squeeze = midasLimit / xymax
   expand = xymax / midasLimit
   (newx, newy) = runGoodframe(int(squeeze * nxmax), int(squeeze * nymax))
   if newx < 0:
      exitError('Getting new size for squeezed slices from goodframe')
   newsize = fmtstr('{},{}', newx, newy)
   makeBackupFile(joinroot + '.sqzxf')
   makeBackupFile(joinroot + '.xpndxf')
   writeTextFile(joinroot + '.sqzxf', [fmtstr('{0:.7f} 0. 0. {0:.7f} 0. 0.', squeeze)])
   writeTextFile(joinroot + '.xpndxf', [fmtstr('{0:.7f} 0. 0. {0:.7f} 0. 0.', expand)])

# Get the scaling
newmode = 0
scalelist = []
for itomo in range(ntomo):
   try:
      scaling = ['1.', '0.']
      if itomo == scaleref:
         (xt,  yt, zt, newmode, px, py, pz) = getmrc(reclist[itomo])
      elif not testMode:
         prnstr(fmtstr("Determining scaling for tomogram # {} ...", itomo + 1))
         denscom = ['ReportOnly',
                    'ReferenceFile ' + reclist[scaleref],
                    'ScaledFile ' + reclist[itomo]]
         scalelines = runcmd('densmatch -StandardInput', denscom)
         for line in scalelines:
            if line.find('Scale factor') >= 0:
               index = line.find(':')
               if index > 0:
                  sctmp = line[index+1:].split()
                  if len(sctmp) == 2:
                     scaling = sctmp
                     break
         else:
            prnstr(fmtstr('WARNING: Densmatch failed for tomogram # {} ; using 1.,0.'+\
                          ' for scaling', itomo + 1))

      scalelist.append(scaling[0] + ',' + scaling[1])

   except ImodpyError:
      exitFromImodError(progname)
               
# Output info file with filenames
infofile = joinroot + '.info'
makeBackupFile(infofile)
allscale = scalelist[0]
for i in range(1,len(scalelist)):
   allscale += ' ' + scalelist[i]

infolines = [fmtstr('{}  {}  {}  {} {} {} {}', ntomo, ifsquoze, nxmax, nymax, version, 
                    newmode, nameStyle),
             invertlist, allscale]
for itomo in range(ntomo):
   infolines.append(fliplist[itomo])
writeTextFile(infofile, infolines)

# COMPOSE THE COMMAND FILE
makeBackupFile(joincom)
joinlines = []
sampleFile = datasetFilename('.sample')
sampAvgFile = datasetFilename('.sampavg')
addOutputFormatVarToLines(joinlines, nameStyle)

# Loop through the files, flip or rotate if necessary
for itomo in range(ntomo):
   if ifrotlist[itomo] == 1:
      joinlines.append(fmtstr('$clip flipyz "{}" {}', reclist[itomo], fliplist[itomo]))
   elif ifrotlist[itomo] == 2 and not didrotlist[itomo]:
      joinlines += ['$rotatevol -StandardInput',
                    'InputFile ' + reclist[itomo],
                    'OutputFile ' + fliplist[itomo],
                    fmtstr('OutputSizeXYZ {},{},{}', sizelist[itomo][0],
                           sizelist[itomo][1], sizelist[itomo][2]),
                    fmtstr('RotationAnglesZYX {},{},{}', anglelist[itomo][2],
                           anglelist[itomo][1], anglelist[itomo][0])]
      
# Now create the newstack command
joinlines.append('$newstack -StandardInput')
for itomo in range(ntomo):
   joinlines.append('InputFile ' + fliplist[itomo])
   if not itomo:
      joinlines.append('SectionsToRead ' + toplist[itomo])
   elif itomo == ntomo - 1:
      joinlines.append('SectionsToRead ' + botlist[itomo])
   else:
      joinlines.append('SectionsToRead ' + botlist[itomo] + ',' + toplist[itomo])
   joinlines.append('MultiplyAndAdd ' + scalelist[itomo])

joinlines += [fmtstr('ModeToOutput {}', newmode),
              'OutputFile ' + sampleFile,
              'SizeToOutputInXandY ' + newsize]
if ifsquoze:
   joinlines.append(fmtstr('ShrinkByFactor {:.7f}', expand))

# Now make the average commands
startsec = 0
tmplist = []
for ind in range(len(avglist)):
   numavg = avglist[ind]
   endsec = startsec + numavg - 1
   tmpfile = fmtstr('{}.tmpavg.{}', joinroot, ind + 1)
   joinlines += ['$avgstack',
                sampleFile,
                tmpfile,
                fmtstr('{},{}', startsec, endsec)]
   startsec += numavg
   tmplist.append(tmpfile)

joinlines += ['$newstack -StandardInput', 'OutputFile ' + sampAvgFile]
for tmpfile in tmplist:
   joinlines.append('InputFile ' + tmpfile)

joinlines.append('$\\rm -f ' + joinroot + '.tmpavg.*')
writeTextFile(joincom, joinlines)

prnstr("""
The command file """ +  joincom + """ has been written and is ready to submit.
Run it to create the files """ +sampleFile+ """ and """ +sampAvgFile+ """
For optional automatic alignment of the top/bottom averages, try:
   xfalign -tomo -pre """ +sampAvgFile+ """ """ +joinroot+ """.xf
Then run:
   midas -cs """ +chunklist + """ """ + sampleFile+ """ """ +joinroot+ """.xf
to run midas in chunk alignment mode and create or edit the transform file
For optional automatic refinement of transforms, then try:
   xfalign -tomo -ini """+joinroot+""".xf """ +sampAvgFile+ """ """ +joinroot+ """.xf
(after which you should check results with the midas command again).
Finally run finishjoin to create the joined tomogram""")

sys.exit(0)
