#!/usr/bin/env python2
"""
Welcome to wikkit %ver, a wickedly fast parallel HTTP/FTP downloader.

Usage: %prog [options] <URL>
%opts

Alternate Usage: %prog [options] <URL_pattern> <filename_pattern>
  Downloads multiple URLs including a counter. The source pattern can contain
  counter specifiers in the '[START-END]' form, where START and END are both
  numbers and the number of digits in START specifies how many digits the
  counter shall have. In the output filename pattern, question marks ('?')
  are replaced by the current value of the counter. If multiple counters are
  defined, the question marks will expand to the numbers in the same order as
  in the URL pattern.
Example:
  %prog "http://www.example.com/imgs/img[01-42].jpg" "example_?.jpg"
  -- will download img01.jpg, img02.jpg, ... img42.jpg from the server and
     store them in example_01.jpg, example_02.jpg, ... example_42.jpg
"""
__version__ = "0.3"
import sys, os, thread, time, signal, urllib, urllib2, urlparse
import re, fnmatch, optparse

################################################################################
## LIBRARY PART                                                               ##
################################################################################

DefaultUserAgent = "Opera/0.07 (X1337; Lunix i286; U)"
DefaultBlockSize = 1024
DefaultMaxDataLength = 1024 * 1024

HTMLExtensions = "html htm xhtml xml php php3 asp apsx msp mspx shtml"
Extensions = """
    html: "+HTMLExtensions+"
    jpeg: jpg jpeg jpe
    png: png apng mng
    gif: gif
    video: mpg mpeg mpv m2v avi asf wmv flv mov mp4
    mp3: mp3 mpa
"""

################################################################################

nex = {}
for line in [line.strip().split(':') for line in Extensions.strip().split("\n")]:
    t = line[0].strip()
    for x in filter(None, line[1].split()):
        nex['.' + x] = t
Extensions = nex
HTMLExtensions = tuple(HTMLExtensions.split() + [""])
del nex

def killfile(f, fn):
    if not(f) or (fn == '-'): return
    f.close()
    try:
        os.remove(fn)
    except OSError:
        pass

def is_html(s):
    s = s.lower()
    for p in ('<html', '<head', '<body', '<!doctype', '<?xml'):
        if s.find(p) >= 0:
            return True
    return False

def thousand(n):
    s = str(n)
    p = len(s) - 3
    while p > 0:
        s = s[:p] + ',' + s[p:]
        p -= 3
    return s

################################################################################

Relative, HostAbsolute, Absolute = range(3)

re_script = re.compile(r'<script.*?</script>', re.I + re.S)
re_link = re.compile(r'<(a)\s[^>]*href=([\'"])?(.*?)(\2.*?>|\s.*?>|>)(.*?)</a.*?>', re.I + re.S)
re_img = re.compile(r'<(img)\s[^>]*src=([\'"])?(.*?)(\2.*?>|\s.*?>|>)', re.I)
re_tag = re.compile(r'<.*?>')
class Link:
    def __init__(self, m, url=""):
        sproto, shost, spath, squery, sfrag = urlparse.urlsplit(url)
        if spath.endswith('/'):
            sdir = spath
        else:
            sdir = os.path.dirname(spath) + '/'
        self.tag = m.group(1).lower()
        self.outer_html = m.group(0)
        self.href = m.group(3)
        self.url = urlparse.urljoin(url, self.href)
        self.protocol, self.host, self.path, self.query, self.fragment = urlparse.urlsplit(self.url)
        self.file = os.path.basename(self.path)
        self.ext = os.path.splitext(self.file)[-1][1:].lower()
        self.same_host = (self.host == shost)
        self.same_dir = self.same_host and self.path.startswith(sdir)
        if ':' in self.href[:8]:
            self.type = Absolute
        elif self.href.startswith('/'):
            self.type = HostAbsolute
        else:
            self.type = Relative
        if self.tag == "a":
            self.inner_html = m.group(5)
            self.text = re_tag.sub("", self.inner_html)
        elif self.tag == "img":
            self.text = "img"

    def match(self, pattern):
        if pattern.startswith('.'):
            pattern = '*' + pattern
        return fnmatch.fnmatch(self.file.lower(), pattern.lower())

