import os
import sys

from bento.commands import hooks

import waflib
import waflib.Errors
from waflib.Task \
    import \
        Task
waflib.Logs.verbose = 1

# Importing this adds new checkers to waf configure context - I don't like this
# way of working, should find a more explicit way to attach new functions to
# context.
import numpy.build_utils.waf

from code_generators.numpy_api \
    import \
        multiarray_api, ufunc_api
from code_generators import generate_numpy_api, generate_ufunc_api, \
        generate_umath

from setup_common \
    import \
        OPTIONAL_STDFUNCS_MAYBE, OPTIONAL_STDFUNCS, C99_FUNCS_EXTENDED, \
        C99_FUNCS_SINGLE, C99_COMPLEX_TYPES, C99_COMPLEX_FUNCS, \
        MANDATORY_FUNCS, C_ABI_VERSION, C_API_VERSION

ENABLE_SEPARATE_COMPILATION = (os.environ.get('NPY_SEPARATE_COMPILATION', "1") != "0")
NPY_RELAXED_STRIDES_CHECKING = (os.environ.get('NPY_RELAXED_STRIDES_CHECKING', "0") != "0")

NUMPYCONFIG_SYM = []

# FIXME
if ENABLE_SEPARATE_COMPILATION:
    NUMPYCONFIG_SYM.append(('DEFINE_NPY_ENABLE_SEPARATE_COMPILATION', '#define NPY_ENABLE_SEPARATE_COMPILATION 1'))
else:
    NUMPYCONFIG_SYM.append(('DEFINE_NPY_ENABLE_SEPARATE_COMPILATION', ''))

if NPY_RELAXED_STRIDES_CHECKING:
    NUMPYCONFIG_SYM.append(('DEFINE_NPY_RELAXED_STRIDES_CHECKING', '#define NPY_RELAXED_STRIDES_CHECKING 1'))
else:
    NUMPYCONFIG_SYM.append(('DEFINE_NPY_RELAXED_STRIDES_CHECKING', ''))

NUMPYCONFIG_SYM.append(('VISIBILITY_HIDDEN', '__attribute__((visibility("hidden")))'))

NUMPYCONFIG_SYM.append(('NPY_ABI_VERSION', '0x%.8X' % C_ABI_VERSION))
NUMPYCONFIG_SYM.append(('NPY_API_VERSION', '0x%.8X' % C_API_VERSION))

global PYTHON_HAS_UNICODE_WIDE

def is_npy_no_signal():
    """Return True if the NPY_NO_SIGNAL symbol must be defined in configuration
    header."""
    return sys.platform == 'win32'

def define_no_smp():
    """Returns True if we should define NPY_NOSMP, False otherwise."""
    #--------------------------------
    # Checking SMP and thread options
    #--------------------------------
    # Python 2.3 causes a segfault when
    #  trying to re-acquire the thread-state
    #  which is done in error-handling
    #  ufunc code.  NPY_ALLOW_C_API and friends
    #  cause the segfault. So, we disable threading
    #  for now.
    if sys.version[:5] < '2.4.2':
        nosmp = 1
    else:
        # Perhaps a fancier check is in order here.
        #  so that threads are only enabled if there
        #  are actually multiple CPUS? -- but
        #  threaded code can be nice even on a single
        #  CPU so that long-calculating code doesn't
        #  block.
        try:
            nosmp = os.environ['NPY_NOSMP']
            nosmp = 1
        except KeyError:
            nosmp = 0
    return nosmp == 1

def write_numpy_config(conf):
    subst_dict = {}
    for key, value in NUMPYCONFIG_SYM:
        subst_dict["@%s@" % key] = str(value)
    node = conf.path.find_node("include/numpy/_numpyconfig.h.in")
    cnt = node.read()
    for k, v in subst_dict.items():
        cnt = cnt.replace(k, v)
    assert node is not None
    onode = conf.bldnode.make_node(node.path_from(conf.srcnode)).change_ext("")
    onode.write(cnt)

