#!/usr/bin/env python
# Trimvol - 
#
# Author: David Mastronarde
#
# $Id: trimvol,v b07decab1e9b 2024/04/28 17:40:24 mast $

progname = 'trimvol'
prefix = 'ERROR: ' + progname + ' - '

# load System Libraries
import os, sys

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

# Initializations
adjust = '-ori'
fliparg = ''
sxarg = syarg = ''
secout = ''

# Fallbacks from ../manpages/autodoc2man 3 1 trimvol
options = ["x:XStartAndEnd:IP:", "y:YStartAndEnd:IP:", "z:ZStartAndEnd:IP:",
           "nx:XSize:I:", "ny:YSize:I:", "nz:ZSize:I:", "sz:ZFindStartAndEnd:IP:",
           "sx:XFindStartAndEnd:IP:", "sy:YFindStartAndEnd:IP:",
           "c:ContrastBlackWhite:IP:", "meansd:ScaleToMeanAndSD:FP:",
           "mm:IntegerMinMax:FP:", "mode:ModeToOutput:I:", "rx:RotateX:B:",
           "yz:FlipYZ:B:", "format:FormatOfOutputFile:CH:", "i:IndexCoordinates:B:",
           "f:FlippedCoordinates:B:", "old:OldFlippedCoordinates:I:", "k:KeepOrigin:B:",
           ":PID:B:", "help:usage:B:"]

# Startup and get input and output files
# Special case: give some good error messages for eliminated/renamed options
if '-s' in sys.argv:
   prnstr(prefix + 'The -s option has been eliminated; use -sz instead and add -f ' +\
          'if coordinates are from a volume loaded with flipping')
   sys.exit(1)
      
(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 2, 1, 1)
if nonopts != 2:
   prnstr(prefix + "wrong number of arguments")
   PipPrintHelp(progname, 0, 1, 1)
   sys.exit(1)

doPID = PipGetBoolean('PID', 0)
printPID(doPID)

inputFile = PipGetNonOptionArg(0)
outputFile = PipGetNonOptionArg(1)
if not os.path.exists(inputFile):
   exitError("Input file " + inputFile + " does not exist")

# Get scaling-related options
(black, white) = PipGetTwoIntegers('ContrastBlackWhite', 0, 0)
contrast = 1 - PipGetErrNo()

(intmin, intmax) = PipGetTwoFloats('IntegerMinMax', 0., 0.)
inmm = 1 - PipGetErrNo()

(slicest, slicend) = PipGetTwoIntegers('ZFindStartAndEnd', 0, 0)
zslices = 1 - PipGetErrNo()
(xsmin, xsmax) = PipGetTwoIntegers('XFindStartAndEnd', -1, -1)
(ysmin, ysmax) = PipGetTwoIntegers('YFindStartAndEnd', -1, -1)

(targMean, targSD) = PipGetTwoFloats('ScaleToMeanAndSD', 0., 0.)
meansd = 1 - PipGetErrNo()

mode = PipGetInteger('ModeToOutput', 0)
ifMode = 1 - PipGetErrNo()
contout = ''
clipMode = ''
if ifMode:
   if contrast or (zslices and not meansd):
      exitError('You cannot enter -mode with -c, or when running findcontrast')
   if inmm + meansd == 0:
      contout = '-mode ' + str(mode)
      clipMode = '-m ' + str(mode)

if contrast + inmm + zslices > 1 or contrast + inmm + meansd > 1:
   exitError('You cannot enter -c and -mm with each other or with -sz or -meansd')

# Manage output format
oformat = PipGetString('FormatOfOutputFile', '')
allowed = ('HDF', 'MRC', 'TIFF', 'TIF')
newstFormat = ''
clipFormat = ''
if oformat:
   if oformat.upper() not in allowed:
      exitError('Output format ' + oformat + ' is not a valid format')
   newstFormat = '-format ' + oformat
   clipFormat = '-f ' + oformat
   
# Get size or coordinate limit options
xsize = PipGetInteger('XSize', 0)
ifxsz = 1 - PipGetErrNo()
(xstart, xend) = PipGetTwoIntegers('XStartAndEnd', 0, 0)
ifxse = 1 - PipGetErrNo()
if ifxse + ifxsz > 1:
   exitError('You cannot enter both -x and -nx options')

ysize = PipGetInteger('YSize', 0)
ifysz = 1 - PipGetErrNo()
(ystart, yend) = PipGetTwoIntegers('YStartAndEnd', 0, 0)
ifyse = 1 - PipGetErrNo()
if ifyse + ifysz > 1:
   exitError('You cannot enter both -y and -ny options')

