Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Libraryimport src gen audit #69619

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ public AnsiStringMarshaller(string? str, Span<byte> buffer)
return;
}

int maxBytesNeeded = checked(Marshal.SystemMaxDBCSCharSize * str.Length);

// >= for null terminator
if ((long)Marshal.SystemMaxDBCSCharSize * str.Length >= buffer.Length)
if (maxBytesNeeded >= buffer.Length)
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
{
// Calculate accurate byte count when the provided stack-allocated buffer is not sufficient
int exactByteCount = Marshal.GetAnsiStringByteCount(str); // Includes null terminator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public ArrayMarshaller(T[]? array, Span<byte> buffer, int sizeOfNativeElement)
_managedArray = array;

// Always allocate at least one byte when the array is zero-length.
int spaceToAllocate = Math.Max(array.Length * _sizeOfNativeElement, 1);
int bufferSize = checked(array.Length * _sizeOfNativeElement);
int spaceToAllocate = Math.Max(bufferSize, 1);
if (spaceToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
Expand Down Expand Up @@ -107,7 +108,12 @@ public ArrayMarshaller(T[]? array, Span<byte> buffer, int sizeOfNativeElement)
/// </remarks>
public ReadOnlySpan<byte> GetNativeValuesSource(int length)
{
return _allocatedMemory == IntPtr.Zero ? default : _span = new Span<byte>((void*)_allocatedMemory, length * _sizeOfNativeElement);
if (_allocatedMemory == IntPtr.Zero)
return default;
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

int allocatedSize = checked(length * _sizeOfNativeElement);
_span = new Span<byte>((void*)_allocatedMemory, allocatedSize);
return _span;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public PointerArrayMarshaller(T*[]? array, Span<byte> buffer, int sizeOfNativeEl
_managedArray = array;

// Always allocate at least one byte when the array is zero-length.
int spaceToAllocate = Math.Max(array.Length * _sizeOfNativeElement, 1);
int bufferSize = checked(array.Length * _sizeOfNativeElement);
int spaceToAllocate = Math.Max(bufferSize, 1);
if (spaceToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
Expand Down Expand Up @@ -117,7 +118,8 @@ public ReadOnlySpan<byte> GetNativeValuesSource(int length)
if (_allocatedMemory == IntPtr.Zero)
return default;

_span = new Span<byte>((void*)_allocatedMemory, length * _sizeOfNativeElement);
int allocatedSize = checked(length * _sizeOfNativeElement);
_span = new Span<byte>((void*)_allocatedMemory, allocatedSize);
return _span;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ public Utf8StringMarshaller(string? str, Span<byte> buffer)
}

const int MaxUtf8BytesPerChar = 3;
int maxBytesNeeded = checked(MaxUtf8BytesPerChar * str.Length);
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

// >= for null terminator
if ((long)MaxUtf8BytesPerChar * str.Length >= buffer.Length)
if (maxBytesNeeded >= buffer.Length)
{
// Calculate accurate byte count when the provided stack-allocated buffer is not sufficient
int exactByteCount = checked(Encoding.UTF8.GetByteCount(str) + 1); // + 1 for null terminator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou

ImmutableArray<AttributeSyntax> forwardedAttributes = pinvokeStub.ForwardedAttributes;

const string innerPInvokeName = "__PInvoke__";
const string innerPInvokeName = "__PInvoke";

BlockSyntax code = stubGenerator.GeneratePInvokeBody(innerPInvokeName);

Expand All @@ -428,10 +428,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou
dllImport = dllImport.AddAttributeLists(AttributeList(SeparatedList(forwardedAttributes)));
}

dllImport = dllImport.WithLeadingTrivia(
Comment("//"),
Comment("// Local P/Invoke"),
Comment("//"));
dllImport = dllImport.WithLeadingTrivia(Comment("// Local P/Invoke"));
code = code.AddStatements(dllImport);

return (pinvokeStub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(PrintGeneratedSource(pinvokeStub.StubMethodSyntaxTemplate, pinvokeStub.SignatureContext, code)), pinvokeStub.Diagnostics.AddRange(diagnostics.Diagnostics));
Expand Down Expand Up @@ -514,7 +511,7 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
SyntaxKind.StringLiteralExpression,
Literal(libraryImportData.EntryPoint ?? stubMethodName))),
AttributeArgument(
NameEquals(nameof(DllImportAttribute.ExactSpelling)),
NameEquals(nameof(DllImportAttribute.SetLastError)),
null,
LiteralExpression(
libraryImportData.SetLastError ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ private static ImmutableArray<StatementSyntax> GenerateStatementsForStubContext(
if (statementsToUpdate.Count > 0)
{
// Comment separating each stage
SyntaxTriviaList newLeadingTrivia = TriviaList(
Comment($"//"),
Comment($"// {context.CurrentStage}"),
Comment($"//"));
SyntaxTriviaList newLeadingTrivia = GenerateStageTrivia(context.CurrentStage);
StatementSyntax firstStatementInStage = statementsToUpdate[0];
newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia());
statementsToUpdate[0] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia);
Expand Down Expand Up @@ -108,5 +105,24 @@ private static StatementSyntax GenerateStatementForNativeInvoke(BoundGenerators
IdentifierName(context.GetIdentifiers(marshallers.NativeReturnMarshaller.TypeInfo).native),
invoke));
}

private static SyntaxTriviaList GenerateStageTrivia(StubCodeContext.Stage stage)
{
string comment = stage switch
{
StubCodeContext.Stage.Setup => "Perform required setup.",
StubCodeContext.Stage.Marshal => "Convert managed data to native data.",
StubCodeContext.Stage.Pin => "Pin data in preparation for calling the P/Invoke.",
StubCodeContext.Stage.Invoke => "Call the P/Invoke.",
StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.",
StubCodeContext.Stage.Cleanup => "Perform required cleanup.",
StubCodeContext.Stage.KeepAlive => "Keep alive any managed objects that need to stay alive across the call.",
StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.",
_ => throw new ArgumentOutOfRangeException(nameof(stage))
};
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

// Comment separating each stage
return TriviaList(Comment($"// {stage} - {comment}"));
}
}
}