def write_numpy_inifiles(conf):
    subst_dict = dict([("@sep@", os.path.sep), ("@pkgname@", "numpy.core")])
    for inifile in ["mlib.ini.in", "npymath.ini.in"]:
        node = conf.path.find_node(inifile)
        cnt = node.read()
        for k, v in subst_dict.items():
            cnt = cnt.replace(k, v)
        assert node is not None
        outfile = os.path.join("lib", "npy-pkg-config", inifile)
        onode = conf.bldnode.make_node(node.path_from(conf.srcnode).replace(
                    inifile, outfile)).change_ext("")
        onode.write(cnt)

def type_checks(conf):
    header_name = "Python.h"
    features = "c pyext"
    for c_type in ("int", "long", "short"):
        macro_name = "SIZEOF_%s" % numpy.build_utils.waf.sanitize_string(c_type)
        conf.check_declaration(macro_name, header_name=header_name,
                               features=features)
        NUMPYCONFIG_SYM.append((macro_name, macro_name))

    for c_type, e_size in (("float", 4), ("double", 8), ("long double", [12, 16])):
        macro_name = "SIZEOF_%s" % numpy.build_utils.waf.sanitize_string(c_type)
        size = conf.check_type_size(c_type, header_name=header_name,
                                    features=features, expected_sizes=e_size)
        NUMPYCONFIG_SYM.append((macro_name, str(size)))

        macro_name = "SIZEOF_COMPLEX_%s" % numpy.build_utils.waf.sanitize_string(c_type)
        complex_def = "struct {%s __x; %s __y;}" % (c_type, c_type)
        size = conf.check_type_size(complex_def, header_name=header_name,
                                    features=features, expected_sizes=2*size)
        NUMPYCONFIG_SYM.append((macro_name, str(size)))

    if sys.platform != 'darwin':
        conf.check_ldouble_representation()

    size = conf.check_type_size("Py_intptr_t", header_name=header_name,
                                expected_sizes=[4, 8], features=features)
    NUMPYCONFIG_SYM.append(('SIZEOF_%s' % numpy.build_utils.waf.sanitize_string("Py_intptr_t"),
                           '%d' % size))

    size = conf.check_type_size("off_t", header_name=header_name,
                                expected_sizes=[4, 8], features=features)
    NUMPYCONFIG_SYM.append(('SIZEOF_%s' % numpy.build_utils.waf.sanitize_string("off_t"),
                           '%d' % size))

    # We check declaration AND type because that's how distutils does it.
    try:
        conf.check_declaration("PY_LONG_LONG", header_name=header_name,
                               features=features)
        size = conf.check_type_size("PY_LONG_LONG", header_name=header_name,
                                    features=features, expected_sizes=[4, 8])
        NUMPYCONFIG_SYM.append(("DEFINE_NPY_SIZEOF_LONGLONG",
                                "#define NPY_SIZEOF_LONGLONG %d" % size))
        NUMPYCONFIG_SYM.append(("DEFINE_NPY_SIZEOF_PY_LONG_LONG",
                                "#define NPY_SIZEOF_PY_LONG_LONG %d" % size))
    except waflib.Errors.ConfigurationError:
        NUMPYCONFIG_SYM.append(("DEFINE_NPY_SIZEOF_LONGLONG", ""))
        NUMPYCONFIG_SYM.append(("DEFINE_NPY_SIZEOF_PY_LONG_LONG", ""))

    conf.check_declaration("CHAR_BIT", header_name=header_name, features=features)

    # Check whether we need our own wide character support
    global PYTHON_HAS_UNICODE_WIDE
    try:
        conf.check_declaration('Py_UNICODE_WIDE', header_name=header_name, features=features)
        PYTHON_HAS_UNICODE_WIDE = False
    except waflib.Errors.ConfigurationError:
        PYTHON_HAS_UNICODE_WIDE = True

    try:
        conf.check_declaration('PyOS_ascii_strtod', header_name=header_name, features=features)
    except waflib.Errors.ConfigurationError:
        try:
            conf.check_func('strtod')
            conf.define('PyOS_ascii_strtod', 'strtod')
        except waflib.Errors.ConfigurationError:
            pass

