##\file searchreplace.py
##\brief Part of \b MOPPY -- A Meta Object Programming toolkit for Python
##\author Joel Boehland

import sys
import os
import inspect
import new
import types
import re

#usually we'll want to ignore these
stdIgnore = [
    '__class__',
    '__delattr__',
    '__dict__',
    '__doc__',
    '__getattribute__',
    '__hash__',
    '__init__',
    '__module__',
    '__new__',
    '__reduce__',
    '__repr__',
    '__setattr__',
    '__str__',
    '__weakref__',
    '__builtins__',
    '__base__',
    '__bases__',
    '__basicsize__',
    '__call__',
    '__cmp__',
    '__dictoffset__',
    '__flags__',
    '__itemsize__',
    '__mro__',
    'mro',
    '__name__',
    '__subclasses__',
    '__weakrefoffset__',
    '__file__',
    '__path__']

#used to cache module lookups. Used in getmodule()
modulesbyfile = {}

class SearchContext(object):
    """
    Provide a context holding useful objects during a recursive searching
    operation.
    """
    def __init__(self, nameFilters, objectFilters, breadthFirst=True, singleSelect=False, debugWalk=False):
        self.searchDone = False
        self.visitedIdx = {}
        self.resultSet = []
        self.currentName = None
        self.currentObj = None
        self.parentName = None
        self.parentObj = None
        self.currentMod = None
        self.nameFilters = nameFilters
        self.objectFilters = objectFilters
        self.breadthFirst = breadthFirst
        self.singleSelect = singleSelect
        self.debugWalk = debugWalk

class MemberNotFoundException(Exception):
    """
    Raise this if the desired member wasn't found.
    """
    pass

class AnyStringFilter(object):
    """
    Matches any string.
    """
    def match(self, ctx):
        return True

class NotStringFilter(object):
    """
    Filter that returns true if the provided
    string is not equal to any strings in the
    provided list
    """
    def __init__(self, badStrings):
        self.badStrings = badStrings

    def match(self, ctx):
        return ctx.currentName not in self.badStrings

#set up a common use case
STD_IGNORE = NotStringFilter(stdIgnore)

class RequiredStringFilter(object):
    """
    Filter that returns true if the other string is
    equal to one of the strings in the required list.
    """
    def __init__(self, goodStrings):
        self.goodStrings = goodStrings

    def match(self, ctx):
        return ctx.currentName in self.goodStrings

class ObjectNameFilter(object):
    """
    Filter that returns true if the other string is
    equal to one of the strings in the required list.
    """
    def __init__(self, nameRE):
        self.nameRE = re.compile(nameRE)

    def match(self, ctx):
        retVal = False
        retVal = self.nameRE.match(ctx.currentName) is not None
        return retVal

class ObjectEqualsFilter(object):
    """
    Filter that returns true if the other object is
    equal to the provided object.
    """

    def __init__(self, obj):
        self.obj = obj
        
    def match(self, ctx):
        return self.obj == ctx.currentObj

class ObjectComparatorFilter(object):
    """
    Filter that returns true if the provided
    comparator function returns true
    """

    def __init__(self, comparatorFunc):
        self.comparatorFunc = comparatorFunc
        
    def match(self, ctx):
        return self.comparatorFunc(ctx.currentObj)

class ClassFilter(object):
    """
    Filter that returns true if the provided
    comparator function returns true AND,
    the parentObj is a class
    """

    def __init__(self, comparatorFunc):
        self.comparatorFunc = comparatorFunc
        
    def match(self, ctx):
        retVal = False
        if inspect.isclass(ctx.parentObj):
            retVal = self.comparatorFunc(ctx.parentObj)
        return retVal
    
class ModuleNameFilter(object):
    """
    Filter that matches on a module name. Can be in glob
    form as well (i.e. somepkg.subpkg*)
    """
    def __init__(self, modNameRE):
        self.modRE = re.compile(modNameRE)

    def match(self, ctx):
        retVal = False
        mod = getmodule(ctx.currentObj)
        if mod is not None:
            retVal = self.modRE.match(mod.__name__) is not None
        
        return retVal

class ClassNameCriteria(object):
    def __init__(self, className, negativeMatch=False):
        self.classNameRE = re.compile(className)
        self.negativeMatch = negativeMatch

    def match(self, ctx):
        retVal = False
        if inspect.isclass(ctx.currentObj):
            fullName = ctx.currentObj.__module__+"."+ctx.currentName
            retVal = self.classNameRE.match(fullName) is not None
            if self.negativeMatch:
                return retVal * -1
            
        return retVal