zsize = PipGetInteger('ZSize', 0)
ifzsz = 1 - PipGetErrNo()
(zstart, zend) = PipGetTwoIntegers('ZStartAndEnd', 0, 0)
ifzse = 1 - PipGetErrNo()
if ifzse + ifzsz > 1:
   exitError('You cannot enter both -z and -nz options')

# Flipping and rotation options
flipyz = PipGetBoolean('FlippedCoordinates', 0)
oldFlip = PipGetInteger('OldFlippedCoordinates', 0)
if flipyz:
   fliparg = '-flip'
   if oldFlip // 2 > 0:
      fliparg += ' -oldflip'

doflip = PipGetBoolean('FlipYZ', 0)
dorot = PipGetBoolean('RotateX', 0)
if dorot and doflip:
   exitError('You cannot use both -yz and -rx options')
if doflip:
   doflip = 'flipyz'
if dorot:
   doflip = 'rotx'

if PipGetBoolean('KeepOrigin', 0):
   adjust = ''
index = PipGetBoolean('IndexCoordinates', 0)

# Get file size
try:
   (nx, ny, nz) = getmrcsize(inputFile)
except ImodpyError:
   exitFromImodError(progname)

# If flipped coordinated, swap appropriate entries
if flipyz:
   stmp = ysize
   ysize = zsize
   zsize = stmp
   stmp = ifysz
   ifysz = ifzsz
   ifzsz = stmp
   stmp = (ystart, yend)
   if oldFlip % 2 > 0:
      (ystart, yend) = (zstart, zend)
   elif index:
      (ystart, yend) = (ny - 1 - zend, ny - 1 - zstart)
   else:
      (ystart, yend) = (ny + 1 - zend, ny + 1 - zstart)

   (zstart, zend) = stmp
   stmp = ifyse
   ifyse = ifzse
   ifzse = stmp

# Check and set up the X coordinates
xoffset = 0
yoffset = 0
if ifxsz:
   if xsize <= 0 or xsize > nx:
      exitError(fmtstr('Illegal X size in -nx {}', xsize))
else:
   xsize = nx
if ifxse:
   if not index:
      xstart -= 1
      xend -= 1
   if xend < 0 or xstart >= nx or xstart > xend:
      exitError(fmtstr('X coordinates out of range for file in -x {},{}', xstart + index,
                       xend + index))
   xsize = xend + 1 - xstart
   xoffset = xstart + xsize // 2 - nx // 2

# Check and set up Y coordinates
inlet = 'yz'[flipyz]
if ifysz:
   if ysize <= 0 or ysize > ny:
      exitError(fmtstr('Illegal {} size in -n{} {}', inlet, inlet, ysize))
else:
   ysize = ny
if ifyse:
   if not index:
      ystart -= 1
      yend -= 1
   if yend < 0 or ystart >= ny or ystart > yend:
      exitError(fmtstr('{} coordinates out of range for file in -{} {},{}', inlet.upper(),
                       inlet, ystart + index, yend + index))
   ysize = yend + 1 - ystart
   yoffset = ystart + ysize // 2 - ny // 2

# Check and set up Z section list if either entry given
inlet = 'zy'[flipyz]
if ifzsz:
   if zsize <= 0 or zsize > nz:
      exitError(fmtstr('Illegal {} size in -n{} {}', inlet, inlet, zsize))
   zstart = (nz - zsize) // 2
   zend = zstart + zsize - 1
   secout = fmtstr('-sec {}-{}', zstart, zend)

if ifzse:
   if not index:
      zstart -= 1
      zend -= 1
   if zend < 0 or zstart >= nz or zstart > zend:
      exitError(fmtstr('{} coordinates out of range for file in -{} {},{}', inlet, inlet,
                       zstart + index, zend + index))
   secout = fmtstr('-sec {}-{}', zstart, zend)
   if zstart < 0 or zend >= nz:
      secout += ' -blank'

# Process the entries for X and Y limits in findcontrast
if xsmin >= 0 and xsmax > 0:
   if not index:
      xsmin -= 1
      xsmax -= 1
   sxarg = fmtstr('-xminmax {},{}', xsmin, xsmax)

if ysmin >= 0 and ysmax > 0:
   if not index:
      ysmin -= 1
      ysmax -= 1
   syarg = fmtstr('-yminmax {},{}', ysmin, ysmax)

# Take care of converting other contrast entries to arguments
if contrast:
   contout = fmtstr('-mode 0 -con {},{}', black, white)