def signal_smp_checks(conf):
    if is_npy_no_signal():
        NUMPYCONFIG_SYM.append(("DEFINE_NPY_NO_SIGNAL", "#define NPY_NO_SIGNAL\n"))
        conf.define("__NPY_PRIVATE_NO_SIGNAL", 1)
    else:
        NUMPYCONFIG_SYM.append(("DEFINE_NPY_NO_SIGNAL", ""))

    if define_no_smp():
        NUMPYCONFIG_SYM.append(("NPY_NO_SMP", 1))
    else:
        NUMPYCONFIG_SYM.append(("NPY_NO_SMP", 0))

def check_math_runtime(conf):
    header_name = "Python.h math.h"
    features = "c cprogram pyext"

    mlibs = [None, "m", "cpml"]
    mathlib = os.environ.get('MATHLIB')
    if mathlib:
        mlibs.insert(0, mathlib)

    mlib = None
    for lib in mlibs:
        try:
            if lib is None:
                kw = {}
            else:
                kw = {"lib": lib}
            st = conf.check_functions_at_once(["exp"], uselib_store="M", **kw)
            mlib = lib or []
            break
        except waflib.Errors.ConfigurationError:
            pass
    if mlib is None:
        raise waflib.Errors.ConfigurationError("No math lib found !")

    # XXX: this is ugly: mathlib has nothing to do in a public header file
    NUMPYCONFIG_SYM.append(('MATHLIB', ','.join(mlib)))

    # FIXME: look more into those additional mandatory functions
    MANDATORY_FUNCS.extend(["pow"])
    conf.check_functions_at_once(MANDATORY_FUNCS, use="M")

    #mfuncs = ('expl', 'expf', 'log1p', 'expm1', 'asinh', 'atanhf', 'atanhl',
    #          'rint', 'trunc')
    #conf.check_functions_at_once(mfuncs, use="M")

    header_name = "Python.h math.h"
    # XXX: with MSVC compiler, one needs to have cprogram defined. Find out why.
    features = "c pyext cprogram"
    for f in OPTIONAL_STDFUNCS_MAYBE:
        try:
            conf.check_declaration("HAVE_%s" % numpy.build_utils.waf.sanitize_string(f),
                                    header_name=header_name,
                                    features=features)
            OPTIONAL_STDFUNCS.remove(f)
        except waflib.Errors.ConfigurationError:
            pass

    conf.check_functions_at_once(OPTIONAL_STDFUNCS,
            features=features, mandatory=False, use="M")
    conf.check_functions_at_once(C99_FUNCS_SINGLE,
            features=features, mandatory=False, use="M")
    conf.check_functions_at_once(C99_FUNCS_EXTENDED,
            features=features, mandatory=False, use="M")
    # TODO: add OPTIONAL_HEADERS, OPTIONAL_INTRINSICS and
    # OPTIONAL_GCC_ATTRIBUTES (see setup.py and gh-3766).  These are
    # performance optimizations for GCC.

    for f in ["isnan", "isinf", "signbit", "isfinite"]:
        try:
            conf.check_declaration("HAVE_DECL_%s" % f.upper(), header_name=header_name,
                                   features=features)
            NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_DECL_%s' % f.upper(),
                '#define NPY_HAVE_DECL_%s' % f.upper()))
        except waflib.Errors.ConfigurationError:
            try:
                conf.check_declaration(f, header_name=header_name, features=features)
                NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_DECL_%s' % f.upper(),
                    '#define NPY_HAVE_DECL_%s' % f.upper()))
            except waflib.Errors.ConfigurationError:
                NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_DECL_%s' % f.upper(), ''))