def InitialLink(url):
    return Link(re_link.match("<a href=\"%s\"></a>" % url))

################################################################################

TotalNet = 0L
TotalDisk = 0L
TotalFiles = 0
global_kill = False
NetStatus = "nothing downloaded yet"
FileStatus = "nothing saved yet"

class Request:
    def __init__(self, url=None, save_file=None, **kwargs):
        self.url = url
        self.save_file = save_file
        self.referer = url
        self.user_agent = DefaultUserAgent
        self.finish_func = None
        self.block_size = 1024
        self.max_data_len = DefaultMaxDataLength
        self.expected_type = 'auto'
        self.ok = False
        self.data = ""
        self.length = 0
        self.__dict__.update(kwargs)

    def run(self):
        global TotalNet, TotalDisk, TotalFiles
        web = None
        if self.save_file == '-':
            f = sys.stdout
        else:
            f = None
        try:
            req = urllib2.Request(self.url)
            req.add_header("Referer", self.referer)
            req.add_header("User-Agent", self.user_agent)
            try:
                web = urllib2.urlopen(req)
            except urllib2.HTTPError, e:
                self.log_error("HTTP %d" % e.code)
                return
            except urllib2.URLError, e:
                self.log("URL Error %d: %s" % e.reason.args)
                return

            if self.save_file:
                try:
                    d = os.path.dirname(self.save_file)
                    if d:
                        os.makedirs(d)
                except OSError:
                    pass
                try:
                    f = file(self.save_file, "wb")
                except IOError:
                    self.log_error("cannot write output file")
                    return

            first_block = True
            while True:
                if global_kill:
                    killfile(f, self.save_file)
                    return

                try:
                    bdata = web.read(self.block_size)
                except KeyboardInterrupt:
                    raise
                except:
                    killfile(f, self.save_file)
                    web.close()
                    return self.log("Error while reading data")

                if first_block:
                    if not getattr(self, "check_" + self.expected_type)(bdata):
                        killfile(f, self.save_file)
                        self.log_error("invalid data")
                        return
                    first_block = False

                if not bdata:
                    break
                self.length += len(bdata)
                TotalNet += len(bdata)
                if len(self.data) < self.max_data_len:
                    self.data += bdata
                if f:
                    try:
                        f.write(bdata)
                        TotalDisk += len(bdata)
                    except IOError:
                        killfile(f, self.save_file)
                        self.log_error("content type mismatch")
                        return
            if f:
                TotalFiles += 1
                if self.save_file != '-':
                    f.close()
            self.ok = True
            self.log("%s bytes" % thousand(self.length))
        finally:
            if web: web.close()
            if self.finish_func:
                self.finish_func(self)

    def log(self, msg=None):
        if self.save_file:
            loc = "%s => %s" % (self.url, self.save_file)
        else:
            loc = self.url
        if msg:
            print >>sys.stderr, "%s [%s]" % (loc, msg)
        else:
            print >>sys.stderr, loc
    def log_error(self, msg=None):
        if msg:
            self.log("ERROR: " + msg)
        else:
            self.log("ERROR")

    def check_auto(self, data):
        dummy1, dummy2, path, dummy3, dummy4 = urlparse.urlsplit(self.url)
        t = Extensions.get(os.path.splitext(path)[-1].lower(), 'none')
        return getattr(self, "check_" + t)(data)
    def check_html(self, data):
        return is_html(data)
    def check_jpeg(self, data):
        return data.startswith("\xff\xd8\xff")
    def check_png(self, data):
        return data.startswith("\x89PNG")
    def check_gif(self, data):
        return data.startswith("GIF8")
    def check_video(self, data):
        return data.startswith("\x30\x26\xb2\x75\x8e\x66\xcf\x11") \
            or data.startswith("RIFF") \
            or data.startswith("\x00\x00\x01") \
            or data.startswith("FLV\x01") \
            or (data[4:8] in ("mdat", "ftyp")) \
            or (data[8:12] == "ftyp")
    def check_mp3(self, data):
        return data.startswith("\xff") \
            or data.startswith("RIFF") \
            or data.startswith("ID3")
    def check_none(self, data):
        return True

    def ParseLinks(self):
        noscript = re_script.sub("", self.data)
        for m in re_link.finditer(noscript):
            yield Link(m, self.url)
        for m in re_img.finditer(noscript):
            yield Link(m, self.url)

