Skip to content

Commit

Permalink
TransformIndirectLoadChain at JITServer
Browse files Browse the repository at this point in the history
Implement TransformIndirectLoadChain partially for the JITServer
so it can employ the Vector API during optimization.

Signed-off-by: Luke Li <[email protected]>
  • Loading branch information
luke-li-2003 committed Dec 10, 2024
1 parent db18829 commit d66ba1d
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 6 deletions.
45 changes: 45 additions & 0 deletions runtime/compiler/control/JITClientCompilationThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2950,6 +2950,51 @@ handleServerMessage(JITServer::ClientStream *client, TR_J9VM *fe, JITServer::Mes
client->write(response, vectorBitSize);
}
break;
case MessageType::KnownObjectTable_addFieldAddressFromBaseIndex:
{
auto recv = client->getRecvData<TR::KnownObjectTable::Index, intptr_t>();
TR::KnownObjectTable::Index baseObjectIndex = std::get<0>(recv);
intptr_t fieldOffset = std::get<1>(recv);

TR::KnownObjectTable::Index resultIndex = TR::KnownObjectTable::UNKNOWN;

{
TR::VMAccessCriticalSection addFieldAddressFromBaseIndex(fe);
uintptr_t baseObjectAddress = knot->getPointer(baseObjectIndex);
uintptr_t fieldAddress = baseObjectAddress + fieldOffset;

uintptr_t objectPointer = fe->getReferenceFieldAtAddress(fieldAddress);

if (objectPointer)
resultIndex = knot->getOrCreateIndex(objectPointer);
}

uintptr_t *resultPointer =
(resultIndex == -1) ? NULL : knot->getPointerLocation(resultIndex);

client->write(response, resultIndex, resultPointer);
}
break;
case MessageType::KnownObjectTable_getFieldAddressData:
{
auto recv = client->getRecvData<TR::KnownObjectTable::Index, intptr_t>();
TR::KnownObjectTable::Index baseObjectIndex = std::get<0>(recv);
intptr_t fieldOffset = std::get<1>(recv);

UDATA data = 0;

{
TR::VMAccessCriticalSection addFieldAddressFromBaseIndex(fe);
uintptr_t baseObjectAddress = knot->getPointer(baseObjectIndex);

uintptr_t fieldAddress = baseObjectAddress + fieldOffset;

data = *(UDATA *) fieldAddress;
}

client->write(response, data);
}
break;
case MessageType::AOTCache_getROMClassBatch:
{
auto recv = client->getRecvData<std::vector<J9Class *>>();
Expand Down
3 changes: 2 additions & 1 deletion runtime/compiler/env/VMJ9Server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class TR_J9ServerVM: public TR_J9VM
virtual intptr_t getVFTEntry(TR_OpaqueClassBlock *clazz, int32_t offset) override;
virtual bool isClassArray(TR_OpaqueClassBlock *klass) override;
virtual uintptr_t getFieldOffset(TR::Compilation * comp, TR::SymbolReference* classRef, TR::SymbolReference* fieldRef) override { return 0; } // safe answer
virtual bool canDereferenceAtCompileTime(TR::SymbolReference *fieldRef, TR::Compilation *comp) override { return false; } // safe answer, might change in the future
// The base version should be safe, no need to override.
// virtual bool canDereferenceAtCompileTime(TR::SymbolReference *fieldRef, TR::Compilation *comp) override; // safe answer, might change in the future
virtual bool instanceOfOrCheckCast(J9Class *instanceClass, J9Class* castClass) override;
virtual bool instanceOfOrCheckCastNoCacheUpdate(J9Class *instanceClass, J9Class* castClass) override;
virtual bool transformJlrMethodInvoke(J9Method *callerMethod, J9Class *callerClass) override;
Expand Down
2 changes: 1 addition & 1 deletion runtime/compiler/net/CommunicationStream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class CommunicationStream
// likely to lose an increment when merging/rebasing/etc.
//
static const uint8_t MAJOR_NUMBER = 1;
static const uint16_t MINOR_NUMBER = 75; // ID: kzkyjklaOnYjEzzJyIl7
static const uint16_t MINOR_NUMBER = 76; // ID: BpR0Syhau116Bh0vAoVr
static const uint8_t PATCH_NUMBER = 0;
static uint32_t CONFIGURATION_FLAGS;

Expand Down
2 changes: 2 additions & 0 deletions runtime/compiler/net/MessageTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ const char *messageNames[] =
"KnownObjectTable_getKnownObjectTableDumpInfo",
"KnownObjectTable_getOpaqueClass",
"KnownObjectTable_getVectorBitSize",
"KnownObjectTable_addFieldAddressFromBaseIndex",
"KnownObjectTable_getFieldAddressData",
"AOTCache_getROMClassBatch",
"AOTCacheMap_request",
"AOTCacheMap_reply"
Expand Down
3 changes: 3 additions & 0 deletions runtime/compiler/net/MessageTypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ enum MessageType : uint16_t
KnownObjectTable_getOpaqueClass,
// for getting a vectorBitSize from KnownObjectTable
KnownObjectTable_getVectorBitSize,
// used with J9TransformUtil
KnownObjectTable_addFieldAddressFromBaseIndex,
KnownObjectTable_getFieldAddressData,

AOTCache_getROMClassBatch,

Expand Down
226 changes: 222 additions & 4 deletions runtime/compiler/optimizer/J9TransformUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1717,12 +1717,20 @@ bool
J9::TransformUtil::transformIndirectLoadChain(TR::Compilation *comp, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, TR::Node **removedNode)
{
#if defined(J9VM_OPT_JITSERVER)
// JITServer KOT: Bypass this method at the JITServer.
// transformIndirectLoadChainImpl requires access to the VM.
// It is already bypassed by transformIndirectLoadChainAt().
// Under JITServer, call a simplified version of transformIndirectLoadChain
// that does not access the VM
if (comp->isOutOfProcessCompilation())
{
return false;
int32_t stableArrayRank =
comp->getKnownObjectTable()->getArrayWithStableElementsRank(baseKnownObject);
bool result =
TR::TransformUtil::transformIndirectLoadChainServerImpl(comp,
node,
baseExpression,
baseKnownObject,
stableArrayRank,
removedNode);
return result;
}
#endif /* defined(J9VM_OPT_JITSERVER) */

Expand All @@ -1733,6 +1741,216 @@ J9::TransformUtil::transformIndirectLoadChain(TR::Compilation *comp, TR::Node *n
return result;
}

#if defined(J9VM_OPT_JITSERVER)
/** Dereference node and fold it into a constant when possible.
*
* A simpler version of transformIndirectLoadChain() for the JITServer mode, which only considers
* the case where the node's symRef is a Java field.
*/
bool
J9::TransformUtil::transformIndirectLoadChainServerImpl(TR::Compilation *comp, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, int32_t baseStableArrayRank, TR::Node **removedNode)
{
bool isBaseStableArray = baseStableArrayRank > 0;
TR_J9VMBase *fej9 = comp->fej9();

TR_ASSERT(node->getOpCode().isLoadIndirect(),
"Expecting indirect load; found %s %p", node->getOpCode().getName(), node);
TR_ASSERT(node->getNumChildren() == 1,
"Expecting indirect load %s %p to have one child; actually has %d",
node->getOpCode().getName(), node, node->getNumChildren());

TR::SymbolReference *symRef = node->getSymbolReference();

if (comp->compileRelocatableCode() ||
(isBaseStableArray && !symRef->getSymbol()->isArrayShadowSymbol()) ||
symRef->hasKnownObjectIndex())
{
return false;
}

// Ignore the case of the J9Class whose finality is conditional on the holding value for now.
if (!symRef->isUnresolved() &&
symRef == comp->getSymRefTab()->findInitializeStatusFromClassSymbolRef())
{
return false;
}

if (!isBaseStableArray && !fej9->canDereferenceAtCompileTime(symRef, comp))
{
if (comp->getOption(TR_TraceOptDetails))
{
traceMsg(comp, "Abort transformIndirectLoadChain - cannot dereference at compile time!\n");
}
return false;
}


// Instead of the recursive dereferenceStructPointerChain, we only consider a single level
// of indirection
TR::Symbol *field = symRef->getSymbol();
TR::Node *addressChildNode = field->isArrayShadowSymbol() ?
node->getFirstChild()->getFirstChild() :
node->getFirstChild();
if (!addressChildNode->getOpCode().hasSymbolReference()
|| addressChildNode != baseExpression)
return false;
// baseStruct is always the value of baseExpression; dereference is not needed

// We only consider the case where isJavaField is true for verifyFieldAccess
if (isJavaField(symRef, comp))
{
TR_OpaqueClassBlock *fieldClass = NULL;

if (symRef->getCPIndex() < 0 &&
field->getRecognizedField() != TR::Symbol::UnknownField)
{
const char* className;
int32_t length;
className = field->owningClassNameCharsForRecognizedField(length);
fieldClass = fej9->getClassFromSignature(className, length, symRef->getOwningMethod(comp));
}
else
fieldClass = symRef->getOwningMethod(comp)->getDeclaringClassFromFieldOrStatic(comp,
symRef->getCPIndex());

TR_OpaqueClassBlock *objectClass =
fej9->getObjectClassFromKnownObjectIndex(comp, baseKnownObject);

// field access verified
if ((fieldClass != NULL) && (fej9->isInstanceOf(objectClass, fieldClass, true) == TR_yes))
{

// check the recognized fields case of avoidFoldingInstanceField
// the non-null checks are done when we get the actual values
if (field->getRecognizedField() == TR::Symbol::Java_lang_invoke_CallSite_target ||
field->getRecognizedField() == TR::Symbol::Java_lang_invoke_MethodHandle_form)
return false;

TR::DataType loadType = node->getDataType();

if (loadType == TR::Address)
{
if (isFinalFieldPointingAtRepresentableNativeStruct(symRef, comp) ||
isFinalFieldPointingAtNativeStruct(symRef, comp))
{
return false;
}
else if (field->isCollectedReference())
{
auto stream = comp->getStream();
stream->write(
JITServer::MessageType::KnownObjectTable_addFieldAddressFromBaseIndex,
baseKnownObject, symRef->getOffset());
auto recv = stream->read<TR::KnownObjectTable::Index, uintptr_t *>();
TR::KnownObjectTable::Index value = std::get<0>(recv);
uintptr_t *objectReferenceLocationClient = std::get<1>(recv);
comp->getKnownObjectTable()->updateKnownObjectTableAtServer(
value,
objectReferenceLocationClient
);

if (value != -1)
{
TR::SymbolReference *improvedSymRef =
comp->getSymRefTab()->findOrCreateSymRefWithKnownObject(symRef, value);

if (improvedSymRef->hasKnownObjectIndex()
&& performTransformation(comp,
"O^O transformIndirectLoadChain: %s [%p] with fieldOffset %d is obj%d referenceAddr is %p\n", node->getOpCode().getName(), node, improvedSymRef->getKnownObjectIndex(), symRef->getOffset(), value))
{
node->setSymbolReference(improvedSymRef);
node->setIsNull(false);
node->setIsNonNull(true);

int32_t stableArrayRank = isArrayWithStableElements(symRef->getCPIndex(),
symRef->getOwningMethod(comp),
comp);
if (isBaseStableArray)
stableArrayRank = baseStableArrayRank - 1;

if (stableArrayRank > 0)
{
TR::KnownObjectTable *knot = comp->getOrCreateKnownObjectTable();
knot->addStableArray(improvedSymRef->getKnownObjectIndex(),
stableArrayRank);
}
return true;
}
else /* has known object index */
{
return false;
}
}
else /* value != -1 */
{
return false;
}
}
else /* collected reference */
{
return false;
}
}
else // non-address types
{
auto stream = comp->getStream();
stream->write(
JITServer::MessageType::KnownObjectTable_getFieldAddressData,
baseKnownObject, symRef->getOffset());
UDATA data = std::get<0>(stream->read<UDATA>());

if (data == 0)
return false;

switch (loadType)
{
case TR::Int32:
{
int32_t value = (int32_t)data;
if (changeIndirectLoadIntoConst(node, TR::iconst, removedNode, comp))
node->setInt(value);
else
return false;
}
break;
case TR::Int64:
{
int64_t value = (int64_t)data;
if (changeIndirectLoadIntoConst(node, TR::lconst, removedNode, comp))
node->setLongInt(value);
else
return false;
}
break;
case TR::Float:
{
float value = (float)data;
if (changeIndirectLoadIntoConst(node, TR::fconst, removedNode, comp))
node->setFloat(value);
else
return false;
}
break;
case TR::Double:
{
double value = (double)data;
if (changeIndirectLoadIntoConst(node, TR::dconst, removedNode, comp))
node->setDouble(value);
else
return false;
}
break;
default:
return false;
}
return true;
}
}
}
return false;
}
#endif /* defined(J9VM_OPT_JITSERVER) */

/** Dereference node and fold it into a constant when possible
*
* @parm comp The compilation object
Expand Down
3 changes: 3 additions & 0 deletions runtime/compiler/optimizer/J9TransformUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ class OMR_EXTENSIBLE TransformUtil : public OMR::TransformUtilConnector
static bool transformIndirectLoadChain(TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, TR::Node **removedNode);
static bool transformIndirectLoadChainAt(TR::Compilation *, TR::Node *node, TR::Node *baseExpression, uintptr_t *baseReferenceLocation, TR::Node **removedNode);
static bool transformIndirectLoadChainImpl( TR::Compilation *, TR::Node *node, TR::Node *baseExpression, void *baseAddress, int32_t baseStableArrayRank, TR::Node **removedNode);
#if defined(J9VM_OPT_JITSERVER)
static bool transformIndirectLoadChainServerImpl( TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, int32_t baseStableArrayRank, TR::Node **removedNode);
#endif /* defined(J9VM_OPT_JITSERVER) */

static bool fieldShouldBeCompressed(TR::Node *node, TR::Compilation *comp);

Expand Down

0 comments on commit d66ba1d

Please sign in to comment.