def check_complex(conf):
    if conf.check_header("complex.h", mandatory=False):
        NUMPYCONFIG_SYM.append(('DEFINE_NPY_USE_C99_COMPLEX',
                                '#define NPY_USE_C99_COMPLEX 1'))
        for t in C99_COMPLEX_TYPES:
            try:
                conf.check_type(t, header_name='complex.h')
                NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_%s' % numpy.build_utils.waf.sanitize_string(t),
                                        '#define NPY_HAVE_%s' % numpy.build_utils.waf.sanitize_string(t)))
            except waflib.Errors.ConfigurationError:
                NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_%s' % numpy.build_utils.waf.sanitize_string(t), ''))

        for prec in ["", "f", "l"]:
            flist = [f + prec for f in C99_COMPLEX_FUNCS]
            conf.check_functions_at_once(flist, use="M")
    else:
        NUMPYCONFIG_SYM.append(('DEFINE_NPY_USE_C99_COMPLEX', ''))
        for t in C99_COMPLEX_TYPES:
            NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_%s' % numpy.build_utils.waf.sanitize_string(t), ''))

def check_win32_specifics(conf):
    from numpy.distutils.misc_util import get_build_architecture
    arch = get_build_architecture()

    # On win32, force long double format string to be 'g', not
    # 'Lg', since the MS runtime does not support long double whose
    # size is > sizeof(double)
    if arch == "Intel" or arch == "AMD64":
        conf.define('FORCE_NO_LONG_DOUBLE_FORMATTING', 1)

@hooks.post_configure
def post_configure(context):
    conf = context.waf_context

    try:
        conf.check_header("endian.h")
        NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_ENDIAN_H',
                                '#define NPY_HAVE_ENDIAN_H 1'))
    except waflib.Errors.ConfigurationError:
        NUMPYCONFIG_SYM.append(('DEFINE_NPY_HAVE_ENDIAN_H', ''))

    try:
        conf.check_declaration('PRIdPTR', header_name='inttypes.h')
        NUMPYCONFIG_SYM.append(('DEFINE_NPY_USE_C99_FORMATS', '#define NPY_USE_C99_FORMATS 1'))
    except waflib.Errors.ConfigurationError:
        NUMPYCONFIG_SYM.append(('DEFINE_NPY_USE_C99_FORMATS', ''))

    type_checks(conf)
    signal_smp_checks(conf)
    check_math_runtime(conf)
    numpy.build_utils.waf.check_inline(conf)
    check_complex(conf)
    check_win32_specifics(conf)

    if ENABLE_SEPARATE_COMPILATION:
        conf.define("ENABLE_SEPARATE_COMPILATION", 1)

    conf.env["CONFIG_HEADER_TEMPLATE"] = """\
%(content)s
#ifndef _NPY_NPY_CONFIG_H_
#error config.h should never be included directly, include npy_config.h instead
#endif"""
    conf.write_config_header("config.h")

    write_numpy_config(conf)

    write_numpy_inifiles(conf)

    conf.env.INCLUDES = [".", "include", "include/numpy"]

    # FIXME: Should be handled in bento context
    conf.store()

class numpy_api_generator(Task):
    vars = ["API_TUPLE"]
    color = "BLUE"
    before = ["c"]
    def run(self):
        targets = [o.path_from(self.generator.bld.srcnode) for o in self.outputs]
        generate_numpy_api.do_generate_api(targets, self.env.API_TUPLE)
        return 0

class ufunc_api_generator(Task):
    vars = ["API_TUPLE"]
    color = "BLUE"
    before = ["c"]
    def run(self):
        targets = [o.path_from(self.generator.bld.srcnode) for o in self.outputs]
        generate_ufunc_api.do_generate_api(targets, self.env.API_TUPLE)
        return 0

@waflib.TaskGen.feature("numpy_api_gen")
def process_multiarray_api_generator(self):
    tsk = self.create_task("numpy_api_generator")
    if hasattr(self, "api_tuple"):
        tsk.env.API_TUPLE = self.api_tuple
    else:
        if not "API_TUPLE" in tsk.env:
            tsk.env.API_TUPLE = ()
    header = "__%s.h" % self.pattern
    source = "__%s.c" % self.pattern
    txt = self.pattern + ".txt"
    files = [header, source, txt]
    tsk.set_outputs([self.path.find_or_declare(f) for f in files])

    self.bld.register_outputs("numpy_gen_headers", "multiarray",
                              [output for output in tsk.outputs if output.suffix() == ".h"],
                              target_dir="$sitedir/numpy/core/include/numpy")

    return tsk

