#!/usr/bin/env python

"""
Pegasus utility for transfer of files during workflow enactment

Usage: pegasus-transfer [options]

If the runtime environment sets a variable PEGASUS_POLICY_CHECKS to a true
value (e.g. True, enabled, on, yes, 1, etc.), this utility will submit the list
of transfers to a Policy Web Service, which will optionally return the list or
some subset of the list after applying policy restrictions to it. The returned
list is then processed normally.

Communication with the Policy Web Service is optionally controlled either by
setting additional environment variables or command line options. (The
additional command line options are only available if PEGASUS_POLICY_CHECKS is
defined and True.) The environment variables are:

    PEGASUS_POLICY_HOST
    PEGASUS_POLICY_PORT
    PEGASUS_POLICY_URL

If neither environment variables nor command line options are used, the defaults
are, respectively, localhost, 80, and /policy/transfer/. 
"""

##
#  Copyright 2007-2011 University Of Southern California
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
##

import os
import re
import sys
import errno
import logging
import optparse
import tempfile
import subprocess
import signal
import string
import stat
import time
from collections import deque


__author__ = "Mats Rynge <rynge@isi.edu>"

# --- regular expressions -------------------------------------------------------------

re_parse_url = re.compile(r'([\w]+)://([\w\.\-:@]*)(/[\S]*)')

# --- classes -------------------------------------------------------------------------

class Transfer:

    pair_id        = 0       # the id of the pair in the input, the nth pair in the input
    src_proto      = ""      # 
    src_host       = ""      # 
    src_path       = ""      #
    dst_proto      = ""      #
    dst_host       = ""      #
    dst_path       = ""      #
    allow_grouping = True    # can this transfer be grouped with others?
    policy_id      = ""      # ID assigned by Policy Web Service (when used)

    def __init__(self, pair_id):
        """
        Initializes the transfer class
        """
        self.pair_id = pair_id

    def set_src(self, url):
        self.src_proto, self.src_host, self.src_path = self.parse_url(url)
    
    def set_dst(self, url):
        self.dst_proto, self.dst_host, self.dst_path = self.parse_url(url)

    def set_policy_id(self, id):
        """
        Used only when a policy service is queried about the transfer list.
        """
        self.policy_id = id
    
    def parse_url(self, url):
        proto = ""
        host = ""
        path = ""

        # default protocol is file://
        if string.find(url, ":") == -1:
            logger.debug("URL without protocol (" + url + ") - assuming file://")
            url = "file://" + url

        # file url is a special cases as it can contain relative paths and env vars
        if string.find(url, "file:") == 0:
            proto = "file"
            # file urls can either start with file://[\w]*/ or file: (no //)
            path = re.sub("^file:(//[\w\.\-:@]*)?", "", url)
            path = expand_env_vars(path)
            return proto, host, path
        
        # symlink url is a special cases as it can contain relative paths and env vars
        if string.find(url, "symlink:") == 0:
            proto = "symlink"
            # symlink urls can either start with symlink://[\w]*/ or symlink: (no //)
            path = re.sub("^symlink:(//[\w\.\-:@]*)?", "", url)
            path = expand_env_vars(path)
            return proto, host, path

        # other than file/symlink urls
        r = re_parse_url.search(url)
        if not r:
            raise RuntimeError("Unable to parse URL: %s" % (url))
        
        # Parse successful
        proto = r.group(1)
        host = r.group(2)
        path = r.group(3)
        
        # no double slashes in urls
        path = re.sub('//+', '/', path)
        
        return proto, host, path


    def src_url(self):
        return "%s://%s%s" % (self.src_proto, self.src_host, self.src_path)

    def src_url_srm(self):
        """
        srm-copy is using broken urls - wants an extra / 
        """
        if self.src_proto != "srm":
            return "%s://%s/%s" % (self.src_proto, self.src_host, self.src_path)
        return self.src_url()
    
    def dst_url(self):
        return "%s://%s%s" % (self.dst_proto, self.dst_host, self.dst_path)

    def dst_url_srm(self):
        """
        srm-copy is using broken urls - wants an extra / 
        """
        if self.dst_proto != "srm":
            return "%s://%s/%s" % (self.dst_proto, self.dst_host, self.dst_path)
        return self.dst_url()
    
    def dst_url_dirname(self):
        dn = os.path.dirname(self.dst_path)
        return "%s://%s%s" % (self.dst_proto, self.dst_host, dn)

    def groupable(self):
        """
        currently only gridftp allows for grouping
        """
        return self.allow_grouping and (self.src_proto == "gsiftp" or self.dst_proto == "gsiftp")

    def __cmp__(self, other):
        """
        compares first on protos, then on hosts, then on paths - useful
        for grouping similar types of transfers
        """
        if cmp(self.src_proto, other.src_proto) != 0:
            return cmp(self.src_proto, other.src_proto)
        if cmp(self.dst_proto, other.dst_proto) != 0:
            return cmp(self.dst_proto, other.dst_proto)
        if cmp(self.src_host, other.src_host) != 0:
            return cmp(self.src_host, other.src_host)
        if cmp(self.dst_host, other.dst_host) != 0:
            return cmp(self.dst_host, other.dst_host)
        if cmp(self.src_path, other.src_path) != 0:
            return cmp(self.src_path, other.src_path)
        if cmp(self.dst_path, other.dst_path) != 0:
            return cmp(self.dst_path, other.dst_path)
        return 0


class Alarm(Exception):
    pass


