Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Compiler Plugins #915

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/source/user-guide/binding/optimization-passes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ create and configure a :class:`PassManagerBuilder`.
methods or :meth:`PassManagerBuilder.populate` to add
optimization passes.

* .. method:: add_pass_by_arg(arg)

Add a pass defined by the supplied ``arg``. See also :meth:`PassRegistry.list_registered_passes`.
Raises ``ValueError`` if no such pass is found.

* .. function:: add_constant_merge_pass()

See `constmerge pass documentation <http://llvm.org/docs/Passes.html#constmerge-merge-duplicate-global-constants>`_.
Expand Down Expand Up @@ -174,3 +179,19 @@ create and configure a :class:`PassManagerBuilder`.

Returns ``True`` if the optimizations made any
modification to the module. Otherwise returns ``False``.

.. function:: load_pass_plugin(path)

Load a shared library defined by ``path`` which contains a custom
LLVM optimization path. See also `'Writing an LLVM Pass' <http://llvm.org/docs/WritingAnLLVMPass.html>`_.

.. class:: PassRegistry()

This class represents the global pass registry for the available optimization passes.

* .. method:: list_registered_passes()

Returns the list of registered passes as a named tuple (arg, name).
``arg`` is the unique identifier used to add passes with :meth:`PassManager.add_pass_by_arg`,
``name`` is the human-readable name of the pass.
See also :func:`load_pass_plugin`.
4 changes: 4 additions & 0 deletions ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ add_definitions(${LLVM_DEFINITIONS})
# Look for SVML
set(CMAKE_REQUIRED_INCLUDES ${LLVM_INCLUDE_DIRS})

if (NOT WIN32)
add_compile_options("-fno-rtti")
endif()

CHECK_INCLUDE_FILES("llvm/IR/SVML.inc" HAVE_SVML)
if(HAVE_SVML)
message(STATUS "SVML found")
Expand Down
4 changes: 2 additions & 2 deletions ffi/Makefile.linux
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Using g++ is recommended for linking
CXX ?= g++

# -flto and --exclude-libs allow us to remove those parts of LLVM we don't use
# -flto allows us to remove those parts of LLVM we don't use
CXX_FLTO_FLAGS ?= -flto
LD_FLTO_FLAGS ?= -flto -Wl,--exclude-libs=ALL
LD_FLTO_FLAGS ?= -flto
# -fPIC is required when compiling objects for a shared library
CXX_FPIC_FLAGS ?= -fPIC

Expand Down
51 changes: 51 additions & 0 deletions ffi/passmanagers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/YAMLTraits.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ADT/SmallString.h"

#include "llvm-c/Transforms/IPO.h"
#include "llvm-c/Transforms/Scalar.h"
Expand All @@ -21,6 +22,8 @@
#include "llvm/Remarks/RemarkStreamer.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/PassRegistry.h"
#include "llvm/Pass.h"

#include <llvm/IR/PassTimingInfo.h>

Expand Down Expand Up @@ -466,4 +469,52 @@ LLVMPY_LLVMAddLoopRotatePass(LLVMPassManagerRef PM) {
LLVMAddLoopRotatePass(PM);
}

API_EXPORT(bool)
LLVMPY_AddPassByArg(LLVMPassManagerRef PM, const char* passArg) {
auto passManager = llvm::unwrap(PM);
auto registry = PassRegistry::getPassRegistry();
const auto* passInfo = registry->getPassInfo(llvm::StringRef(passArg));
if (passInfo == nullptr) {
return false;
}
auto pass = passInfo->createPass();
passManager->add(pass);
return true;
}

namespace {
class NameSavingPassRegListener : public llvm::PassRegistrationListener {
public:
NameSavingPassRegListener() : stream_buf_(), sstream_(stream_buf_) { }
void passEnumerate(const llvm::PassInfo* passInfo) override {
if (stream_buf_.size() != 0) {
sstream_ << ",";

}
sstream_ << passInfo->getPassArgument() << ":" << passInfo->getPassName();
}

const char* getNames() {
return stream_buf_.c_str();
}

private:
llvm::SmallString<128> stream_buf_;
llvm::raw_svector_ostream sstream_;
};
}

API_EXPORT(LLVMPassRegistryRef)
LLVMPY_GetPassRegistry() {
return LLVMGetGlobalPassRegistry();
}