class ClassMemberCriteria(object):
    def __init__(self, className, negativeMatch=False):
        self.classNameRE = re.compile(className)
        self.negativeMatch = negativeMatch

    def match(self, ctx):
        retVal = False
        if inspect.isclass(ctx.parentObj):
            #need name of class
            className = ctx.parentObj.__module__+"."+ctx.parentName
            retVal = self.classNameRE.match(className) is not None
            if self.negativeMatch:
                return retVal * -1
            
        return retVal
    
class ModuleMemberCriteria(object):
    def __init__(self, modName, negativeMatch=False):
        self.modNameRE = re.compile(modName)
        self.negativeMatch = negativeMatch

    def match(self, ctx):
        retVal = False
        if inspect.ismodule(ctx.parentObj):
            #need name of module
            modName = ctx.parentObj.__name__
            retVal = self.modNameRE.match(modName) is not None
            if self.negativeMatch:
                return retVal * -1
            
        return retVal

class FunctionNameCriteria(object):
    def __init__(self, functionName, negativeMatch=False):
        self.functionNameRE = re.compile(functionName)
        self.negativeMatch = negativeMatch
        
    def match(self, ctx):
        retVal = False
        if inspect.isfunction(ctx.currentObj):
            retVal = self.functionNameRE.match(ctx.currentName) is not None
            if self.negativeMatch:
                return retVal * -1
            
        return retVal
    
class MethodNameCriteria(object):
    def __init__(self, methodName, negativeMatch=False):
        self.methodNameRE = re.compile(methodName)
        self.negativeMatch = negativeMatch
        
    def match(self, ctx):
        retVal = False
        if inspect.ismethod(ctx.currentObj):
            retVal = self.methodNameRE.match(ctx.currentName) is not None
            if self.negativeMatch:
                return retVal * -1
            
        return retVal

###XXX--these aren't done yet
class AttributeSetterCriteria(object):
    def __init__(self, attName):
        self.attName = attName

class AttributeGetterCriteria(object):
    def __init__(self, attName):
        self.attName = attName



def applyFilters(object, filterList):
    """
    Return True if the provided object passes through
    the filter list. Return False otherwise.
    """

    if None == filterList: return True
    if 0 == len(filterList): return True
    
    retVal = False
    for filter in filterList:
        if not filter.match(object):
            return retVal
    #object passed bank of filters
    return True


def getModuleBaseFilename(modfileName):
    """Get the root filename of a module without the
    suffix i.e somemod.py-->somemod"""
    idx = modfileName.rfind(".")
    baseName = modfileName[0:idx]
    return baseName


def getmodule(object):
    """
    Return the module an object was defined in, or None if not found.
    
    This was taken and modified slightly from the getmodule method in the inspect
    std python lib
    """
    if inspect.ismodule(object):
        return object
    if inspect.isclass(object):
        return sys.modules.get(object.__module__)
    try:
        file = inspect.getabsfile(object)
    except TypeError:
        if hasattr(object, "__class__"):
            #this is a class instance
            try:
                file = inspect.getabsfile(object.__class__)
            except TypeError:
                return None
    
    baseName = getModuleBaseFilename(file)
    
    if modulesbyfile.has_key(baseName):
        return sys.modules[modulesbyfile[baseName]]

    for module in sys.modules.values():
        if hasattr(module, '__file__'):
            bname = getModuleBaseFilename(inspect.getabsfile(module))
            modulesbyfile[bname] = module.__name__

    if modulesbyfile.has_key(baseName):
        mod = sys.modules[modulesbyfile[baseName]]
        return mod

    main = sys.modules['__main__']
    if recurseContainsObject(main, object):
        return main
                    
    builtin = sys.modules['__builtin__']
    if recurseContainsObject(builtin, object):
        return builtin

    return None

def recurseContainsObject(container, obj):
    """
    Return True if the provided object lives somewhere
    in the hierarchy starting with the container object.
    Return False otherwise.
    """
    nameFilters = []
    if hasattr(obj, '__name__'):
        reqStrings = [obj.__name__]
        nameFilters.append(RequiredStringFilter(reqStrings))
    
    objectFilter = ObjectEqualsFilter(obj)
    ctx = SearchContext(nameFilters, [objectFilter])
    recurseSelectMembers(container, ctx)
    
    return len(ctx.resultSet) > 0 and ctx.resultSet[0] is not None

def recurseGetParentObject(container, obj):
    """
    Return the object that contains the provided object. For
    example, if you pass it the class method A.foo(), it will
    return the A class
    """
    retVal = None
    nameFilters = []
    #should we do this? What if an object is given
    #an alias name?
    if hasattr(obj, '__name__'):
        reqStrings = [obj.__name__]
        nameFilters.append(RequiredStringFilter(reqStrings))
    
    objectFilter = ObjectEqualsFilter(obj)
    ctx = SearchContext(nameFilters, [objectFilter], singleSelect=True, debugWalk=True)
    _recurseSelectMembers(container, ctx)

    if len(ctx.resultSet) > 0:
        retVal = ctx.parentObj

    return retVal