if inmm:
   if not ifMode:
      mode = 1
   contout = fmtstr('-mode {} -sca {},{}', mode, intmin, intmax)

# Check entered slice limits for scaling depending on flipping
slicelim = nz
ylim = ny
if fliparg:
   slicelim = ny
   ylim = nz
if zslices:
   if slicest < 1 or slicend > slicelim or slicest > slicend:
      exitError(fmtstr('Slices out of range for file in -sz {},{}', slicest, slicend))

newstout = outputFile
if doflip:
   newstout = inputFile + '.tmp.' + str(os.getpid())

# If given target mean/SD, find the scaling factors
if meansd:
   if not sxarg:
      xsmin = int(0.1 * nx)
      xsmax = int(0.9 * nx)
   if not syarg:
      ysmin = int(0.1 * ylim)
      ysmax = int(0.9 * ylim)
   if zslices:
      zsmin = slicest - 1
      zsmax = slicend - 1
   else:
      zsmin = 0
      zsmax = slicelim - 1
   if flipyz:
      if oldFlip // 2 > 0:
         (ysmin, ysmax, zsmin, zsmax) = (zsmin, zsmax, ysmin, ysmax)
      else:
         (ysmin, ysmax, zsmin, zsmax) = (ny - 1 - zsmax, ny - 1 - zsmin, ysmin, ysmax)
         
   comlines = ['ScaledFile ' + inputFile,
               fmtstr('TargetMeanAndSD {},{}', targMean, targSD),
               'ReportOnly 1',
               fmtstr('XMinAndMax {},{}', xsmin, xsmax),
               fmtstr('YMinAndMax {},{}', ysmin, ysmax),
               fmtstr('ZMinAndMax {},{}', zsmin, zsmax)]
   try:
      prnstr('Running densmatch to find scaling to mean and SD...')
      densLines = runcmd('densmatch -StandardInput', comlines)
   except ImodpyError:
      exitFromImodError(progname)

   for line in densLines:
      prnstr(line.rstrip('\r\n'))
      if line.startswith('Scale factors to'):
         lsplit = line.split()
         try:
            multfac = float(lsplit[-2])
            sdfac = float(lsplit[-1])
            if not ifMode:
               mode = 0
            contout = fmtstr('-mode {} -multadd {},{}', mode, multfac, sdfac)
         except Exception:
            pass

   if not contout:
      exitError('Cannot find scaling information in output of densmatch')
      
# Or run findcontrast if Z slices were entered
elif zslices:
   slicelim = nz
   if fliparg:
      slicelim = ny
   if slicest < 1 or slicend > slicelim or slicest > slicend:
      exitError(fmtstr('Slices out of range for file in -sz {},{}', slicest, slicend))
   prnstr('Determining byte scaling of ' + inputFile + '...')
   findcom = fmtstr('findcontrast -slice {},{} {} {} {} "{}"', slicest, slicend, fliparg,
                    sxarg, syarg, inputFile)
   prnstr(findcom)
   try:
      findlines = runcmd(findcom)
   except ImodpyError:
      exitFromImodError(progname)

   # Get the black white while printing the lines
   black = None
   for line in findlines:
      prnstr(line.strip())
      if black == None and line.find('Implied') >= 0:
         ind = line.find('are ')
         if ind > 0:
            bwsplit = line[ind+3:].split()
            if len(bwsplit) > 2:
               black = int(bwsplit[0])
               white = int(bwsplit[2])

   if black == None:
      exitError('Findcontrast failed to return scaling values')
   contout = fmtstr('-mode 0 -con {},{}', black, white)

# compose and run the newstack command
newstcom = fmtstr('newstack -siz {},{} -off {},{} {} {} {} {} "{}" "{}"', xsize, ysize,
                  xoffset, yoffset, newstFormat, adjust, contout, secout, inputFile,
                  newstout)

try:
   runcmd(newstcom, None, 'stdout')
except ImodpyError:
   exitFromImodError(progname)

if zslices:
   prnstr(fmtstr('Contrast black/white levels determined from file were {},{}', black,
                 white))
prnstr(' ')
prnstr('The newstack command was:')
prnstr(newstcom)

# Flip or rotate if requested
if doflip:
   prnstr('Running clip ' + doflip)
   try:
      runcmd(fmtstr('clip {} {} {} "{}" "{}"', doflip, clipFormat, clipMode, newstout, outputFile))
   except ImodpyError:
      exitFromImodError(progname)
   try:
      os.remove(newstout)
   except:
      prnstr('WARNING: error trying to delete temporary file ' + newstout)

sys.exit(0)