@waflib.TaskGen.feature("ufunc_api_gen")
def process_api_ufunc_generator(self):
    tsk = self.create_task("ufunc_api_generator")
    if hasattr(self, "api_tuple"):
        tsk.env.API_TUPLE = self.api_tuple
    else:
        if not "API_TUPLE" in tsk.env:
            tsk.env.API_TUPLE = ()
    header = "__%s.h" % self.pattern
    source = "__%s.c" % self.pattern
    txt = self.pattern + ".txt"
    files = [header, source, txt]
    tsk.set_outputs([self.path.find_or_declare(f) for f in files])

    headers = [output for output in tsk.outputs if output.suffix() == ".h"]
    self.bld.register_outputs("numpy_gen_headers", "ufunc", headers,
                              target_dir="$sitedir/numpy/core/include/numpy")
    return tsk

class umath_generator(Task):
    vars = ["API_TUPLE"]
    color = "BLUE"
    before = ["c"]
    ext_in = ".in"
    def run(self):
        if len(self.outputs) > 1:
            raise ValueError("Only one target (the .c file) is expected in the umath generator task")
        code = generate_umath.make_code(generate_umath.defdict, generate_umath.__file__)
        self.outputs[0].write(code)
        return 0

@waflib.TaskGen.feature("umath_gen")
def process_umath_generator(self):
    tsk = self.create_task("umath_generator")
    source = "__%s.c" % self.pattern
    tsk.set_outputs(self.path.find_or_declare(source))
    return tsk