def recurseGetNormalMembers(obj):
    """
    Recursively return all normal (non builtin)
    members of the provided object
    """
    nameFilter = [STD_IGNORE]
    ctx = SearchContext(nameFilter,[])
    return recurseSelectMembers(obj, ctx)

def recurseGetAllMembers(obj):
    """
    Recursively return all members of the provided object
    """
    ctx = SearchContext([],[])
    return recurseSelectMembers(obj, ctx)

def recurseSelectMembers(obj, ctx):
    """
    Recursively search through all members of an
    object, and return those that match the search criteria
    """
    _recurseSelectMembers(obj, ctx)    
    return ctx.resultSet

def _recurseSelectMembers(obj, ctx):
    """
    Recursively search through all members of an
    object, and place those that match the search criteria
    into the resultset of the provided SearchContext
    object
    """
    #short circut a recursive search if searchDone
    #flag set to true
    if ctx.searchDone:
        return
    
    descentList = []
    for name in dir(obj):
        o = getattr(obj, name)

        #no duplicates.
        if id(o) in ctx.visitedIdx:
            continue
            
        #set up ctx vars for use in filters
        ctx.currentName = name
        ctx.currentObj = o
        otyp = type(o)
        ctx.parentObj = obj
        #try to set the parentName
        if ctx.parentName is None:
            if inspect.ismodule(obj):
                ctx.parentName = obj.__name__
                
        if applyFilters(ctx, ctx.nameFilters):
            
            if applyFilters(ctx, ctx.objectFilters): 
                #passed filters, add object to result set
                ctx.visitedIdx[id(o)] = o
                ctx.resultSet.append(o)
                if ctx.singleSelect:
                    ctx.searchDone = True
         
        #determine whether we recurse this object's children
        #XXX--should have recursion filter???
        if otyp in (types.TypeType, types.ClassType, types.ModuleType):
            ctx.visitedIdx[id(o)] = o
            if ctx.breadthFirst:
                tup = (name, o)
                descentList.append(tup)
            else:
                ctx.parentName = name
                _recurseSelectMembers(o, ctx)
                
    if ctx.breadthFirst:
        for item in descentList:
            ctx.parentName = item[0]
            _recurseSelectMembers(item[1], ctx)

def isClassMethod(func):
    # The "im_self" attribute is the
    # same as im_class for classmethod objects 
    return type(func.im_class) == type(func.im_self)


def replaceInGlobals(obj, replacement):
    """
    Walk each module
    to find any instances of this object that were imported
    into the module's namespace, and replace it
    """
    for mod in sys.modules.values():
        for name in dir(mod):
            if getattr(mod, name) is obj:
                setattr(mod, name, replacement)
                

def replaceCallable(func, replacement, replaceGlobals=True):
    """
    Replace callable object with replacement callable
    """

    #sanity checks
    if not callable(func):
        raise TypeError("Object: ",func," is not callable")
    if not callable(replacement):
        raise TypeError("Object: ",replacement," is not callable")

    funcType = type(func)
    #print "replaceCallable: func: ",func, "replacement", replacement
    #print "callable type: ",funcType
    #print "dir func: ",dir(func)
        
    if funcType == types.MethodType:
        cls = func.im_class
        if isClassMethod(func):
            newmethod = replacement
        else:
            newmethod = new.instancemethod(replacement, None, cls)
        setattr(cls, func.__name__, newmethod)
        if replaceGlobals: replaceInGlobals(func, replacement)        
    elif funcType == types.FunctionType:
        mod = getmodule(func)
        if None == mod:
            raise MemberNotFoundException("Couldn't find owner object for callable: ",func)

        #If this is a module-level function, it will
        #be an attribute of the module object. If it
        #is a class static method, its type will be
        #'function', but it won't be listed in the module's
        #attributes--we'll have to find it in a member class of the
        #module
        if hasattr(mod, func.__name__):
            setattr(mod, func.__name__, replacement)
            if replaceGlobals: replaceInGlobals(func, replacement)
        else:
            found = False
            parent = recurseGetParentObject(mod, func)
            if parent is not None:
                replacement = staticmethod(replacement)
                setattr(parent, func.__name__, replacement)
                if replaceGlobals: replaceInGlobals(func, replacement)
                found = True

            if not found:
                raise MemberNotFoundException("Couldn't find owner object for callable: ",func)