API_EXPORT(void)
LLVMPY_ListRegisteredPasses(LLVMPassRegistryRef PR, const char** out) {
auto registry = unwrap(PR);
auto listener = std::make_unique<NameSavingPassRegListener>();
registry->enumerateWith(listener.get());
*out = LLVMPY_CreateString(listener->getNames());
}

} // end extern "C"
1 change: 1 addition & 0 deletions llvmlite/binding/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _make_opaque_ref(name):
LLVMExecutionEngineRef = _make_opaque_ref("LLVMExecutionEngine")
LLVMPassManagerBuilderRef = _make_opaque_ref("LLVMPassManagerBuilder")
LLVMPassManagerRef = _make_opaque_ref("LLVMPassManager")
LLVMPassRegistryRef = _make_opaque_ref("LLVMPassRegistryRef")
LLVMTargetDataRef = _make_opaque_ref("LLVMTargetData")
LLVMTargetLibraryInfoRef = _make_opaque_ref("LLVMTargetLibraryInfo")
LLVMTargetRef = _make_opaque_ref("LLVMTarget")
Expand Down
42 changes: 41 additions & 1 deletion llvmlite/binding/passmanagers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ctypes import c_bool, c_char_p, c_int, c_size_t, c_uint, Structure, byref
from ctypes import c_bool, c_char_p, c_int, c_size_t, c_uint, Structure, byref, cdll, POINTER
from collections import namedtuple
from enum import IntFlag
from llvmlite.binding import ffi
Expand Down Expand Up @@ -100,10 +100,39 @@ class RefPruneSubpasses(IntFlag):
ALL = PER_BB | DIAMOND | FANOUT | FANOUT_RAISE


# Maintain a global list of loaded libraries.
_loaded_libraries = []

def load_pass_plugin(path):
"""Load shared library containing a pass"""
pass_lib = cdll.LoadLibrary(path)
pass_lib.LLVMPY_RegisterPass(get_pass_registry())
_loaded_libraries.append(pass_lib)

class PassRegistry(ffi.ObjectRef):
"""
Return the names of the registered passes as a list
of tuples (<pass arg>, <pass name>)
"""
def list_registered_passes(self):
PassInfo = namedtuple("PassInfo", ["arg", "name"])
with ffi.OutputString() as out:
ffi.lib.LLVMPY_ListRegisteredPasses(self, out)
passes = [PassInfo(*pass_info.split(":", 1)) \
for pass_info in str(out).split(',')]
return passes

def get_pass_registry():
return PassRegistry(ffi.lib.LLVMPY_GetPassRegistry())


class PassManager(ffi.ObjectRef):
"""PassManager
"""

def __init__(self, ptr):
super(PassManager, self).__init__(ptr)

def _dispose(self):
self._capi.LLVMPY_DisposePassManager(self)

Expand Down Expand Up @@ -131,6 +160,11 @@ def add_constant_merge_pass(self):
""" # noqa E501
ffi.lib.LLVMPY_AddConstantMergePass(self)

def add_pass_by_arg(self, pass_arg):
"""Add a pass using its registered name"""
if not ffi.lib.LLVMPY_AddPassByArg(self, pass_arg.encode("utf8")):
raise ValueError("Could not add pass '{}'".format(pass_arg))

def add_dead_arg_elimination_pass(self):
"""
See http://llvm.org/docs/Passes.html#deadargelim-dead-argument-elimination
Expand Down Expand Up @@ -914,3 +948,9 @@ def run_with_remarks(self, function, remarks_format='yaml',

ffi.lib.LLVMPY_AddRefPrunePass.argtypes = [ffi.LLVMPassManagerRef, c_int,
c_size_t]

ffi.lib.LLVMPY_GetPassRegistry.argtypes = []
ffi.lib.LLVMPY_GetPassRegistry.restype = ffi.LLVMPassRegistryRef
ffi.lib.LLVMPY_ListRegisteredPasses.argtypes = [ffi.LLVMPassRegistryRef, POINTER(c_char_p)]
ffi.lib.LLVMPY_AddPassByArg.argtypes = [ffi.LLVMPassManagerRef, c_char_p]
ffi.lib.LLVMPY_AddPassByArg.restype = c_bool