# --- global variables ----------------------------------------------------------------

prog_dir  = os.path.normpath(os.path.join(os.path.dirname(sys.argv[0])))
prog_base = os.path.split(sys.argv[0])[1]   # Name of this program

logger = logging.getLogger("my_logger")

# timeout for when shelling out
default_subshell_timeout = 6 * 60 * 60;

# this is the map of what tool to use for a given protocol pair (src, dest)
tool_map = {}
tool_map[('fdt'     , 'file'    )] = 'fdt'
tool_map[('file'    , 'fdt'     )] = 'fdt'
tool_map[('file'    , 'file'    )] = 'cp'
tool_map[('file'    , 'gsiftp'  )] = 'gsiftp'
tool_map[('file'    , 'irods'   )] = 'irods'
tool_map[('file'    , 'scp'     )] = 'scp'
tool_map[('file'    , 's3'      )] = 's3'
tool_map[('file'    , 's3s'     )] = 's3'
tool_map[('file'    , 'srm'     )] = 'srm'
tool_map[('file'    , 'symlink' )] = 'symlink'
tool_map[('ftp'     , 'ftp'     )] = 'gsiftp'
tool_map[('ftp'     , 'gsiftp'  )] = 'gsiftp'
tool_map[('gsiftp'  , 'file'    )] = 'gsiftp'
tool_map[('gsiftp'  , 'ftp'     )] = 'gsiftp'
tool_map[('gsiftp'  , 'gsiftp'  )] = 'gsiftp'
tool_map[('gsiftp'  , 'srm'     )] = 'srm'
tool_map[('http'    , 'file'    )] = 'webget'
tool_map[('http'    , 'gsiftp'  )] = 'gsiftp'
tool_map[('https'   , 'file'    )] = 'webget'
tool_map[('irods'   , 'file'    )] = 'irods'
tool_map[('s3'      , 'file'    )] = 's3'
tool_map[('s3s'     , 'file'    )] = 's3'
tool_map[('scp'     , 'file'    )] = 'scp'
tool_map[('srm'     , 'file'    )] = 'srm'
tool_map[('srm'     , 'gsiftp'  )] = 'srm'
tool_map[('srm'     , 'srm'     )] = 'srm'
tool_map[('symlink' , 'symlink' )] = 'symlink'

tool_info = {}

# track remote directories created so that don't have to
# try to create them over and over again
remote_dirs_created = {}

# stats
stats_start = 0
stats_end = 0
stats_total_bytes = 0

# This flag used to control calls to code that uses an external Policy Service.
# If the runtime environment sets a variable PEGASUS_POLICY_CHECKS, the flag is
# toggled and functions to submit the transfer list to the policy service are
# used. Otherwise, the list is just sorted lexically.
using_policy_service = False 

# --- functions -----------------------------------------------------------------------


def setup_logger(level_str):
    
    # log to the console
    console = logging.StreamHandler()
    
    # default log level - make logger/console match
    logger.setLevel(logging.INFO)
    console.setLevel(logging.INFO)

    # level - from the command line
    level_str = level_str.lower()
    if level_str == "debug":
        logger.setLevel(logging.DEBUG)
        console.setLevel(logging.DEBUG)
    if level_str == "warning":
        logger.setLevel(logging.WARNING)
        console.setLevel(logging.WARNING)
    if level_str == "error":
        logger.setLevel(logging.ERROR)
        console.setLevel(logging.ERROR)

    # formatter
    formatter = logging.Formatter("%(asctime)s %(levelname)7s:  %(message)s")
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.debug("Logger has been configured")

def prog_sigint_handler(signum, frame):
    logger.warn("Exiting due to signal %d" % (signum))
    myexit(1)

def alarm_handler(signum, frame):
    raise Alarm


def expand_env_vars(s):
    re_env_var = re.compile(r'\${?([a-zA-Z0-9_]+)}?')
    s = re.sub(re_env_var, get_env_var, s)
    return s


def get_env_var(match):
    name = match.group(1)
    value = ""
    logger.debug("Looking up " + name)
    if name in os.environ:
        value = os.environ[name]
    return value


def myexec(cmd_line, timeout_secs, should_log):
    """
    executes shell commands with the ability to time out if the command hangs
    """
    global delay_exit_code
    if should_log or logger.isEnabledFor(logging.DEBUG):
        logger.info(cmd_line)
    sys.stdout.flush()

    # set up signal handler for timeout
    signal.signal(signal.SIGALRM, alarm_handler)
    signal.alarm(timeout_secs)

    p = subprocess.Popen(cmd_line, shell=True)
    try:
        stdoutdata, stderrdata = p.communicate()
    except Alarm:
        if sys.version_info >= (2, 6):
            p.terminate()
        raise RuntimeError("Command '%s' timed out after %s seconds" % (cmd_line, timeout_secs))
    rc = p.returncode
    if rc != 0:
        raise RuntimeError("Command '%s' failed with error code %s" % (cmd_line, rc))


def backticks(cmd_line):
    """
    what would a python program be without some perl love?
    """
    return subprocess.Popen(cmd_line, shell=True, stdout=subprocess.PIPE).communicate()[0]