################################################################################

running = False
current_threads = 0
max_threads = 0
visited_urls = {}
queue = []
lock = thread.allocate_lock()

def WorkerThread(dummy):
    global current_threads
    try:
        req = True
        while req and not(global_kill):
            lock.acquire()
            try:
                try:
                    req = queue.pop(0)
                except IndexError:
                    req = None
            finally:
                lock.release()
            if req:
                req.run()
    finally:
        lock.acquire()
        try:
            current_threads -= 1
        finally:
            lock.release()

def AddRequest(req, force=False):
    global running, current_threads, max_threads
    lock.acquire()
    try:
        if not(force) and (req.url in visited_urls):
            return False
        visited_urls[req.url] = True
        queue.append(req)
        if running and (current_threads < max_threads):
            current_threads += 1
            thread.start_new_thread(WorkerThread, (current_threads,))
        return True
    finally:
        lock.release()

def RunThreads(count=4):
    global running, current_threads, max_threads, global_kill
    global TotalNet, TotalDisk, NetStatus, FileStatus
    if running: return
    if not queue: return
    global_kill = False
    TotalNet = 0L
    TotalDisk = 0L
    t = time.time()
    current_threads = 0
    max_threads = count
    running = True
    current_threads = min(len(queue), count)
    for i in range(current_threads):
        thread.start_new_thread(WorkerThread, (i,))
    cont = True
    while cont:
        try:
            time.sleep(1)
            lock.acquire()
            try:
                cont = (len(queue) and not(global_kill)) or current_threads
            finally:
                lock.release()
        except KeyboardInterrupt:
            if global_kill: raise
            print >>sys.stderr, "^C: stopping all downloads ..."
            global_kill = True
    running = False
    t = int(time.time() - t + 0.9)
    NetStatus = "downloaded %s bytes in %s seconds -> %s bytes/sec" % \
                (thousand(TotalNet), thousand(t), thousand(TotalNet / t))
    FileStatus = "saved %s bytes in %s files" % \
                 (thousand(TotalDisk), thousand(TotalFiles))
    print >>sys.stderr, NetStatus
    print >>sys.stderr, FileStatus

################################################################################
## APPLICATION PART                                                           ##
################################################################################

def OnDownloadDone(req):
    if not(Recursive and req.ok and is_html(req.data) and req.level):
        return
    for link in req.ParseLinks():
        CheckAndSubmit(link, req.level - 1, referer=req.url)

def CheckAndSubmit(link, level=0, initial=False, output=None, referer=None):
    # check if the file matches the patterns
    if LinksOnly and (link.tag != "a"):
        pattern_match = False
    elif Accept:
        pattern_match = False
        for a in Accept:
            if link.match(a):
                pattern_match = True
                break
    else:
        pattern_match = True
    if pattern_match:
        for r in Reject:
            if link.match(r):
                pattern_match = False
                break

    # filter for hierarchy restrictions
    if SpanHosts:
        host_match = True
    elif BaseDirs:
        host_match = link.same_host
    else:
        host_match = link.same_dir

    # force download if it's likely to be a html document
    is_html = link.ext in HTMLExtensions

    # reject irrelevant URLs
    need_download = initial or (host_match and (is_html or pattern_match))
#    print link.url, "pat:%d host:%d html:%d initial:%d ->" % (pattern_match, host_match, is_html, initial), need_download
    if not need_download: return

    # build and enqueue request object
    req = Request(link.url)
    if Recursive:
        req.level = level
        if referer:
            req.referer = referer
        req.finish_func = OnDownloadDone
        if pattern_match:
            req.save_file = link.host + link.path
            if not link.ext:
                req.save_file = os.path.join(req.save_file, ".index.html")
    elif output:
        req.save_file = output
    elif link.ext:
        req.save_file = link.file
    else:
        req.save_file = ".index.html"
#    if initial:
    AddRequest(req)