from os.path import join as pjoin
@hooks.pre_build
def pre_build(context):
    bld = context.waf_context

    context.register_category("numpy_gen_inifiles")
    inifile_mlib = context.local_node.declare(os.path.join(
                               "lib", "npy-pkg-config", "mlib.ini"))
    inifile_npymath = context.local_node.declare(os.path.join(
                               "lib", "npy-pkg-config", "npymath.ini"))
    context.register_outputs("numpy_gen_inifiles", "numpyconfig",
                             [inifile_mlib, inifile_npymath])

    context.register_category("numpy_gen_headers")

    numpyconfig_h = context.local_node.declare(os.path.join("include", "numpy", "_numpyconfig.h"))
    context.register_outputs("numpy_gen_headers", "numpyconfig", [numpyconfig_h])

    context.tweak_library("lib/npymath", includes=["src/private", "src/npymath", "include"])

    context.tweak_library("npysort",
                          includes=[".", "src/private", "src/npysort"],
                          use="npymath")

    def builder_multiarray(extension):
        bld(name="multiarray_api",
            features="numpy_api_gen",
            api_tuple=multiarray_api,
            pattern="multiarray_api")

        multiarray_templates = ["src/multiarray/scalartypes.c.src",
                "src/multiarray/arraytypes.c.src",
                "src/multiarray/nditer_templ.c.src",
                "src/multiarray/lowlevel_strided_loops.c.src",
                "src/multiarray/einsum.c.src"]
        bld(target="multiarray_templates", source=multiarray_templates)
        if ENABLE_SEPARATE_COMPILATION:
            sources = [pjoin('src', 'multiarray', 'arrayobject.c'),
                pjoin('src', 'multiarray', 'alloc.c'),
                pjoin('src', 'multiarray', 'arraytypes.c.src'),
                pjoin('src', 'multiarray', 'array_assign.c'),
                pjoin('src', 'multiarray', 'array_assign_array.c'),
                pjoin('src', 'multiarray', 'array_assign_scalar.c'),
                pjoin('src', 'multiarray', 'buffer.c'),
                pjoin('src', 'multiarray', 'calculation.c'),
                pjoin('src', 'multiarray', 'common.c'),
                pjoin('src', 'multiarray', 'conversion_utils.c'),
                pjoin('src', 'multiarray', 'convert.c'),
                pjoin('src', 'multiarray', 'convert_datatype.c'),
                pjoin('src', 'multiarray', 'ctors.c'),
                pjoin('src', 'multiarray', 'datetime.c'),
                pjoin('src', 'multiarray', 'datetime_busday.c'),
                pjoin('src', 'multiarray', 'datetime_busdaycal.c'),
                pjoin('src', 'multiarray', 'datetime_strings.c'),
                pjoin('src', 'multiarray', 'descriptor.c'),
                pjoin('src', 'multiarray', 'dtype_transfer.c'),
                pjoin('src', 'multiarray', 'einsum.c.src'),
                pjoin('src', 'multiarray', 'flagsobject.c'),
                pjoin('src', 'multiarray', 'getset.c'),
                pjoin('src', 'multiarray', 'hashdescr.c'),
                pjoin('src', 'multiarray', 'item_selection.c'),
                pjoin('src', 'multiarray', 'iterators.c'),
                pjoin('src', 'multiarray', 'lowlevel_strided_loops.c.src'),
                pjoin('src', 'multiarray', 'mapping.c'),
                pjoin('src', 'multiarray', 'methods.c'),
                pjoin('src', 'multiarray', 'multiarraymodule.c'),
                pjoin('src', 'multiarray', 'nditer_templ.c.src'),
                pjoin('src', 'multiarray', 'nditer_api.c'),
                pjoin('src', 'multiarray', 'nditer_constr.c'),
                pjoin('src', 'multiarray', 'nditer_pywrap.c'),
                pjoin('src', 'multiarray', 'number.c'),
                pjoin('src', 'multiarray', 'numpymemoryview.c'),
                pjoin('src', 'multiarray', 'numpyos.c'),
                pjoin('src', 'multiarray', 'refcount.c'),
                pjoin('src', 'multiarray', 'scalarapi.c'),
                pjoin('src', 'multiarray', 'scalartypes.c.src'),
                pjoin('src', 'multiarray', 'sequence.c'),
                pjoin('src', 'multiarray', 'shape.c'),
                pjoin('src', 'multiarray', 'ucsnarrow.c'),
                pjoin('src', 'multiarray', 'usertypes.c')]
        else:
            sources = extension.sources
        includes = ["src/multiarray", "src/private"]
        return context.default_builder(extension,
                                       includes=includes,
                                       source=sources,
                                       use="npysort npymath",
                                       defines=['_FILE_OFFSET_BITS=64',
                                                '_LARGEFILE_SOURCE=1',
                                                '_LARGEFILE64_SOURCE=1']
                                       )
    context.register_builder("multiarray", builder_multiarray)

    def build_ufunc(extension):
        bld(features="ufunc_api_gen",
            api_tuple=ufunc_api,
            pattern="ufunc_api",
            name="ufunc_api")

        ufunc_templates = [
            "src/umath/loops.h.src",
            "src/umath/loops.c.src",
            "src/umath/funcs.inc.src",
            "src/umath/simd.inc.src"]
        bld(target="ufunc_templates", source=ufunc_templates)

        bld(features="umath_gen",
            pattern="umath_generated",
            name="umath_gen")

        includes = ["src/umath", "src/multiarray", "src/private"]
        if ENABLE_SEPARATE_COMPILATION:
            sources = [
                    pjoin("src", "umath", "loops.h.src"),
                    pjoin("src", "umath", "loops.c.src"),
                    pjoin('src', 'umath', 'reduction.c'),
                    pjoin('src', 'umath', 'ufunc_object.c'),
                    pjoin('src', 'umath', 'ufunc_type_resolution.c'),
                    pjoin("src", "umath", "umathmodule.c"),
            ]
        else:
            sources = extension.sources
        return context.default_builder(extension,
                                       includes=includes,
                                       source=sources,
                                       use="npymath")
    context.register_builder("umath", build_ufunc)

    context.tweak_extension("scalarmath", use="npymath", includes=["src/private"])
    context.tweak_extension("multiarray_tests", use="npymath", includes=["src/private"])
    context.tweak_extension("umath_tests", use="npymath", includes=["src/private"])

    def build_dotblas(extension):
        if bld.env.HAS_CBLAS:
            return context.default_builder(extension, use="CBLAS",
                                           includes=["src/multiarray", "src/private"])
    context.register_builder("_dotblas", build_dotblas)