def check_tool(executable, version_arg, version_regex):
    # initialize the global tool info for this executable
    tool_info[executable] = {}
    tool_info[executable]['full_path'] = None
    tool_info[executable]['version'] = None
    tool_info[executable]['version_major'] = None
    tool_info[executable]['version_minor'] = None
    tool_info[executable]['version_patch'] = None

    # figure out the full path to the executable
    full_path = backticks("which " + executable + " 2>/dev/null") 
    full_path = full_path.rstrip('\n')
    if full_path == "":
        logger.info("Command '%s' not found in the current environment" %(executable))
        return
    tool_info[executable]['full_path'] = full_path

    # version
    if version_regex == None:
        version = "N/A"
    else:
        version = backticks(executable + " " + version_arg + " 2>&1")
        version = version.replace('\n', "")
        re_version = re.compile(version_regex)
        result = re_version.search(version)
        if result:
            version = result.group(1)
        tool_info[executable]['version'] = version

    # if possible, break up version into major, minor, patch
    re_version = re.compile("([0-9]+)\.([0-9]+)(\.([0-9]+)){0,1}")
    result = re_version.search(version)
    if result:
        tool_info[executable]['version_major'] = int(result.group(1))
        tool_info[executable]['version_minor'] = int(result.group(2))
        tool_info[executable]['version_patch'] = result.group(4)
    if tool_info[executable]['version_patch'] == None or tool_info[executable]['version_patch'] == "":
        tool_info[executable]['version_patch'] = None
    else:
        tool_info[executable]['version_patch'] = int(tool_info[executable]['version_patch'])

    logger.info("  %-18s Version: %-7s Path: %s" % (executable, version, full_path))