re_iter = re.compile(r'\[(\d+)-(\d+)\]')
def ResolveIterators(url, dest):
    m = re_iter.search(url)
    if not m:
        return [(url, dest)]
    if dest:
        p = dest.find('?')
    else:
        p = -1
    fmt = "%%0%dd" % len(m.group(1))
    ip = url[:m.start()]
    fp = url[m.end():]
    res = []
    for i in xrange(int(m.group(1)), int(m.group(2))+1):
        url = ip + fmt%i + fp
        if p >= 0:
            res.extend(ResolveIterators(url, dest[:p] + fmt%i + dest[p+1:]))
        else:
            res.extend(ResolveIterators(url, dest))
    return res

def ParseParameters(plist):
    if not(plist) or (len(plist) > 2):
        return False
    if len(plist) == 1:
        plist = (plist[0], None)
    for url, dest in ResolveIterators(*plist):
        if ReallyDownload:
            CheckAndSubmit(InitialLink(url), Levels, True, dest)
        elif dest:
            print >>sys.stderr, url, "=>", dest
            TotalFiles += 1
        else:
            print >>sys.stderr, url    
    return True

if __name__ == "__main__":
    parser = optparse.OptionParser(usage=optparse.SUPPRESS_USAGE, version=__version__, add_help_option=False)
    parser.add_option("-h", "--help", action="store_true", help="show this help message and exit")
    parser.add_option("-O", "--output", dest="DestFile", metavar="FILE",
                      type="string", default=None,
                      help="write document(s) to FILE (identical to the <dest_filename> parameter; `-' for stdout)")
    parser.add_option("-f", "--file", dest="ListFile", metavar="FILE",
                      type="string", default=None,
                      help="use file with URLs instead of parameters, `-' for stdin")
    parser.add_option("-n", "--dry-run", dest="ReallyDownload",
                      action="store_false", default=True,
                      help="don't download anything, just show what would be done")
    parser.add_option("-U", "--user-agent", dest="DefaultUserAgent", metavar="AGENT",
                      type="string", default=DefaultUserAgent,
                      help="identify as AGENT")
    parser.add_option("-j", "--threads", dest="Threads", metavar="COUNT",
                      type="int", default=4,
                      help="number of simultaneous server connections (default: 4)")
    parser.add_option("-r", "--recursive", dest="Recursive",
                      action="store_true", default=False,
                      help="enable recursive retrieval; output paths will mirror the URL structure, including the server name")
    parser.add_option("-a", "--links-only", dest="LinksOnly",
                      action="store_true", default=False,
                      help="only follow <a> links, not <img>")
    parser.add_option("-l", "--level", dest="Levels", metavar="NUMBER",
                      type="int", default=1000,
                      help="maximum recursion depth (default: infinite)")
    parser.add_option("-A", "--accept", dest="Accept", metavar="LIST",
                      type="string", default=[],
                      help="comma-separated list of accepted extensions or filename patterns"+
                           "(example: .jpg,.jpeg)")
    parser.add_option("-R", "--reject", dest="Reject", metavar="LIST",
                      type="string", default=[],
                      help="comma-separated list of rejected extensions or filename patterns "+
                           "(example: .jpg,.jpeg)")
    parser.add_option("-p", "--parent", dest="BaseDirs",
                      action="store_true", default=False,
                      help="allow to ascend into the parent directory")
    parser.add_option("-H", "--span-hosts", dest="SpanHosts",
                      action="store_true", default=False,
                      help="go to foreign hosts")
    options, args = parser.parse_args()
    if options.help:
        print __doc__.strip().replace('%ver', __version__) \
              .replace('%opts', parser.format_help().strip()) \
              .replace('%prog', os.path.basename(sys.argv[0]))
        sys.exit(0)
    globals().update(options.__dict__)

    if Accept: Accept = Accept.split(',')
    if Reject: Reject = Reject.split(',')

    if ListFile:
        if ListFile == '-':
            f = sys.stdin
        else:
            try:
                f = file(ListFile, "r")
            except IOError:
                parser.error("cannot read list file")
        n = 0
        for line in f:
            n += 1
            line = line.strip()
            if not(line) or line.startswith('#'): continue
            if not ParseParameters(line.split()):
                print >>sys.stderr, "syntax error in line %d of the list file" % n
        f.close()
    elif not ParseParameters(args):
        parser.error("invalid number of parameters")

    if TotalFiles:
        print >>sys.stderr, "%s files total" % thousand(TotalFiles)
    if ReallyDownload:
        RunThreads(Threads)