def check_env_and_tools():
    
    # PATH setup
    path = "/usr/bin:/bin"
    if "PATH" in os.environ:
        path = os.environ['PATH']
    path_entries = path.split(':')
    
    # is /usr/bin in the path?
    if not("/usr/bin" in path_entries):
        path_entries.append("/usr/bin")
        path_entries.append("/bin")

    # fink on macos x
    if os.path.exists("/sw/bin") and not("/sw/bin" in path_entries):
        path_entries.append("/sw/bin")
       
    # need LD_LIBRARY_PATH for Globus tools
    ld_library_path = ""
    if "LD_LIBRARY_PATH" in os.environ:
        ld_library_path = os.environ['LD_LIBRARY_PATH']
    ld_library_path_entries = ld_library_path.split(':')
    
    # if PEGASUS_HOME is set, prepend it to the PATH (we want it early to override other cruft)
    if "PEGASUS_HOME" in os.environ:
        try:
            path_entries.remove(os.environ['PEGASUS_HOME'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['PEGASUS_HOME'] + "/bin")
    
    # if GLOBUS_LOCATION is set, prepend it to the PATH and LD_LIBRARY_PATH 
    # (we want it early to override other cruft)
    if "GLOBUS_LOCATION" in os.environ:
        try:
            path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/bin")
        try:
            ld_library_path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/lib")
        except Exception:
            pass
        ld_library_path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/lib")

    os.environ['PATH'] = ":".join(path_entries)
    os.environ['LD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    os.environ['DYLD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    logger.info("PATH=" + os.environ['PATH'])
    logger.info("LD_LIBRARY_PATH=" + os.environ['LD_LIBRARY_PATH'])
    
    # irods requires a password hash file
    os.environ['irodsAuthFileName'] = os.getcwd() + "/.irodsA"
    
    # tools we might need later
    check_tool("wget", "--version", "([0-9]+\.[0-9]+)")
    check_tool("globus-version", "--full", "([0-9]+\.[0-9]+\.[0-9]+)")
    check_tool("globus-url-copy", "-version", "([0-9]+\.[0-9]+)")
    check_tool("srm-copy", "-version", "srm-copy[ \t]+([\.0-9a-zA-Z]+)")
    check_tool("iget", "-h", "Version[ \t]+([\.0-9a-zA-Z]+)")
    check_tool("pegasus-s3", "help", None)


def prepare_local_dir(path):
    """
    makes sure a local path exists before putting files into it
    """
    if not(os.path.exists(path)):
        logger.debug("Creating local directory " + path)
        try:
            os.makedirs(path, 0755)
        except os.error, err:
            # if dir already exists, ignore the error
            if not(os.path.isdir(path)):
                raise RuntimeError(err)


def cp(transfers, failed_q):
    """
    copies locally using /bin/cp
    """
    for i, transfer in enumerate(transfers): 
        prepare_local_dir(os.path.dirname(transfer.dst_path))
        cmd = "/bin/cp -f -L \"%s\" \"%s\"" % (transfer.src_path, transfer.dst_path)
        try:
            myexec(cmd, default_subshell_timeout, True)
        except RuntimeError, err:
            logger.error(err)
            failed_q.append(transfer)
        stats_add(transfer.dst_path)


def symlink(transfers, failed_q):
    """
    symlinks locally using ln
    """

    for i, transfer in enumerate(transfers): 
        prepare_local_dir(os.path.dirname(transfer.dst_path))

        # we do not allow dangling symlinks
        if not os.path.exists(transfer.src_path):
            logger.warning("Symlink source (%s) does not exist" % (transfer.src_path))
            failed_q.append(transfer)
            continue

        if os.path.exists(transfer.src_path) and os.path.exists(transfer.dst_path):
            # make sure src and target are not the same file - have to compare at the
            # inode level as paths can differ
            src_inode = os.stat(transfer.src_path)[stat.ST_INO]
            dst_inode = os.stat(transfer.dst_path)[stat.ST_INO]
            if src_inode == dst_inode:
                logger.warning("symlink: src (%s) and dst (%s) already exists" % (transfer.src_path, transfer.dst_path))
                continue

        cmd = "ln -f -s %s %s" % (transfer.src_path, transfer.dst_path)
        try:
            myexec(cmd, 60, True)
        except RuntimeError, err:
            logger.error(err)
            failed_q.append(transfer)


def prepare_scp_dir(rhost, rdir):
    """
    makes sure a local path exists before putting files into it
    """
    cmd = "/usr/bin/ssh"
    if "SSH_PRIVATE_KEY" in os.environ:
        cmd += " -i " + os.environ['SSH_PRIVATE_KEY']
    cmd += " -q -o StrictHostKeyChecking=no"
    cmd += " " + rhost + " '/bin/mkdir -p " + rdir + "'"
    myexec(cmd, default_subshell_timeout, True)


def scp(transfers, failed_q):
    """
    copies using scp
    """
    for i, transfer in enumerate(transfers): 
        cmd = "/usr/bin/scp"
        if "SSH_PRIVATE_KEY" in os.environ:
            cmd += " -i " + os.environ['SSH_PRIVATE_KEY']
        cmd += " -q -B -o StrictHostKeyChecking=no"
        try:
            if transfer.dst_proto == "file":
                prepare_local_dir(os.path.dirname(transfer.dst_path))
                cmd += " " + transfer.src_host + ":" + transfer.src_path
                cmd += " " + transfer.dst_path
            else:
                mkdir_key = "scp://" + transfer.dst_host + ":" + os.path.dirname(transfer.dst_path)
                if not mkdir_key in remote_dirs_created:
                    prepare_scp_dir(transfer.dst_host, os.path.dirname(transfer.dst_path))
                    remote_dirs_created[mkdir_key] = True
                cmd += " " + transfer.src_path
                cmd += " " + transfer.dst_host + ":" + transfer.dst_path
                stats_add(transfer.src_path)

            myexec(cmd, default_subshell_timeout, True)
            if transfer.dst_proto == "file":
                stats_add(transfer.dst_path)    

        except RuntimeError, err:
            logger.error(err)
            failed_q.append(transfer)


def fdt(transfers, failed_q):
    """
    copies using FDT - http://monalisa.cern.ch/FDT/license.html
    """
    # download fdt.jar on demand - it can not be shipped with Pegasus due to licensing
    if not os.path.exists("fdt.jar"):
        cmd = "wget -nv -O fdt.jar http://monalisa.cern.ch/FDT/lib/fdt.jar"
        try:
            myexec(cmd, 10*60, True)
        except RuntimeError, err:
            logger.error(err)

    for i, transfer in enumerate(transfers): 
        cmd = "echo | java -jar fdt.jar"
        if transfer.dst_proto == "file":
            prepare_local_dir(os.path.dirname(transfer.dst_path))
            cmd += " " + transfer.src_host + ":" + transfer.src_path
            cmd += " " + transfer.dst_path
        else:
            cmd += " " + transfer.src_path
            cmd += " " + transfer.dst_host + ":" + transfer.dst_path
            stats_add(transfer.src_path)
        try:
            myexec(cmd, default_subshell_timeout, True)
            if transfer.dst_proto == "file":
                stats_add(transfer.dst_path)    
        except RuntimeError, err:
            logger.error(err)
            failed_q.append(transfer)


def webget(transfers, failed_q):
    """
    pulls http/https using wget
    """
    if len(transfers) == 0:
        return
    if len(transfers) > 0 and tool_info['wget']['full_path'] == None:
        raise RuntimeError("Unable to do http/https transfers becuase wget could not be found")
    for i, transfer in enumerate(transfers): 
        prepare_local_dir(os.path.dirname(transfer.dst_path))
        cmd = tool_info['wget']['full_path']
        if logger.isEnabledFor(logging.DEBUG):
            cmd += " -v"
        else:
            cmd += " -q"
        cmd += " --no-check-certificate -O \"" + transfer.dst_path + "\" \"" + transfer.src_url() + "\""
        try:
            myexec(cmd, default_subshell_timeout, True)
            stats_add(transfer.dst_path)
        except RuntimeError, err:
            logger.error(err)
            failed_q.append(transfer)


def transfers_groupable(a, b):
    """
    compares two url_pairs, and determins if they are similar enough to be
    grouped together for one tool
    """
    if not a.groupable() or not b.groupable():
        return False
    if a.src_proto != b.src_proto:
        return False
    if a.dst_proto != b.dst_proto:
        return False
    return True


def gsiftp_similar(a, b):
    """
    compares two url_pairs, and determins if they are similar enough to be
    grouped together in one transfer input file
    """
    if a.src_host != b.src_host:
        return False
    if a.dst_host != b.dst_host:
        return False
    if os.path.dirname(a.src_path) != os.path.dirname(b.src_path):
        return False
    if os.path.dirname(a.dst_path) != os.path.dirname(b.dst_path):
        return False
    return True


def gsiftp(full_list, failed_q, attempt):
    """
    gsiftp - globus-url-copy for now, maybe uberftp in the future
    """
    if len(full_list) == 0:
        return
    
    if tool_info['globus-url-copy']['full_path'] == None:
        raise RuntimeError("Unable to do gsiftp transfers becuase globus-url-copy could not be found")

    # create lists with similar (same src host/path, same dst host/path) url pairs
    while len(full_list) > 0:

        similar_list = []

        curr = full_list.pop()
        prev = curr
        third_party = curr.src_proto == "gsiftp" and curr.dst_proto == "gsiftp"

        while gsiftp_similar(curr, prev):
            
            similar_list.append(curr)

            if len(full_list) == 0:
                break
            else:
                prev = curr
                curr = full_list.pop()

        if not gsiftp_similar(curr, prev):
            # the last pair is not part of the set and needs to be added back to the
            # beginning of the list
            full_list.append(curr)

        if len(similar_list) == 0:
            break

        # we now have a list of similar transfers - break up and send the first one with create dir
        # and the rest with no create dir options
        first_list = []
        first_list.append(similar_list.pop())
        gsiftp_do_transfers(first_list, failed_q, True, third_party)
        if len(similar_list) > 0:
            gsiftp_do_transfers(similar_list, failed_q, False, third_party)




def gsiftp_do_transfers(transfers, failed_q, create_dest, third_party):
    """
    sub to gsiftp() - transfers a list of urls
    """
    
    # keep track of what transfer we attempted so we can add to fail q in case of failures
    attempted_transfers = transfers[:]
    delayed_file_stat = []

    # create tmp file with transfer src/dst pairs
    num_pairs = 0
    try:
        tmp_fd, tmp_name = tempfile.mkstemp(prefix="pegasus-transfer-", suffix=".lst", dir="/tmp")
        tmp_file = os.fdopen(tmp_fd, "w+b")
    except:
        raise RuntimeError("Unable to create tmp file for globus-url-copy transfers")
    for i, t in enumerate(transfers):
        num_pairs += 1
        logger.debug("   adding %s %s" % (t.src_url(), t.dst_url()))

        # delay stating until we have finished the transfers
        if t.src_proto == "file":
            delayed_file_stat.append(t.src_path)
        elif t.dst_proto == "file":
            delayed_file_stat.append(t.dst_path)

        tmp_file.write("%s %s\n" % (t.src_url(), t.dst_url()))

    tmp_file.close()
    
    logger.info("Grouped %d similar gsiftp transfers together in temporary file %s" %(num_pairs, tmp_name))

    # build command line for globus-url-copy
    cmd = tool_info['globus-url-copy']['full_path']

    # make output from guc match our current log level
    if logger.isEnabledFor(logging.DEBUG):
        cmd += " -dbg"
    elif num_pairs < 10:
        cmd += " -verbose"

    # should we try to create directories?
    if create_dest:
        cmd += " -create-dest"
    
    # Only do third party transfers for gsiftp->gsiftp. For other combinations, fall
    # back to settings which will for well over for example NAT
    if third_party:
        cmd += " -parallel 4"

        # -fast only for Globus 4 and above
        if tool_info['globus-version']['version_major'] >= 4:
            cmd += " -fast"
        
        # -pipeline only for Globus 4.2 and above
        if (tool_info['globus-version']['version_major'] == 5 \
            or (tool_info['globus-version']['version_major'] >= 4 \
                and tool_info['globus-version']['version_minor'] >= 2)):
            cmd += " -pipeline"       
    else:
        cmd += " -no-third-party-transfers -no-data-channel-authentication"

    cmd += " -f " + tmp_name
    try:
        myexec(cmd, default_subshell_timeout, True)
    
        # stat the files
        for i, filename in enumerate(delayed_file_stat): 
            stats_add(filename)
    except Exception, err:
        logger.error(err)
        for i, t in enumerate(attempted_transfers):
            failed_q.append(t)
    os.unlink(tmp_name)


def irods_login():
    """
    log in to irods by using the iinit command - if the file already exists,
    we are already logged in
    """
    f = os.environ['irodsAuthFileName']
    if os.path.exists(f):
        return
    
    # read password from env file
    if not "irodsEnvFile" in os.environ:
        raise RuntimeError("Missing irodsEnvFile - unable to do irods transfers")
    password = None
    h = open(os.environ['irodsEnvFile'], 'r')
    for line in h:
        items = line.split(" ", 2)
        if items[0].lower() == "irodspassword":
            password = items[1].strip(" \t'\"\r\n")
    h.close()
    if password == None:
        raise RuntimeError("No irodsPassword specified in irods env file")
    
    h = open(".irodsAc", "w")
    h.write(password + "\n")
    h.close()
    
    cmd = "cat .irodsAc | iinit"
    myexec(cmd, 5*60, True)
        
    os.unlink(".irodsAc")


def irods(transfers, failed_q):
    """
    irods - use the icommands to interact with irods
    """
    if len(transfers) == 0:
        return

    if tool_info['iget']['full_path'] == None:
        raise RuntimeError("Unable to do irods transfers becuase iget could not be found in the current path")

    # log in to irods
    try:
        irods_login()
    except Exception, loginErr:
        logger.error(loginErr)
        raise RuntimError("Unable to log into irods")

    for i, url_pair in enumerate(transfers): 
        if url_pair.dst_proto == "file":
            # irods->file
            prepare_local_dir(os.path.dirname(url_pair.dst_path))
            cmd = "iget -f " + url_pair.src_path + " " + url_pair.dst_path
        else:
            # file->irods
            cmd = "imkdir -p " + os.path.dirname(url_pair.dst_path)
            try:
                myexec(cmd, 60*60, True)
            except:
                # ignore errors from the mkdir command
                pass
            cmd = "iput -f " + url_pair.src_path + " " +  url_pair.dst_path

        try:
            myexec(cmd, default_subshell_timeout, True)
            # stats      
            if url_pair.dst_proto == "file":
                stats_add(url_pair.dst_path)
            else:
                stats_add(url_pair.src_path)
        except Exception, err:
            logger.error(err)
            failed_q.append(url_pair)


def srm(transfers, failed_q):
    """
    srm - use srm-copy (Is this generic enough? Do we need to handle space tokens?)
    """
    if len(transfers) == 0:
        return

    if tool_info['srm-copy']['full_path'] == None:
        raise RuntimeError("Unable to do srm transfers becuase srm-copy could not be found")

    for i, url_pair in enumerate(transfers): 
        if url_pair.dst_proto == "file":
            prepare_local_dir(os.path.dirname(url_pair.dst_path))
        #elif url_pair.dst_proto == "gsiftp" or url_pair.dst_proto == "srm":
        #    srm_mkdir(os.path.dirname(url_pair.dst_url_srm()))
            
        third_party = (url_pair.src_proto == "gsiftp" or url_pair.src_proto == "srm") and \
                      (url_pair.dst_proto == "gsiftp" or url_pair.dst_proto == "srm")
            
        cmd = "srm-copy  %s %s -mkdir" % (url_pair.src_url_srm(), url_pair.dst_url_srm())
        if third_party:
            cmd = cmd + " -parallelism 4 -3partycopy"
    
        if not logger.isEnabledFor(logging.DEBUG):
            cmd = cmd + " >/dev/null"
            
        try:
            myexec(cmd, 6*60*60, True)
        except Exception, err:
            logger.error(err)
            failed_q.append(url_pair)
            
            
def srm_mkdir(url):
    """
    implements recursive mkdir as srm-mkdir can not handle it
    """

    # end condition
    if url == "/" or url == "":
        return True

    # does the url exist?
    cmd = "srm-ls %s >/dev/null" %(url)
    try:
        myexec(cmd, 10*60, True)
        return True
    except Exception, err:
        logger.error(err)
    
    # if we get here, the directory does not exist
    # create the parent first
    one_up = os.path.dirname(url)
    srm_mkdir(one_up)
    
    cmd = "srm-mkdir %s >/dev/null" %(url)
    try:
        myexec(cmd, 10*60, True)
    except Exception, err:
        logger.error(err)
        return False

    return True
    

def s3(transfers, failed_q):
    """
    s3 - uses pegasus-s3 to interact with Amazon S3 
    """
    if len(transfers) == 0:
        return

    if tool_info['pegasus-s3']['full_path'] == None:
        raise RuntimeError("Unable to do S3 transfers becuase pegasus-s3 could not be found")

    buckets_created = {}

    for i, url_pair in enumerate(transfers): 

        # get/put?
        if url_pair.dst_proto == "file":
            # this is a 'get'
            local_filename = url_pair.dst_path
            prepare_local_dir(os.path.dirname(url_pair.dst_path))
            cmd = "pegasus-s3 get %s %s" % (url_pair.src_url(), url_pair.dst_path)
        else:
            # this is a 'put'
            local_filename = url_pair.src_path
            
            # extract the bucket part
            re_bucket = re.compile(r'(s3(s){0,1}://\w+@\w+/+[\w]+)')
            bucket = url_pair.dst_url_dirname()
            r = re_bucket.search(bucket)
            if r:
                bucket = r.group(1)
            else:
                raise RuntimeError("Unable to parse bucket: %s" % (bucket))
            
            # first ensure that the bucket exists
            if not bucket in buckets_created:
                buckets_created[bucket] = True
                cmd = "pegasus-s3 mkdir %s" %(bucket)
                try:
                    myexec(cmd, 5*60, True)
                except Exception, err:
                    logger.error("mkdir failed - possibly due to the bucket already existing, so continuing...")
            cmd = "pegasus-s3 put %s %s" % (url_pair.src_path, url_pair.dst_url())

        try:
            myexec(cmd, default_subshell_timeout, True)
            stats_add(local_filename)
        except Exception, err:
            logger.error(err)
            failed_q.append(url_pair)


def handle_transfers(transfers, failed_q, attempt):
    """
    handles a list of transfers - failed ones are added to the failed queue
    """
    try:
        if tool_map.has_key((t_main.src_proto, t_main.dst_proto)):
            tool = tool_map[(t_main.src_proto, t_main.dst_proto)]
            if tool == "cp":
                cp(transfers, failed_q)
            elif tool == "fdt":
                fdt(transfers, failed_q)
            elif tool == "symlink":
                symlink(transfers, failed_q)
            elif tool == "scp":
                scp(transfers, failed_q)
            elif tool == "webget":
                webget(transfers, failed_q)
            elif tool == "gsiftp":
                gsiftp(transfers, failed_q, attempt)
            elif tool == "irods":
                irods(transfers, failed_q)
            elif tool == "srm":
                srm(transfers, failed_q)
            elif tool == "s3":
                s3(transfers, failed_q)
            else:
                logger.critical("Error: No mapping for the tool '%s'" %(tool))
                myexit(1)
        else:
            logger.critical("Error: This tool does not know how to transfer from %s:// to %s://" \
                            % (url_pair.src_proto, url_pair.dst_proto))
            myexit(1)

    except RuntimeError, err:
        logger.critical(err)
        myexit(1)


def stats_add(filename):
    global stats_total_bytes
    try:
        s = os.stat(filename)
        stats_total_bytes = stats_total_bytes + s[stat.ST_SIZE]
    except BaseException, err:
        pass # ignore


def stats_summarize():
    if stats_total_bytes == 0:
        logger.info("Stats: no local files in the transfer set")
        return

    total_secs = stats_end - stats_start
    Bps = stats_total_bytes / total_secs

    logger.info("Stats: %sB transferred in %.0f seconds. Rate: %sB/s (%sb/s)" % (
                iso_prefix_formatted(stats_total_bytes), total_secs, 
                iso_prefix_formatted(Bps), iso_prefix_formatted(Bps*8)))
    logger.info("NOTE: stats do not include third party gsiftp/srm transfers")


def iso_prefix_formatted(n):
    prefix = ""
    n = float(n)
    if n > (1024*1024*1024*1024):
        prefix = "T"
        n = n / (1024*1024*1024*1024)
    elif n > (1024*1024*1024):
        prefix = "G"
        n = n / (1024*1024*1024)
    elif n > (1024*1024):
        prefix = "M"
        n = n / (1024*1024)
    elif n > (1024):
        prefix = "K"
        n = n / (1024)
    return "%.1f %s" % (n, prefix)


def myexit(rc):
    """
    system exit without a stack trace - silly python
    """
    try:
        sys.exit(rc)
    except SystemExit:
        sys.exit(rc)


# --- policy functions -------------------------------------------------------
#
# Functions post_policy_requests and delete_policy_request are only used if
# PEGASUS_POLICY_CHECKS is set in the runtime environment.
#
# ----------------------------------------------------------------------------
def post_policy_requests(inputs):
    """
    Format a request to the policy web service with the input Transfer objects.
    @param inputs: List of transfers to process
    @type inputs: list
    @return: List of Transfer objects as modified by the web service.
    @rtype: list
    """
    logger.debug("post_policy_requests entered; input list has %d Transfers", len(inputs))
    transfers = []
    for pt in inputs:
        transfer = {'source' : pt.src_url(), 'destination' : pt.dst_url(), 'properties' : None }
        transfers.append(transfer)

    policy_requests = json.dumps(transfers)
    policy_headers = {'Content-type' : 'application/json', 'Encoding' : 'latin-1'}
    try: # TODO Change to variables instead of hardcoding
        ws = HTTPConnection(policy_host, policy_port)
        ws.request('POST', policy_url, policy_requests, policy_headers)
        resp = ws.getresponse();
        logger.info("policy web service status: %d", resp.status)
        resp_string = resp.read()
        ws.close()
    except HTTPException, e:
        logger.critical("Exception communicating with policy web servcie: %s", e.str())
        raise RuntimeError("Exception communicating with policy web servcie: %s" %(e.str()))
    except SocketError, (value, message):
        logger.critical("socket exception: [ERRNO%s] %s", value, message)
        raise RuntimeError("socket exception: [ERRNO%s] %s" %(value, message))

    logger.debug("policy web service response: %s", resp_string)
    posted_transfers = json.loads(resp_string)

    inputs = []; # wipe out old list
    pair_nr = 0
    for pt in posted_transfers:
        logger.debug("source = %s destination = %s properties = {%s} id = %s",
                      pt['source'], pt['destination'], pt['properties'], pt['id'])
        pair_nr +=1
        policy_transfer = Transfer(pair_nr)
        policy_transfer.set_src(pt['source']);
        policy_transfer.set_dst(pt['destination'])
        policy_transfer.set_policy_id(pt['id'])
        logger.debug("appending Transfer %d, id=%s", pair_nr, policy_transfer.policy_id )
        inputs.append(policy_transfer)

    logger.debug("post_policy_requests return")
    return inputs

def delete_policy_requests(transfers):
    """
    Send a DELETE request to the policy web service. Uses the 'policy_id'
    attribute of each Transfer object to remove that transfer from the web
    service database.
    @param transfers: A list of transfers to process.
    @type transfers: list
    @return: Nothing.
    """
    logger.debug("delete_policy_requests enter; input list has %d Transfers", len(transfers))
    try: 
        ws = HTTPConnection(policy_host, policy_port)
        for transfer in transfers:
            delete_request = policy_url + transfer.policy_id
            logger.debug("Sending DELETE request: %s", delete_request)
            ws.request('DELETE', delete_request)
            resp = ws.getresponse();
            logger.debug("status for %s: %d", delete_request, resp.status)

        ws.close()
    except HTTPException, e:
        logger.critical("Exception communicating with policy web servcie: %s", e.str())
        raise RuntimeError("Exception communicating with policy web servcie: %s" %(e.str()))
    except SocketError, (value, message):
        logger.critical("socket exception: [ERRNO%s] %s", value, message)
        raise RuntimeError("socket exception: [ERRNO%s] %s" %(value, message))

    logger.debug("delete_policy_requests exit")


# --- main ----------------------------------------------------------------------------

# dup stderr onto stdout
sys.stderr = sys.stdout

# Configure command line option parser
prog_usage = "usage: %s [options]" % (prog_base)
parser = optparse.OptionParser(usage=prog_usage)
parser.add_option("-l", "--loglevel", action = "store", dest = "log_level",
                  help = "Log level. Valid levels are: debug,info,warning,error, Default is info.")
parser.add_option("-f", "--file", action = "store", dest = "file",
                  help = "File containing URL pairs to be transferred. If not given, list is read from stdin.")
parser.add_option("", "--max-attempts", action = "store", type="int", dest = "max_attempts", default = 2,
                  help = "Number of attempts allowed for each transfer. Default is 2.")

# Check environment to decide whether to use a Policy Service.
if 'PEGASUS_POLICY_CHECKS' in os.environ:
    true_synonyms = ('true', '1', 't', 'y', 'yes', 'on', 'enabled', 'enable')
    using_policy_service = os.environ['PEGASUS_POLICY_CHECKS'].lower() in true_synonyms
    
if using_policy_service:
    # Extra imports to make policy service work. If not available, fail now
    try:
        import json
    except ImportError:
        print >>sys.stderr, "Cannot use Policy Check: Failed to import JSON library"
        sys.exit(1)

    try:
        from httplib import HTTPConnection, HTTPException
    except ImportError:
        print >>sys.stderr, "Cannot use Policy Check: Failed to import HTTP library"
        sys.exit(1)

    from socket import error as SocketError

    policy_host = "localhost"
    policy_port = 80
    policy_url  = "/policy/transfer/"

    parser.add_option("", "--policy-host", action = "store", dest = "policy_host",
                      help = "hostname for the Policy web service; default is '%s'" %(policy_host))
    parser.add_option("", "--policy-port", type="int", dest = "policy_port",
                      help = ("Port used by the policy web service; default is %d" %(policy_port)))
    parser.add_option("", "--policy-url", dest = "policy_url",
                      help = ("URL of the policy web service; default is '%s'" %(policy_url)))

    if 'PEGASUS_POLICY_HOST' in os.environ:
        policy_host = os.environ['PEGASUS_POLICY_HOST']
        
    if 'PEGASUS_POLICY_PORT' in os.environ:
        policy_port = int(os.environ['PEGASUS_POLICY_PORT'])
        
    if 'PEGASUS_POLICY_URL' in os.environ:
        policy_url = os.environ['PEGASUS_POLICY_URL']  

# Parse command line options
(options, args) = parser.parse_args()
if options.log_level == None:
    options.log_level = "info"
setup_logger(options.log_level)

# If we're using a Policy Service, check the command line for environment
# overrides.
if using_policy_service:
    if options.policy_host != None:
        policy_host = options.policy_host
    if options.policy_port != None:
        policy_port = options.policy_port
    if options.policy_url != None:
        policy_url = options.policy_url

# Die nicely when asked to (Ctrl+C, system shutdown)
signal.signal(signal.SIGINT, prog_sigint_handler)

attempts_max = options.max_attempts

# stdin or file input?
if options.file == None:
    logger.info("Reading URL pairs from stdin")
    input_file = sys.stdin
else:
    logger.info("Reading URL pairs from %s" % (options.file))
    try:
        input_file = open(options.file, 'r')
    except Exception, err:
        logger.critical('Error reading url pair list: %s' % (err))
        myexit(1)

# check environment and tools
try:
    check_env_and_tools()
except Exception, err:
    logger.critical(err)
    myexit(1)

# queues to track the work
transfer_q = deque()
failed_q = deque()

# fill the transfer queue with user provided entries
line_nr = 0
pair_nr = 0
inputs = []
url_first = True
try:
    for line in input_file.readlines():
        line_nr += 1
        if line[0] != '#' and len(line) > 4:
            line = line.rstrip('\n')
            if url_first:
                pair_nr += 1
                url_pair = Transfer(pair_nr)
                url_pair.set_src(line)
                url_first = False
            else:
                url_pair.set_dst(line)
                inputs.append(url_pair)
                url_first = True
except Exception, err:
    logger.critical('Error reading url pair list: %s' % (err))
    myexit(1)

# Check our policy service flag and branch appropriately.
if using_policy_service:
    # Send the input list to the policy server before continuing
    logger.info("Using policy web service at %s on port %d with URL %s",
                policy_host, policy_port, policy_url)
    inputs = post_policy_requests(inputs)
else:
    # we will now sort the list as some tools (gridftp) can optimize when
    # given a group of similar transfers
    logger.info("Sorting the tranfers based on transfer type and source/destination")
    inputs.sort()

transfer_q = deque(inputs)

# start the stats time
stats_start = time.time()

# attempt transfers until the queue is empty
done = False
attempt_current = 0
while not done:

    attempt_current = attempt_current + 1
    logger.info("----------------------------------------------------------------------")
    logger.info("Starting transfers - attempt %d" % (attempt_current))

    # do the transfers
    while transfer_q:
        
        t_main = transfer_q.popleft()
        
        # create a list of transfers to pass to underlying tool
        t_list = []
        t_list.append(t_main)

        try:
            t_next = transfer_q[0]
        except IndexError, err:
            t_next = False
        while t_next and transfers_groupable(t_main, t_next):
            t_list.append(t_next)
            transfer_q.popleft()
            try:
                t_next = transfer_q[0]
            except IndexError, err:
                t_next = False

        # magic!
        handle_transfers(t_list, failed_q, attempt_current)
        logger.debug("%d items in failed_q" %(len(failed_q)))
        if using_policy_service:
            # Remove these from policy web service
            delete_policy_requests(t_list) 
    
    # are we done?
    if attempt_current == attempts_max or not failed_q:
        done = True
        break
    
    # retry failed transfers with a delay
    if failed_q and attempt_current < attempts_max:
        time.sleep(10) # do not sleep too long - we want to give quick feed back on failures to the workflow
    while failed_q:
        t = failed_q.popleft()
        t.allow_grouping = False # only allow grouping on the first try
        transfer_q.append(t)

# end the stats timer and show summary
stats_end = time.time()
stats_summarize()

if failed_q:
    logger.critical("Some transfers failed! See above, and possibly stderr.")
    myexit(1)

logger.info("All transfers completed successfully.")

myexit(0)


