diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 450154598..db2f67127 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -82,4 +82,16 @@ std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::stri Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); + +/** A number of built-in functions are magical enough that we need to match on them specifically by + * name when they are called. These are listed here to be used whenever necessary, instead of duplicating this logic repeatedly. + */ + +bool matchSetMetatable(const AstExprCall& call); +bool matchTableFreeze(const AstExprCall& call); +bool matchAssert(const AstExprCall& call); + +// Returns `true` if the function should introduce typestate for its first argument. +bool shouldTypestateForFirstArgument(const AstExprCall& call); + } // namespace Luau diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 103b5bbd5..b0c8fd17c 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -22,6 +22,15 @@ struct CloneState SeenTypePacks seenTypePacks; }; +/** `shallowClone` will make a copy of only the _top level_ constructor of the type, + * while `clone` will make a deep copy of the entire type and its every component. + * + * Be mindful about which behavior you actually _want_. + */ + +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState); + TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 600574f06..435c62fb6 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -146,6 +146,8 @@ struct ConstraintGenerator */ void visitModuleRoot(AstStatBlock* block); + void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block); + private: std::vector> interiorTypes; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4d38118af..c9336c1d0 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -3,6 +3,7 @@ #pragma once #include "Luau/Constraint.h" +#include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" #include "Luau/Location.h" @@ -69,6 +70,9 @@ struct ConstraintSolver NotNull rootScope; ModuleName currentModuleName; + // The dataflow graph of the program, used in constraint generation and for magic functions. + NotNull dfg; + // Constraints that the solver has generated, rather than sourcing from the // scope tree. std::vector> solverConstraints; @@ -120,6 +124,7 @@ struct ConstraintSolver NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, + NotNull dfg, TypeCheckLimits limits ); @@ -167,9 +172,9 @@ struct ConstraintSolver */ bool tryDispatch(NotNull c, bool force); - bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint); bool tryDispatch(const IterableConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); @@ -194,14 +199,14 @@ struct ConstraintSolver bool tryDispatch(const UnpackConstraint& c, NotNull constraint); bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const EqualityConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint); std::pair, std::optional> lookupTableProp( NotNull constraint, diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 718a03506..662e50aa1 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -35,6 +35,8 @@ struct DataFlowGraph DataFlowGraph& operator=(DataFlowGraph&&) = default; DefId getDef(const AstExpr* expr) const; + // Look up the definition optionally, knowing it may not be present. + std::optional getDefOptional(const AstExpr* expr) const; // Look up for the rvalue def for a compound assignment. std::optional getRValueDefForCompoundAssign(const AstExpr* expr) const; diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index bfc5f6e69..671cbb693 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -12,6 +12,7 @@ namespace Luau { +struct FrontendOptions; struct FragmentAutocompleteAncestryResult { @@ -29,15 +30,30 @@ struct FragmentParseResult std::unique_ptr alloc = std::make_unique(); }; +struct FragmentTypeCheckResult +{ + ModulePtr incrementalModule = nullptr; + Scope* freshScope = nullptr; +}; + FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos); +FragmentTypeCheckResult typecheckFragment( + Frontend& frontend, + const ModuleName& moduleName, + const Position& cursorPos, + std::optional opts, + std::string_view src +); + AutocompleteResult fragmentAutocomplete( Frontend& frontend, std::string_view src, const ModuleName& moduleName, Position& cursorPosition, + std::optional opts, StringCompletionCallback callback ); diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index 0fd2817ab..73345f98c 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -60,7 +60,7 @@ struct ReplaceGenerics : Substitution }; // A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution +struct Instantiation final : Substitution { Instantiation(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) : Substitution(log, arena) diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index d844d211d..97d13a600 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -23,8 +23,6 @@ using ModulePtr = std::shared_ptr; bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); class TypeIds { @@ -336,6 +334,7 @@ struct NormalizedType }; +using SeenTablePropPairs = Set, TypeIdPairHash>; class Normalizer { @@ -390,7 +389,13 @@ class Normalizer void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTables(TypeIds& heres, const TypeIds& theres); NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - NormalizationResult unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars = -1); + NormalizationResult unionNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes, + int ignoreSmallerTyvars = -1 + ); // ------- Negations std::optional negateNormal(const NormalizedType& here); @@ -407,16 +412,26 @@ class Normalizer void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); - std::optional intersectionOfTables(TypeId here, TypeId there, Set& seenSet); - void intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes); + std::optional intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSet); + void intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes); void intersectTables(TypeIds& heres, const TypeIds& theres); std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); - NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes); + NormalizationResult intersectTyvarsWithTy( + NormalizedTyvars& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes + ); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes); - NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet); + NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes); + NormalizationResult normalizeIntersections( + const std::vector& intersections, + NormalizedType& outType, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSet + ); // Check for inhabitance NormalizationResult isInhabited(TypeId ty); @@ -426,7 +441,7 @@ class Normalizer // Check for intersections being inhabited NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); - NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set& seenSet); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 55089aa3e..d100fa4d7 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -806,6 +806,13 @@ struct Type final Type& operator=(const TypeVariant& rhs); Type& operator=(TypeVariant&& rhs); + Type(Type&&) = default; + Type& operator=(Type&&) = default; + + Type clone() const; + +private: + Type(const Type&) = default; Type& operator=(const Type& rhs); }; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index b0a855d37..3de841ed2 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -179,7 +179,7 @@ struct Unifier bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed); bool occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); - Unifier makeChildUnifier(); + std::unique_ptr makeChildUnifier(); void reportError(TypeError err); LUAU_NOINLINE void reportError(Location location, TypeErrorData data); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index f2235bb96..c89d77931 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -16,7 +16,6 @@ #include LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit) LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions, false) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) @@ -157,11 +156,8 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T NotNull{&iceReporter}, NotNull{&limits} }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime - if (FFlag::LuauAutocompleteNewSolverLimit) - { - unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; - } + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 041d1bed6..84d2d6e9d 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -2,6 +2,8 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Ast.h" +#include "Luau/Clone.h" +#include "Luau/Error.h" #include "Luau/Frontend.h" #include "Luau/Symbol.h" #include "Luau/Common.h" @@ -25,9 +27,12 @@ * about a function that takes any number of values, but where each value must have some specific type. */ -LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins, false) +LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix, false) -LUAU_FASTFLAG(AutocompleteRequirePathSuggestions); +LUAU_FASTFLAG(AutocompleteRequirePathSuggestions) namespace Luau { @@ -67,6 +72,7 @@ static std::optional> magicFunctionRequire( static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionPack(MagicFunctionCallContext context); +static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -395,8 +401,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // but it'll be ok for now. TypeId genericTy = arena.addType(GenericType{"T"}); TypePackId thePack = arena.addTypePack({genericTy}); + TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); + ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze"); + TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); - ttv->props["freeze"] = makeProperty(idTy, "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone"); } else @@ -413,6 +421,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); + if (FFlag::LuauTypestateBuiltins) + attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); } if (FFlag::AutocompleteRequirePathSuggestions) @@ -574,7 +584,11 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex fmt = context.callSite->args.data[0]->as(); if (!fmt) + { + if (FFlag::LuauStringFormatArityFix) + context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location); return; + } std::vector expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(context.arguments); @@ -1324,6 +1338,58 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) return true; } +static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context) +{ + LUAU_ASSERT(FFlag::LuauTypestateBuiltins); + + TypeArena* arena = context.solver->arena; + const DataFlowGraph* dfg = context.solver->dfg.get(); + Scope* scope = context.constraint->scope.get(); + + const auto& [paramTypes, paramTail] = extendTypePack(*arena, context.solver->builtinTypes, context.arguments, 1); + LUAU_ASSERT(paramTypes.size() >= 1); + + TypeId inputType = follow(paramTypes.at(0)); + + // we'll check if it's a table first since this magic function also produces the error if it's not until we have bounded generics + if (!get(inputType)) + { + context.solver->reportError(TypeMismatch{context.solver->builtinTypes->tableType, inputType}, context.callSite->argLocation); + return false; + } + + AstExpr* targetExpr = context.callSite->args.data[0]; + std::optional resultDef = dfg->getDefOptional(targetExpr); + std::optional resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt; + + // Clone the input type, this will become our final result type after we mutate it. + CloneState cloneState{context.solver->builtinTypes}; + TypeId clonedType = shallowClone(inputType, *arena, cloneState); + auto tableTy = getMutable(clonedType); + // `clone` should not break this. + LUAU_ASSERT(tableTy); + tableTy->state = TableState::Sealed; + tableTy->syntheticName = std::nullopt; + + // We'll mutate the table to make every property type read-only. + for (auto iter = tableTy->props.begin(); iter != tableTy->props.end();) + { + if (iter->second.isWriteOnly()) + iter = tableTy->props.erase(iter); + else + { + iter->second.writeTy = std::nullopt; + iter++; + } + } + + if (resultTy) + asMutable(*resultTy)->ty.emplace(clonedType); + asMutable(context.result)->ty.emplace(arena->addTypePack({clonedType})); + + return true; +} + static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) { // require(foo.parent.bar) will technically work, but it depends on legacy goop that @@ -1415,4 +1481,52 @@ static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) return false; } +bool matchSetMetatable(const AstExprCall& call) +{ + const char* smt = "setmetatable"; + + if (call.args.size != 2) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != smt) + return false; + + return true; +} + +bool matchTableFreeze(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprIndexName* index = call.func->as(); + if (!index || index->index != "freeze") + return false; + + const AstExprGlobal* global = index->expr->as(); + if (!global || global->name != "table") + return false; + + return true; +} + +bool matchAssert(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != "assert") + return false; + + return true; +} + +bool shouldTypestateForFirstArgument(const AstExprCall& call) +{ + // TODO: magic function for setmetatable and assert and then add them + return matchTableFreeze(call); +} + } // namespace Luau diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index d5793c932..745a03074 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -140,7 +140,7 @@ class TypeCloner } } -private: +public: TypeId shallowClone(TypeId ty) { // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. @@ -189,6 +189,7 @@ class TypeCloner return target; } +private: Property shallowClone(const Property& p) { if (FFlag::LuauSolverV2) @@ -453,6 +454,24 @@ class TypeCloner } // namespace +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState) +{ + if (tp->persistent) + return tp; + + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.shallowClone(tp); +} + +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState) +{ + if (typeId->persistent) + return typeId; + + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.shallowClone(typeId); +} + TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 9c30668c4..e242df8ec 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -2,11 +2,12 @@ #include "Luau/ConstraintGenerator.h" #include "Luau/Ast.h" -#include "Luau/Def.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DcrLogger.h" +#include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" @@ -30,6 +31,9 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAG(LuauTypestateBuiltins) + +LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues, false) namespace Luau { @@ -54,20 +58,6 @@ static std::optional matchRequire(const AstExprCall& call) return call.args.data[0]; } -static bool matchSetmetatable(const AstExprCall& call) -{ - const char* smt = "setmetatable"; - - if (call.args.size != 2) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != smt) - return false; - - return true; -} - struct TypeGuard { bool isTypeof; @@ -110,18 +100,6 @@ static std::optional matchTypeGuard(const AstExprBinary* binary) }; } -static bool matchAssert(const AstExprCall& call) -{ - if (call.args.size < 1) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != "assert") - return false; - - return true; -} - namespace { @@ -285,6 +263,31 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) } } +void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) +{ + visitBlockWithoutChildScope(resumeScope, block); + fillInInferredBindings(resumeScope, block); + + if (logger) + logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + { + d = follow(d); + if (d == ty) + continue; + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } +} + TypeId ConstraintGenerator::freshType(const ScopePtr& scope) { return Luau::freshType(arena, builtinTypes, scope.get()); @@ -1075,9 +1078,17 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); else if (const AstExprCall* call = value->as()) { - if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") + if (FFlag::LuauTypestateBuiltins) + { + if (matchSetMetatable(*call)) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + } + else { - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") + { + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); + } } } } @@ -1975,7 +1986,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* Checkpoint argEndCheckpoint = checkpoint(this); - if (matchSetmetatable(*call)) + if (matchSetMetatable(*call)) { TypePack argTailPack; if (argTail && args.size() < 2) @@ -2050,72 +2061,80 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; } - else + + if (FFlag::LuauTypestateBuiltins && shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) { - if (matchAssert(*call) && !argumentRefinements.empty()) - applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); + AstExpr* targetExpr = call->args.data[0]; + auto resultTy = arena->addType(BlockedType{}); - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = addTypePack(std::move(args), argTail); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); + if (auto def = dfg->getDefOptional(targetExpr)) + { + scope->lvalueTypes[*def] = resultTy; + scope->rvalueRefinements[*def] = resultTy; + } + } - /* - * To make bidirectional type checking work, we need to solve these constraints in a particular order: - * - * 1. Solve the function type - * 2. Propagate type information from the function type to the argument types - * 3. Solve the argument types - * 4. Solve the call - */ + if (matchAssert(*call) && !argumentRefinements.empty()) + applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); - NotNull checkConstraint = addConstraint( - scope, - call->func->location, - FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}} - ); + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = addTypePack(std::move(args), argTail); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); - forEachConstraint( - funcBeginCheckpoint, - funcEndCheckpoint, - this, - [checkConstraint](const ConstraintPtr& constraint) - { - checkConstraint->dependencies.emplace_back(constraint.get()); - } - ); + /* + * To make bidirectional type checking work, we need to solve these constraints in a particular order: + * + * 1. Solve the function type + * 2. Propagate type information from the function type to the argument types + * 3. Solve the argument types + * 4. Solve the call + */ - NotNull callConstraint = addConstraint( - scope, - call->func->location, - FunctionCallConstraint{ - fnType, - argPack, - rets, - call, - std::move(discriminantTypes), - &module->astOverloadResolvedTypes, - } - ); + NotNull checkConstraint = addConstraint( + scope, call->func->location, FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}} + ); - getMutable(rets)->owner = callConstraint.get(); + forEachConstraint( + funcBeginCheckpoint, + funcEndCheckpoint, + this, + [checkConstraint](const ConstraintPtr& constraint) + { + checkConstraint->dependencies.emplace_back(constraint.get()); + } + ); - callConstraint->dependencies.push_back(checkConstraint); + NotNull callConstraint = addConstraint( + scope, + call->func->location, + FunctionCallConstraint{ + fnType, + argPack, + rets, + call, + std::move(discriminantTypes), + &module->astOverloadResolvedTypes, + } + ); - forEachConstraint( - argBeginCheckpoint, - argEndCheckpoint, - this, - [checkConstraint, callConstraint](const ConstraintPtr& constraint) - { - constraint->dependencies.emplace_back(checkConstraint); + getMutable(rets)->owner = callConstraint.get(); - callConstraint->dependencies.emplace_back(constraint.get()); - } - ); + callConstraint->dependencies.push_back(checkConstraint); - return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; - } + forEachConstraint( + argBeginCheckpoint, + argEndCheckpoint, + this, + [checkConstraint, callConstraint](const ConstraintPtr& constraint) + { + constraint->dependencies.emplace_back(checkConstraint); + + callConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton, bool generalize) @@ -2703,7 +2722,16 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, Type visitLValue(scope, e, rhsType); else if (auto e = expr->as()) { - // Nothing? + if (FFlag::LuauNewSolverVisitErrorExprLvalues) + { + // If we end up with some sort of error expression in an lvalue + // position, at least go and check the expressions so that when + // we visit them later, there aren't any invalid assumptions. + for (auto subExpr : e->expressions) + { + check(scope, subExpr); + } + } } else ice->ice("Unexpected lvalue expression", expr->location); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index c0d30137c..31afabb23 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -326,6 +326,7 @@ ConstraintSolver::ConstraintSolver( NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, + NotNull dfg, TypeCheckLimits limits ) : arena(normalizer->arena) @@ -335,6 +336,7 @@ ConstraintSolver::ConstraintSolver( , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) + , dfg(dfg) , moduleResolver(moduleResolver) , requireCycles(std::move(requireCycles)) , logger(logger) @@ -618,11 +620,11 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool success = false; if (auto sc = get(*constraint)) - success = tryDispatch(*sc, constraint, force); + success = tryDispatch(*sc, constraint); else if (auto psc = get(*constraint)) - success = tryDispatch(*psc, constraint, force); + success = tryDispatch(*psc, constraint); else if (auto gc = get(*constraint)) - success = tryDispatch(*gc, constraint, force); + success = tryDispatch(*gc, constraint); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); else if (auto nc = get(*constraint)) @@ -650,14 +652,14 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo else if (auto rpc = get(*constraint)) success = tryDispatch(*rpc, constraint, force); else if (auto eqc = get(*constraint)) - success = tryDispatch(*eqc, constraint, force); + success = tryDispatch(*eqc, constraint); else LUAU_ASSERT(false); return success; } -bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint) { if (isBlocked(c.subType)) return block(c.subType, constraint); @@ -669,7 +671,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint) { if (isBlocked(c.subPack)) return block(c.subPack, constraint); @@ -681,7 +683,7 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint) { TypeId generalizedType = follow(c.generalizedType); @@ -828,7 +830,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull= 2) tableTy = iterator.head[1]; - return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint); } else @@ -2165,7 +2167,7 @@ bool ConstraintSolver::tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull constraint) { unify(constraint, c.resultType, c.assignmentType); unify(constraint, c.assignmentType, c.resultType); @@ -2328,8 +2330,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( TypeId nextTy, TypeId tableTy, const IterableConstraint& c, - NotNull constraint, - bool force + NotNull constraint ) { const FunctionType* nextFn = get(nextTy); diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 1307da8d7..4225942b9 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -2,6 +2,7 @@ #include "Luau/DataFlowGraph.h" #include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Def.h" #include "Luau/Common.h" #include "Luau/Error.h" @@ -12,6 +13,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauTypestateBuiltins) namespace Luau { @@ -67,6 +69,14 @@ DefId DataFlowGraph::getDef(const AstExpr* expr) const return NotNull{*def}; } +std::optional DataFlowGraph::getDefOptional(const AstExpr* expr) const +{ + auto def = astDefs.find(expr); + if (!def) + return std::nullopt; + return NotNull{*def}; +} + std::optional DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const { auto def = compoundAssignDefs.find(expr); @@ -929,6 +939,39 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) { visitExpr(c->func); + if (FFlag::LuauTypestateBuiltins && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) + { + AstExpr* firstArg = *c->args.begin(); + + // this logic has to handle the name-like subset of expressions. + std::optional result; + if (auto l = firstArg->as()) + result = visitExpr(l); + else if (auto g = firstArg->as()) + result = visitExpr(g); + else if (auto i = firstArg->as()) + result = visitExpr(i); + else if (auto i = firstArg->as()) + result = visitExpr(i); + else + LUAU_UNREACHABLE(); // This is unreachable because the whole thing is guarded by `isLValue`. + + LUAU_ASSERT(result); + + Location location = currentScope()->location; + // This scope starts at the end of the call site and continues to the end of the original scope. + location.begin = c->location.end; + DfgScope* child = makeChildScope(location); + scopeStack.push_back(child); + + auto [def, key] = *result; + graph.astDefs[firstArg] = def; + if (key) + graph.astRefinementKeys[firstArg] = key; + + visitLValue(firstArg, def); + } + for (AstExpr* arg : c->args) visitExpr(arg); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 853a3d89a..d4f3ebd99 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -4,11 +4,44 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Common.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/Module.h" +#include "Luau/TimeTrace.h" +#include "Luau/UnifierSharedState.h" +#include "Luau/TypeFunction.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/ConstraintSolver.h" #include "Luau/Frontend.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" +LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTINT(LuauTarjanChildLimit) +LUAU_FASTFLAG(LuauAllowFragmentParsing); +LUAU_FASTFLAG(LuauStoreDFGOnModule2); + +namespace +{ +template +void copyModuleVec(std::vector& result, const std::vector& input) +{ + result.insert(result.end(), input.begin(), input.end()); +} + +template +void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap& input) +{ + for (auto [k, v] : input) + result[k] = v; +} + +} // namespace + + namespace Luau { @@ -147,17 +180,173 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie return fragmentResult; } +ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) +{ + freeze(result->internalTypes); + freeze(result->interfaceTypes); + ModulePtr incrementalModule = std::make_shared(); + incrementalModule->name = result->name; + incrementalModule->humanReadableName = result->humanReadableName; + incrementalModule->allocator = std::move(alloc); + // Don't need to keep this alive (it's already on the source module) + copyModuleVec(incrementalModule->scopes, result->scopes); + copyModuleMap(incrementalModule->astTypes, result->astTypes); + copyModuleMap(incrementalModule->astTypePacks, result->astTypePacks); + copyModuleMap(incrementalModule->astExpectedTypes, result->astExpectedTypes); + // Don't need to clone astOriginalCallTypes + copyModuleMap(incrementalModule->astOverloadResolvedTypes, result->astOverloadResolvedTypes); + // Don't need to clone astForInNextTypes + copyModuleMap(incrementalModule->astForInNextTypes, result->astForInNextTypes); + // Don't need to clone astResolvedTypes + // Don't need to clone astResolvedTypePacks + // Don't need to clone upperBoundContributors + copyModuleMap(incrementalModule->astScopes, result->astScopes); + // Don't need to clone declared Globals; + return incrementalModule; +} + +FragmentTypeCheckResult typeCheckFragmentHelper( + Frontend& frontend, + AstStatBlock* root, + const ModulePtr& stale, + const ScopePtr& closestScope, + const Position& cursorPos, + std::unique_ptr astAllocator, + const FrontendOptions& opts +) +{ + freeze(stale->internalTypes); + freeze(stale->interfaceTypes); + ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator)); + unfreeze(incrementalModule->internalTypes); + unfreeze(incrementalModule->interfaceTypes); + + /// Setup typecheck limits + TypeCheckLimits limits; + if (opts.moduleTimeLimitSec) + limits.finishTime = TimeTrace::getClock() + *opts.moduleTimeLimitSec; + else + limits.finishTime = std::nullopt; + limits.cancellationToken = opts.cancellationToken; + + /// Icehandler + NotNull iceHandler{&frontend.iceHandler}; + /// Make the shared state for the unifier (recursion + iteration limits) + UnifierSharedState unifierState{iceHandler}; + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); + + /// Initialize the normalizer + Normalizer normalizer{&incrementalModule->internalTypes, frontend.builtinTypes, NotNull{&unifierState}}; + + /// User defined type functions runtime + TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits}); + + /// Create a DataFlowGraph just for the surrounding context + auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler); + + /// Contraint Generator + ConstraintGenerator cg{ + incrementalModule, + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull{&frontend.moduleResolver}, + frontend.builtinTypes, + iceHandler, + frontend.globals.globalScope, + nullptr, + nullptr, + NotNull{&updatedDfg}, + {} + }; + cg.rootScope = stale->getModuleScope().get(); + // Any additions to the scope must occur in a fresh scope + auto freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope}); + + // closest Scope -> children = { ...., freshChildOfNearestScope} + // We need to trim nearestChild from the scope hierarcy + closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()}); + // Visit just the root - we know the scope it should be in + cg.visitFragmentRoot(freshChildOfNearestScope, root); + // Trim nearestChild from the closestScope + Scope* back = closestScope->children.back().get(); + LUAU_ASSERT(back == freshChildOfNearestScope.get()); + closestScope->children.pop_back(); + + /// Initialize the constraint solver and run it + ConstraintSolver cs{ + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull(cg.rootScope), + borrowConstraints(cg.constraints), + incrementalModule->name, + NotNull{&frontend.moduleResolver}, + {}, + nullptr, + NotNull{&updatedDfg}, + limits + }; + + try + { + cs.run(); + } + catch (const TimeLimitError&) + { + stale->timeout = true; + } + catch (const UserCancelError&) + { + stale->cancelled = true; + } + + // In frontend we would forbid internal types + // because this is just for autocomplete, we don't actually care + // We also don't even need to typecheck - just synthesize types as best as we can + + freeze(incrementalModule->internalTypes); + freeze(incrementalModule->interfaceTypes); + return {std::move(incrementalModule), freshChildOfNearestScope.get()}; +} + + +FragmentTypeCheckResult typecheckFragment( + Frontend& frontend, + const ModuleName& moduleName, + const Position& cursorPos, + std::optional opts, + std::string_view src +) +{ + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + { + LUAU_ASSERT(!"Expected Source Module for fragment typecheck"); + return {}; + } + + ModulePtr module = frontend.moduleResolver.getModule(moduleName); + const ScopePtr& closestScope = findClosestScope(module, cursorPos); + + + FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos); + FrontendOptions frontendOptions = opts.value_or(frontend.options); + return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions); +} AutocompleteResult fragmentAutocomplete( Frontend& frontend, std::string_view src, const ModuleName& moduleName, Position& cursorPosition, + const FrontendOptions& opts, StringCompletionCallback callback ) { LUAU_ASSERT(FFlag::LuauSolverV2); - // TODO + LUAU_ASSERT(FFlag::LuauAllowFragmentParsing); + LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2); return {}; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 95ad58f35..4072575a5 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -49,7 +49,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false) LUAU_FASTFLAG(StudioReportLuauAny2) -LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule, false) +LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2, false) namespace Luau { @@ -1315,9 +1315,9 @@ ModulePtr check( } } - DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph oldDfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); DataFlowGraph* dfgForConstraintGeneration = nullptr; - if (FFlag::LuauStoreDFGOnModule) + if (FFlag::LuauStoreDFGOnModule2) { auto [dfg, scopes] = DataFlowGraphBuilder::buildShared(sourceModule.root, iceHandler); result->dataFlowGraph = std::move(dfg); @@ -1326,7 +1326,7 @@ ModulePtr check( } else { - dfgForConstraintGeneration = &dfg; + dfgForConstraintGeneration = &oldDfg; } UnifierSharedState unifierState{iceHandler}; @@ -1365,6 +1365,7 @@ ModulePtr check( moduleResolver, requireCycles, logger.get(), + NotNull{dfgForConstraintGeneration}, limits }; @@ -1418,16 +1419,32 @@ ModulePtr check( switch (mode) { case Mode::Nonstrict: - Luau::checkNonStrict( - builtinTypes, - NotNull{&typeFunctionRuntime}, - iceHandler, - NotNull{&unifierState}, - NotNull{&dfg}, - NotNull{&limits}, - sourceModule, - result.get() - ); + if (FFlag::LuauStoreDFGOnModule2) + { + Luau::checkNonStrict( + builtinTypes, + NotNull{&typeFunctionRuntime}, + iceHandler, + NotNull{&unifierState}, + NotNull{dfgForConstraintGeneration}, + NotNull{&limits}, + sourceModule, + result.get() + ); + } + else + { + Luau::checkNonStrict( + builtinTypes, + NotNull{&typeFunctionRuntime}, + iceHandler, + NotNull{&unifierState}, + NotNull{&oldDfg}, + NotNull{&limits}, + sourceModule, + result.get() + ); + } break; case Mode::Definition: // fallthrough intentional diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index c5e3496f5..1480d2635 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -20,8 +20,9 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAGVARIABLE(LuauUseNormalizeIntersectionLimit, false) LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) +LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance, false); +LUAU_FASTFLAGVARIABLE(LuauIntersectNormalsNeedsToTrackResourceLimits, false); namespace Luau { @@ -570,10 +571,11 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) { Set seen{nullptr}; - return isIntersectionInhabited(left, right, seen); + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + return isIntersectionInhabited(left, right, seenTablePropPairs, seen); } -NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet) +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set& seenSet) { left = follow(left); right = follow(right); @@ -586,7 +588,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ } NormalizedType norm{builtinTypes}; - NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); + NormalizationResult res = normalizeIntersections({left, right}, norm, seenTablePropPairs, seenSet); if (res != NormalizationResult::True) { if (cacheInhabitance && res == NormalizationResult::False) @@ -937,7 +939,8 @@ std::shared_ptr Normalizer::normalize(TypeId ty) NormalizedType norm{builtinTypes}; Set seenSetTypes{nullptr}; - NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + NormalizationResult res = unionNormalWithTy(norm, ty, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return nullptr; @@ -955,7 +958,12 @@ std::shared_ptr Normalizer::normalize(TypeId ty) return shared; } -NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet) +NormalizationResult Normalizer::normalizeIntersections( + const std::vector& intersections, + NormalizedType& outType, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSet +) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); @@ -964,7 +972,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector // Now we need to intersect the two types for (auto ty : intersections) { - NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); + NormalizationResult res = intersectNormalWithTy(norm, ty, seenTablePropPairs, seenSet); if (res != NormalizationResult::True) return res; } @@ -1728,7 +1736,13 @@ NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, N } // See above for an explaination of `ignoreSmallerTyvars`. -NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars) +NormalizationResult Normalizer::unionNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes, + int ignoreSmallerTyvars +) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) @@ -1760,7 +1774,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) { - NormalizationResult res = unionNormalWithTy(here, *it, seenSetTypes); + NormalizationResult res = unionNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) { seenSetTypes.erase(there); @@ -1781,7 +1795,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t norm.tops = builtinTypes->anyType; for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) { - NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) { seenSetTypes.erase(there); @@ -1881,7 +1895,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t for (auto& [tyvar, intersect] : here.tyvars) { - NormalizationResult res = unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar)); + NormalizationResult res = unionNormalWithTy(*intersect, there, seenTablePropPairs, seenSetTypes, tyvarIndex(tyvar)); if (res != NormalizationResult::True) return res; } @@ -2491,7 +2505,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T return arena->addTypePack({}); } -std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, Set& seenSet) +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSet) { if (here == there) return here; @@ -2573,31 +2587,63 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // if the intersection of the read types of a property is uninhabited, the whole table is `never`. // We've seen these table prop elements before and we're about to ask if their intersection // is inhabited - if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) + if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance) { - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - return {builtinTypes->neverType}; + auto pair1 = std::pair{*hprop.readTy, *tprop.readTy}; + auto pair2 = std::pair{*tprop.readTy, *hprop.readTy}; + if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2)) + { + seenTablePropPairs.erase(pair1); + seenTablePropPairs.erase(pair2); + return {builtinTypes->neverType}; + } + else + { + seenTablePropPairs.insert(pair1); + seenTablePropPairs.insert(pair2); + } + + Set seenSet{nullptr}; + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet); + + seenTablePropPairs.erase(pair1); + seenTablePropPairs.erase(pair2); + if (NormalizationResult::True != res) + return {builtinTypes->neverType}; + + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); } else { - seenSet.insert(*hprop.readTy); - seenSet.insert(*tprop.readTy); - } - NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); + if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + return {builtinTypes->neverType}; + } + else + { + seenSet.insert(*hprop.readTy); + seenSet.insert(*tprop.readTy); + } - // Cleanup - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); - if (NormalizationResult::True != res) - return {builtinTypes->neverType}; + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); - TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; - prop.readTy = ty; - hereSubThere &= (ty == hprop.readTy); - thereSubHere &= (ty == tprop.readTy); + if (NormalizationResult::True != res) + return {builtinTypes->neverType}; + + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } } else { @@ -2703,7 +2749,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (tmtable && hmtable) { // NOTE: this assumes metatables are ivariant - if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenSet)) + if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenTablePropPairs, seenSet)) { if (table == htable && *mtable == hmtable) return here; @@ -2733,12 +2779,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return table; } -void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes) +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes) { TypeIds tmp; for (TypeId here : heres) { - if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) + if (std::optional inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes)) tmp.insert(*inter); } heres.retain(tmp); @@ -2753,7 +2799,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) for (TypeId there : theres) { Set seenSetTypes{nullptr}; - if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + if (std::optional inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes)) tmp.insert(*inter); } } @@ -2971,12 +3018,17 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } } -NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes) +NormalizationResult Normalizer::intersectTyvarsWithTy( + NormalizedTyvars& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes +) { for (auto it = here.begin(); it != here.end();) { NormalizedType& inter = *it->second; - NormalizationResult res = intersectNormalWithTy(inter, there, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(inter, there, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; if (isShallowInhabited(inter)) @@ -2990,6 +3042,13 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, Ty // See above for an explaination of `ignoreSmallerTyvars`. NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { + if (FFlag::LuauIntersectNormalsNeedsToTrackResourceLimits) + { + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return NormalizationResult::HitLimits; + } + if (!get(there.tops)) { here.tops = intersectionOfTops(here.tops, there.tops); @@ -3001,13 +3060,10 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor return unionNormals(here, there, ignoreSmallerTyvars); } - if (FFlag::LuauUseNormalizeIntersectionLimit) - { - // Limit based on worst-case expansion of the table intersection - // This restriction can be relaxed when table intersection simplification is improved - if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit)) - return NormalizationResult::HitLimits; - } + // Limit based on worst-case expansion of the table intersection + // This restriction can be relaxed when table intersection simplification is improved + if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit)) + return NormalizationResult::HitLimits; here.booleans = intersectionOfBools(here.booleans, there.booleans); @@ -3062,7 +3118,12 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor return NormalizationResult::True; } -NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes) +NormalizationResult Normalizer::intersectNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes +) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) @@ -3078,14 +3139,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type else if (!get(here.tops)) { clearNormal(here); - return unionNormalWithTy(here, there, seenSetTypes); + return unionNormalWithTy(here, there, seenTablePropPairs, seenSetTypes); } else if (const UnionType* utv = get(there)) { NormalizedType norm{builtinTypes}; for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) { - NormalizationResult res = unionNormalWithTy(norm, *it, seenSetTypes); + NormalizationResult res = unionNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; } @@ -3095,7 +3156,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) { - NormalizationResult res = intersectNormalWithTy(here, *it, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; } @@ -3124,7 +3185,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { TypeIds tables = std::move(here.tables); clearNormal(here); - intersectTablesWithTable(tables, there, seenSetTypes); + intersectTablesWithTable(tables, there, seenTablePropPairs, seenSetTypes); here.tables = std::move(tables); } else if (get(there)) @@ -3236,7 +3297,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type return NormalizationResult::True; } else if (auto nt = get(t)) - return intersectNormalWithTy(here, nt->ty, seenSetTypes); + return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes); else { // TODO negated unions, intersections, table, and function. @@ -3256,7 +3317,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type else LUAU_ASSERT(!"Unreachable"); - NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenSetTypes); + NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; here.tyvars = std::move(tyvars); @@ -3456,38 +3517,4 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, N } } -bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) -{ - LUAU_ASSERT(!FFlag::LuauSolverV2); - - UnifierSharedState sharedState{&ice}; - TypeArena arena; - Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; - - u.tryUnify(subTy, superTy); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; -} - -bool isConsistentSubtype( - TypePackId subPack, - TypePackId superPack, - NotNull scope, - NotNull builtinTypes, - InternalErrorReporter& ice -) -{ - LUAU_ASSERT(!FFlag::LuauSolverV2); - - UnifierSharedState sharedState{&ice}; - TypeArena arena; - Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; - - u.tryUnify(subPack, superPack); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; -} - } // namespace Luau diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 099e6a0d1..3a1e3bd10 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -2,6 +2,7 @@ #include "Luau/Simplify.h" +#include "Luau/Common.h" #include "Luau/DenseHash.h" #include "Luau/RecursionCounter.h" #include "Luau/Set.h" @@ -14,6 +15,7 @@ LUAU_FASTINT(LuauTypeReductionRecursionLimit) LUAU_FASTFLAG(LuauSolverV2) LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8); +LUAU_FASTFLAGVARIABLE(LuauFlagBasicIntersectFollows, false); namespace Luau { @@ -1064,6 +1066,12 @@ TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) { + if (FFlag::LuauFlagBasicIntersectFollows) + { + left = follow(left); + right = follow(right); + } + if (get(left) && get(right)) return right; if (get(right) && get(left)) diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index cc1ed7cf0..6c84c9333 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -22,7 +22,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteNewSolverLimit, false); LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau @@ -512,19 +511,14 @@ struct SeenSetPopper SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull scope) { - std::optional rc; + UnifierCounters& counters = normalizer->sharedState->counters; + RecursionCounter rc(&counters.recursionCount); - if (FFlag::LuauAutocompleteNewSolverLimit) + if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount) { - UnifierCounters& counters = normalizer->sharedState->counters; - rc.emplace(&counters.recursionCount); - - if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount) - { - SubtypingResult result; - result.normalizationTooComplex = true; - return result; - } + SubtypingResult result; + result.normalizationTooComplex = true; + return result; } subTy = follow(subTy); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index bde7751af..e272c6610 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -93,8 +93,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); + TypeId leftTy = arena->addType((*leftRep)->pending.clone()); + TypeId rightTy = arena->addType(rightRep->pending.clone()); typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}}; } else @@ -170,8 +170,8 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); + TypeId leftTy = arena->addType((*leftRep)->pending.clone()); + TypeId rightTy = arena->addType(rightRep->pending.clone()); if (follow(leftTy) == follow(rightTy)) typeVarChanges[ty] = std::move(rightRep); @@ -217,7 +217,7 @@ TxnLog TxnLog::inverse() for (auto& [ty, _rep] : typeVarChanges) { if (!_rep->dead) - inversed.typeVarChanges[ty] = std::make_unique(*ty); + inversed.typeVarChanges[ty] = std::make_unique(ty->clone()); } for (auto& [tp, _rep] : typePackChanges) @@ -292,7 +292,7 @@ PendingType* TxnLog::queue(TypeId ty) auto& pending = typeVarChanges[ty]; if (!pending || (*pending).dead) { - pending = std::make_unique(*ty); + pending = std::make_unique(ty->clone()); pending->pending.owningArena = nullptr; } diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index a77836c5f..1cf9d268f 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -999,6 +999,11 @@ Type& Type::operator=(const Type& rhs) return *this; } +Type Type::clone() const +{ + return *this; +} + TypeId makeFunction( TypeArena& arena, std::optional selfType, diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b681e15cd..fab5a65db 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5037,17 +5037,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. - Unifier child = state.makeChildUnifier(); - child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); - if (!child.errors.empty()) + std::unique_ptr child = state.makeChildUnifier(); + child->tryUnify(subTy, superTy, /*isFunctionCall*/ false); + if (!child->errors.empty()) { - TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); + TypeId instantiated = instantiate(scope, subTy, state.location, &child->log); if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - state.log.concat(std::move(child.log)); + state.log.concat(std::move(child->log)); - state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); + state.errors.insert(state.errors.end(), child->errors.begin(), child->errors.end()); } else { @@ -5056,7 +5056,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c } else { - state.log.concat(std::move(child.log)); + state.log.concat(std::move(child->log)); } } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 25f05a6f6..b1e16c25a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -749,25 +749,25 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ for (TypeId type : subUnion->options) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(type, superTy); if (useNewSolver) - logs.push_back(std::move(innerState.log)); + logs.push_back(std::move(innerState->log)); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) unificationTooComplex = e; - else if (innerState.failure) + else if (innerState->failure) { // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. - if (innerState.errors.empty()) - logs.push_back(std::move(innerState.log)); + if (innerState->errors.empty()) + logs.push_back(std::move(innerState->log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' else if (!firstFailedOption && !isNil(type)) - firstFailedOption = {innerState.errors.front()}; + firstFailedOption = {innerState->errors.front()}; failed = true; - errorsSuppressed &= innerState.errors.empty(); + errorsSuppressed &= innerState->errors.empty(); } } @@ -862,26 +862,26 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp for (size_t i = 0; i < uv->options.size(); ++i) { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; - Unifier innerState = makeChildUnifier(); - innerState.normalize = false; - innerState.tryUnify_(subTy, type, isFunctionCall); + std::unique_ptr innerState = makeChildUnifier(); + innerState->normalize = false; + innerState->tryUnify_(subTy, type, isFunctionCall); - if (!innerState.failure) + if (!innerState->failure) { found = true; if (useNewSolver) - logs.push_back(std::move(innerState.log)); + logs.push_back(std::move(innerState->log)); else { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } } - else if (innerState.errors.empty()) + else if (innerState->errors.empty()) { errorsSuppressed = true; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { unificationTooComplex = e; } @@ -890,7 +890,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp failedOptionCount++; if (!failedOption) - failedOption = {innerState.errors.front()}; + failedOption = {innerState->errors.front()}; } } @@ -906,25 +906,25 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // It is possible that T <: A | B even though T innerState = makeChildUnifier(); std::shared_ptr subNorm = normalizer->normalize(subTy); std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) return reportError(location, NormalizationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( + innerState->tryUnifyNormalizedTypes( subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption ); else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + innerState->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - if (!innerState.failure) - log.concat(std::move(innerState.log)); - else if (errorsSuppressed || innerState.errors.empty()) + if (!innerState->failure) + log.concat(std::move(innerState->log)); + else if (errorsSuppressed || innerState->errors.empty()) failure = true; else - reportError(std::move(innerState.errors.front())); + reportError(std::move(innerState->errors.front())); } else if (!found && normalize) { @@ -963,22 +963,22 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I // T <: A & B if and only if T <: A and T <: B for (TypeId type : uv->parts) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) unificationTooComplex = e; - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) { if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; + firstFailedOption = {innerState->errors.front()}; } if (useNewSolver) - logs.push_back(std::move(innerState.log)); + logs.push_back(std::move(innerState->log)); else - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } if (useNewSolver) @@ -1058,27 +1058,27 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* for (size_t i = 0; i < uv->parts.size(); ++i) { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; - Unifier innerState = makeChildUnifier(); - innerState.normalize = false; - innerState.tryUnify_(type, superTy, isFunctionCall); + std::unique_ptr innerState = makeChildUnifier(); + innerState->normalize = false; + innerState->tryUnify_(type, superTy, isFunctionCall); // TODO: This sets errorSuppressed to true if any of the parts is error-suppressing, // in paricular any & T is error-suppressing. Really, errorSuppressed should be true if // all of the parts are error-suppressing, but that fails to typecheck lua-apps. - if (innerState.errors.empty()) + if (innerState->errors.empty()) { found = true; - errorsSuppressed = innerState.failure; - if (useNewSolver || innerState.failure) - logs.push_back(std::move(innerState.log)); + errorsSuppressed = innerState->failure; + if (useNewSolver || innerState->failure) + logs.push_back(std::move(innerState->log)); else { errorsSuppressed = false; - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { unificationTooComplex = e; } @@ -1204,16 +1204,16 @@ void Unifier::tryUnifyNormalizedTypes( { for (TypeId superTable : superNorm.tables) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(subClass, superTable); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify(subClass, superTable); - if (innerState.errors.empty()) + if (innerState->errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); } } @@ -1235,17 +1235,17 @@ void Unifier::tryUnifyNormalizedTypes( break; } - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.tryUnify(subTable, superTable); + innerState->tryUnify(subTable, superTable); - if (innerState.errors.empty()) + if (innerState->errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); } if (!found) @@ -1258,15 +1258,15 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); for (TypeId superFun : superNorm.functions.parts) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); const FunctionType* superFtv = get(superFun); if (!superFtv) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); - innerState.tryUnify_(tgt, superFtv->retTypes); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - else if (auto e = hasUnificationTooComplex(innerState.errors)) + TypePackId tgt = innerState->tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); + innerState->tryUnify_(tgt, superFtv->retTypes); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); else return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); @@ -1304,17 +1304,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized { if (!firstFun) firstFun = ftv; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(args, ftv->argTypes); - if (innerState.errors.empty()) + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(args, ftv->argTypes); + if (innerState->errors.empty()) { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); if (result) { - innerState.log.clear(); - innerState.tryUnify_(*result, ftv->retTypes); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + innerState->log.clear(); + innerState->tryUnify_(*result, ftv->retTypes); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload // in that case. @@ -1324,7 +1324,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized else result = ftv->retTypes; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { reportError(*e); return builtinTypes->errorRecoveryTypePack(args); @@ -1510,18 +1510,18 @@ void Unifier::enableNewSolver() ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { - Unifier s = makeChildUnifier(); - s.tryUnify_(subTy, superTy); + std::unique_ptr s = makeChildUnifier(); + s->tryUnify_(subTy, superTy); - return s.errors; + return s->errors; } ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { - Unifier s = makeChildUnifier(); - s.tryUnify_(subTy, superTy, isFunctionCall); + std::unique_ptr s = makeChildUnifier(); + s->tryUnify_(subTy, superTy, isFunctionCall); - return s.errors; + return s->errors; } void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) @@ -1884,9 +1884,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal // generic methods in tables to be marked read-only. if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { - Instantiation instantiation{&log, types, builtinTypes, scope->level, scope}; + std::unique_ptr instantiation = std::make_unique(&log, types, builtinTypes, scope->level, scope); - std::optional instantiated = instantiation.substitute(subTy); + std::optional instantiated = instantiation->substitute(subTy); if (instantiated.has_value()) { subFunction = log.getMutable(*instantiated); @@ -1930,54 +1930,54 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (!isFunctionCall) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + innerState->ctx = CountMismatch::Arg; + innerState->tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - bool reported = !innerState.errors.empty(); + bool reported = !innerState->errors.empty(); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + else if (!innerState->errors.empty() && innerState->firstPackErrorPos) reportError( location, TypeMismatch{ superTy, subTy, - format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), + format("Argument #%d type is not compatible.", *innerState->firstPackErrorPos), + innerState->errors.front(), mismatchContext() } ); - else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + else if (!innerState->errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()}); - innerState.ctx = CountMismatch::FunctionResult; - innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); + innerState->ctx = CountMismatch::FunctionResult; + innerState->tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + else if (!innerState->errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState->errors.front(), mismatchContext()}); + else if (!innerState->errors.empty() && innerState->firstPackErrorPos) reportError( location, TypeMismatch{ superTy, subTy, - format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), + format("Return #%d type is not compatible.", *innerState->firstPackErrorPos), + innerState->errors.front(), mismatchContext() } ); - else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + else if (!innerState->errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()}); } - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); } else { @@ -2115,14 +2115,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (!literalProperties || !literalProperties->contains(name)) variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(r->second.type(), prop.type()); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(r->second.type(), prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { @@ -2132,14 +2132,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (!literalProperties || !literalProperties->contains(name)) variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subTable->indexer->indexResultType, prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (subTable->state == TableState::Unsealed && isOptional(prop.type())) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` @@ -2210,20 +2210,20 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (!literalProperties || !literalProperties->contains(name)) variance = Invariant; - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) - innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); + innerState->tryUnify_(prop.type(), superTable->indexer->indexResultType); else { // Incredibly, the old solver depends on this bug somehow. - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); + innerState->tryUnify_(superTable->indexer->indexResultType, prop.type()); } - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (superTable->state == TableState::Unsealed) { @@ -2294,22 +2294,22 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, Resetter resetter{&variance}; variance = Invariant; - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + innerState->tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); - bool reported = !innerState.errors.empty(); + bool reported = !innerState->errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, "[indexer key]", superTy, subTy); - innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + innerState->tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); if (!reported) - checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, "[indexer value]", superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (superTable->indexer) { @@ -2408,13 +2408,13 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { TypeId ty = it->second.type(); - Unifier child = makeChildUnifier(); - child.tryUnify_(ty, superTy); + std::unique_ptr child = makeChildUnifier(); + child->tryUnify_(ty, superTy); // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check - TypeId newSuperTy = child.log.follow(superTy); + TypeId newSuperTy = child->log.follow(superTy); if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) { @@ -2422,16 +2422,16 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) return; } - if (auto e = hasUnificationTooComplex(child.errors)) + if (auto e = hasUnificationTooComplex(child->errors)) reportError(*e); - else if (!child.errors.empty()) - fail(child.errors.front()); + else if (!child->errors.empty()) + fail(child->errors.front()); - log.concat(std::move(child.log)); + log.concat(std::move(child->log)); // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype - if (child.errors.empty()) + if (child->errors.empty()) log.replace(superTy, BoundType{subTy}); return; @@ -2476,19 +2476,19 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (const MetatableType* subMetatable = log.getMutable(subTy)) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subMetatable->table, superMetatable->table); - innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subMetatable->table, superMetatable->table); + innerState->tryUnify_(subMetatable->metatable, superMetatable->metatable); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) reportError( - location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} + location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()} ); - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (TableType* subTable = log.getMutable(subTy)) { @@ -2498,14 +2498,14 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { if (useNewSolver) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); bool missingProperty = false; for (const auto& [propName, prop] : subTable->props) { if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) { - innerState.tryUnify(prop.type(), *mtPropTy); + innerState->tryUnify(prop.type(), *mtPropTy); } else { @@ -2520,18 +2520,18 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) // TODO: Unify indexers. } - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) reportError(TypeError{ location, - TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} + TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()} }); else if (!missingProperty) { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); log.bindTable(subTy, superTy); - failure |= innerState.failure; + failure |= innerState->failure; } } else @@ -2618,15 +2618,15 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) } else { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(classProp->type(), prop.type()); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(classProp->type(), prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + checkChildUnifierTypeMismatch(innerState->errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) + if (innerState->errors.empty()) { - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else { @@ -2662,9 +2662,9 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) return reportError(location, NormalizationTooComplex{}); // T state = makeChildUnifier(); + state->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); + if (state->errors.empty()) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } @@ -2889,27 +2889,27 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) if (occurs) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); if (const UnionType* ut = get(haystack)) { if (reversed) - innerState.tryUnifyUnionWithType(haystack, ut, needle); + innerState->tryUnifyUnionWithType(haystack, ut, needle); else - innerState.tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); + innerState->tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); } else if (const IntersectionType* it = get(haystack)) { if (reversed) - innerState.tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); + innerState->tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); else - innerState.tryUnifyTypeWithIntersection(needle, haystack, it); + innerState->tryUnifyTypeWithIntersection(needle, haystack, it); } else { - innerState.failure = true; + innerState->failure = true; } - if (innerState.failure) + if (innerState->failure) { reportError(location, OccursCheckFailed{}); log.replace(needle, BoundType{builtinTypes->errorRecoveryType()}); @@ -3014,14 +3014,14 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ return false; } -Unifier Unifier::makeChildUnifier() +std::unique_ptr Unifier::makeChildUnifier() { - Unifier u = Unifier{normalizer, scope, location, variance, &log}; - u.normalize = normalize; - u.checkInhabited = checkInhabited; + std::unique_ptr u = std::make_unique(normalizer, scope, location, variance, &log); + u->normalize = normalize; + u->checkInhabited = checkInhabited; if (useNewSolver) - u.enableNewSolver(); + u->enableNewSolver(); return u; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 099ece2b4..7845cca28 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -1490,6 +1490,7 @@ class AstVisitor } }; +bool isLValue(const AstExpr*); AstName getIdentifier(AstExpr*); Location getLocation(const AstTypeList& typeList); @@ -1520,4 +1521,4 @@ struct hash } }; -} // namespace std \ No newline at end of file +} // namespace std diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index f5deac2fe..a72aca86e 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -1146,6 +1146,14 @@ void AstTypePackGeneric::visit(AstVisitor* visitor) visitor->visit(this); } +bool isLValue(const AstExpr* expr) +{ + return expr->is() + || expr->is() + || expr->is() + || expr->is(); +} + AstName getIdentifier(AstExpr* node) { if (AstExprGlobal* expr = node->as()) @@ -1170,4 +1178,4 @@ Location getLocation(const AstTypeList& typeList) return result; } -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 3d39a2de5..941137e96 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -2,9 +2,13 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "ldebug.h" #include "lstate.h" #include "lvm.h" +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCoroCheckStack, false) +LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) + #define CO_STATUS_ERROR -1 #define CO_STATUS_BREAK -2 @@ -37,6 +41,12 @@ static int auxresume(lua_State* L, lua_State* co, int narg) luaL_error(L, "too many arguments to resume"); lua_xmove(L, co, narg); } + else if (DFFlag::LuauCoroCheckStack) + { + // coroutine might be completely full already + if ((co->top - co->base) > LUAI_MAXCSTACK) + luaL_error(L, "too many arguments to resume"); + } co->singlestep = L->singlestep; @@ -227,8 +237,22 @@ static int coclose(lua_State* L) else { lua_pushboolean(L, false); - if (lua_gettop(co)) - lua_xmove(co, L, 1); // move error message + + if (DFFlag::LuauStackLimit) + { + if (co->status == LUA_ERRMEM) + lua_pushstring(L, LUA_MEMERRMSG); + else if (co->status == LUA_ERRERR) + lua_pushstring(L, LUA_ERRERRMSG); + else if (lua_gettop(co)) + lua_xmove(co, L, 1); // move error message + } + else + { + if (lua_gettop(co)) + lua_xmove(co, L, 1); // move error message + } + lua_resetthread(co); return 2; } diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 400654b7c..28ab00b63 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,11 @@ #include +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStackLimit, false) + +// keep max stack allocation request under 1GB +#define MAX_STACK_SIZE (int(1024 / sizeof(TValue)) * 1024 * 1024) + /* ** {====================================================== ** Error-recovery functions @@ -176,6 +181,10 @@ static void correctstack(lua_State* L, TValue* oldstack) void luaD_reallocstack(lua_State* L, int newsize) { + // throw 'out of memory' error because space for a custom error message cannot be guaranteed here + if (DFFlag::LuauStackLimit && newsize > MAX_STACK_SIZE) + luaD_throw(L, LUA_ERRMEM); + TValue* oldstack = L->stack; int realsize = newsize + EXTRA_STACK; LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 4473f04f4..6ba758df5 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -14,6 +14,8 @@ #include +LUAU_DYNAMIC_FASTFLAG(LuauCoroCheckStack) + /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -436,12 +438,27 @@ static void shrinkstack(lua_State* L) int s_used = cast_int(lim - L->stack); // part of stack in use if (L->size_ci > LUAI_MAXCALLS) // handling overflow? return; // do not touch the stacks - if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); + + if (DFFlag::LuauCoroCheckStack) + { + if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); // still big enough... + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + + if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2); // still big enough... + condhardstacktests(luaD_reallocstack(L, s_used)); + } + else + { + if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); // still big enough... + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + + if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2); // still big enough... + condhardstacktests(luaD_reallocstack(L, s_used)); + } } /* diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 53c418cd7..de4049a97 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,9 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauUseNormalizeIntersectionLimit) using namespace Luau; @@ -3824,7 +3822,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_subtyping_recursion_limit") if (!FFlag::LuauSolverV2) return; - ScopedFastFlag luauAutocompleteNewSolverLimit{FFlag::LuauAutocompleteNewSolverLimit, true}; ScopedFastInt luauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 10}; const int parts = 100; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e135cc528..0a88444a8 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -35,6 +35,7 @@ LUAU_FASTFLAG(LuauMathMap) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauNativeAttribute) +LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) static lua_CompileOptions defaultOptions() { @@ -755,6 +756,8 @@ TEST_CASE("Closure") TEST_CASE("Calls") { + ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; + runConformance("calls.lua"); } @@ -794,6 +797,8 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { + ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; + runConformance( "pcall.lua", [](lua_State* L) diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index 62efdb686..1b84d4c90 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -44,7 +44,7 @@ void ConstraintGeneratorFixture::solve(const std::string& code) { generateConstraints(code); ConstraintSolver cs{ - NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {} + NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {} }; cs.run(); } diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index 692d6e0fc..de2e98322 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -5,14 +5,17 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" using namespace Luau; LUAU_FASTFLAG(LuauAllowFragmentParsing); +LUAU_FASTFLAG(LuauStoreDFGOnModule2); struct FragmentAutocompleteFixture : Fixture { + ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}}; FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) { @@ -31,11 +34,20 @@ struct FragmentAutocompleteFixture : Fixture FragmentParseResult parseFragment(const std::string& document, const Position& cursorPos) { - ScopedFastFlag sffs[]{{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}}; SourceModule* srcModule = this->getMainSourceModule(); std::string_view srcString = document; return Luau::parseFragment(*srcModule, srcString, cursorPos); } + + FragmentTypeCheckResult checkFragment(const std::string& document, const Position& cursorPos) + { + FrontendOptions options; + options.retainFullTypeGraphs = true; + // Don't strictly need this in the new solver + options.forAutocomplete = true; + options.runLintChecks = false; + return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document); + } }; TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); @@ -267,3 +279,56 @@ local y = 5 } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_simple_fragment") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = checkFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + auto opt = linearSearchForBinding(fragment.freshScope, "z"); + REQUIRE(opt); + CHECK_EQ("number", toString(*opt)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_fragment_inserted_inline") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + auto fragment = checkFragment( + R"( +local x = 4 +local z = x +local y = 5 +)", + Position{2, 11} + ); + + auto correct = linearSearchForBinding(fragment.freshScope, "z"); + REQUIRE(correct); + CHECK_EQ("number", toString(*correct)); +} + +TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index d9c9f46d9..24186c0a0 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauNormalizationTracksCyclicPairsThroughInhabitance) using namespace Luau; namespace @@ -1026,4 +1027,109 @@ TEST_CASE_FIXTURE(NormalizeFixture, "truthy_table_property_and_optional_table_wi CHECK("{ x: number }" == toString(ty)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "normalizer_should_be_able_to_detect_cyclic_tables_and_not_stack_overflow") +{ + if (!FFlag::LuauSolverV2) + return; + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; + ScopedFastFlag sff{FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance, true}; + CheckResult result = check(R"( +--!strict + +type Array = { [number] : T} +type Object = { [number] : any} + +type Set = typeof(setmetatable( + {} :: { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, + }, + {} :: { + __index: Set, + __iter: (self: Set) -> (({ [K]: V }, K?) -> (K, V), T), + } +)) + +type Map = typeof(setmetatable( + {} :: { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + [K]: V, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + _map: { [K]: V }, + _array: { [number]: K }, + __index: (self: Map, key: K) -> V, + __iter: (self: Map) -> (({ [K]: V }, K?) -> (K?, V), V), + __newindex: (self: Map, key: K, value: V) -> (), + }, + {} :: { + __index: Map, + __iter: (self: Map) -> (({ [K]: V }, K?) -> (K, V), V), + __newindex: (self: Map, key: K, value: V) -> (), + } +)) +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U + +function fromSet( + value: Set, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + + local array : { [number] : string} = {"foo"} + return array +end + +function instanceof(tbl: any, class: any): boolean + return true +end + +function fromArray( + value: Array, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + local array : {[number] : string} = {} + return array +end + +return function( + value: string | Array | Set | Map, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + if value == nil then + error("cannot create array from a nil value") + end + local array: Array | Array | Array + + if instanceof(value, Set) then + array = fromSet(value :: Set, mapFn, thisArg) + else + array = {} + end + + + return array +end +)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 25f3d1132..7f73f8e2c 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" #include "Fixture.h" @@ -8,7 +9,10 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAG(LuauTypestateBuiltins) +LUAU_FASTFLAG(LuauStringFormatArityFix) TEST_SUITE_BEGIN("BuiltinTests"); @@ -802,6 +806,19 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") CHECK_EQ(tm->givenType, builtinTypes->numberType); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_trivial_arity") +{ + ScopedFastFlag sff{FFlag::LuauStringFormatArityFix, true}; + + CheckResult result = check(R"( + string.format() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Argument count mismatch. Function 'string.format' expects at least 1 argument, but none are specified", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") { CheckResult result = check(R"( @@ -1109,15 +1126,28 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") local c = tf3[2] local d = tf1.b + + local a2 = t1.a + local b2 = t2.b + local c2 = t3[2] )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauSolverV2) + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + CHECK("Key 'b' not found in table '{ read a: number }'" == toString(result.errors[0])); + else if (FFlag::LuauSolverV2) CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0])); else CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + { + CHECK_EQ("{ read a: number }", toString(requireTypeAtPosition({15, 19}))); + CHECK_EQ("{ read b: string }", toString(requireTypeAtPosition({16, 19}))); + CHECK_EQ("{boolean}", toString(requireTypeAtPosition({17, 19}))); + } + CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); @@ -1126,6 +1156,86 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") CHECK_EQ("any", toString(requireType("d"))); else CHECK_EQ("*error-type*", toString(requireType("d"))); + + CHECK_EQ("number", toString(requireType("a2"))); + CHECK_EQ("string", toString(requireType("b2"))); + CHECK_EQ("boolean", toString(requireType("c2"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_does_not_retroactively_block_mutation") +{ + CheckResult result = check(R"( + local t1 = {a = 42} + + t1.q = ":3" + + local tf1 = table.freeze(t1) + + local a = tf1.a + local b = t1.a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + + if (FFlag::LuauTypestateBuiltins) + { + CHECK_EQ("t1 | { read a: number, read q: string }", toString(requireType("t1"))); + // before the assignment, it's `t1` + CHECK_EQ("t1", toString(requireTypeAtPosition({3, 8}))); + // after the assignment, it's read-only. + CHECK_EQ("{ read a: number, read q: string }", toString(requireTypeAtPosition({8, 18}))); + } + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_no_generic_table") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + --!strict + type k = { + read k: string, + } + + function _(): k + return table.freeze({ + k = "", + }) + end + )"); + + if (FFlag::LuauTypestateBuiltins) + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_non_tables") +{ + CheckResult result = check(R"( + --!strict + table.freeze(42) + )"); + + // this does not error in the new solver without the typestate builtins functionality. + if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins) + { + LUAU_REQUIRE_NO_ERRORS(result); + return; + } + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + CHECK_EQ(toString(tm->wantedType), "table"); + else + CHECK_EQ(toString(tm->wantedType), "{- -}"); + CHECK_EQ(toString(tm->givenType), "number"); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index f8cc88315..c31f3d8cf 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauTypestateBuiltins) using namespace Luau; @@ -152,6 +153,45 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") CHECK(get(*iter.tail())); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_table_freeze") +{ + fileResolver.source["game/A"] = R"( + --!strict + return { + a = 1, + } + )"; + + fileResolver.source["game/B"] = R"( + --!strict + return table.freeze(require(game.A)) + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr a = frontend.moduleResolver.getModule("game/A"); + REQUIRE(a != nullptr); + // confirm that no cross-module mutation happened here! + if (FFlag::LuauSolverV2) + CHECK(toString(a->returnType) == "{ a: number }"); + else + CHECK(toString(a->returnType) == "{| a: number |}"); + + ModulePtr b = frontend.moduleResolver.getModule("game/B"); + REQUIRE(b != nullptr); + // confirm that no cross-module mutation happened here! + if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins) + CHECK(toString(b->returnType) == "{ read a: number }"); + else if (FFlag::LuauSolverV2) + CHECK(toString(b->returnType) == "{ a: number }"); + else + CHECK(toString(b->returnType) == "{| a: number |}"); +} + TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 2943486d3..615bebcdf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,7 +8,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUseNormalizeIntersectionLimit) using namespace Luau; @@ -2327,8 +2326,6 @@ end) TEST_CASE_FIXTURE(Fixture, "refinements_table_intersection_limits" * doctest::timeout(0.5)) { - ScopedFastFlag LuauUseNormalizeIntersectionLimit{FFlag::LuauUseNormalizeIntersectionLimit, true}; - CheckResult result = check(R"( --!strict type Dir = { diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index fbc03213e..42963f5e1 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -23,6 +23,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) using namespace Luau; @@ -1706,4 +1707,29 @@ TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub") )")); } +TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauNewSolverVisitErrorExprLvalues, true} + }; + + // This should always fail to parse, but shouldn't assert. Previously this + // would assert as we end up _roughly_ parsing this (with a lot of error + // nodes) as: + // + // do + // x :: T, y = z + // end + // + // We assume that `T` has some resolved type that is set up during + // constraint generation and resolved during constraint solving to + // be used during typechecking. We didn't descend into error nodes + // in lvalue positions. + LUAU_REQUIRE_ERRORS(check(R"( + --!strict + (::, + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 10ddd097a..2a0a072a7 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -32,7 +32,7 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; - Type numberTwo = numberOne; + Type numberTwo = numberOne.clone(); state.tryUnify(&numberTwo, &numberOne); @@ -64,13 +64,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; - Type functionOneSaved = functionOne; + Type functionOneSaved = functionOne.clone(); TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) }}; - Type functionTwoSaved = functionTwo; + Type functionTwoSaved = functionTwo.clone(); state.tryUnify(&functionTwo, &functionOne); CHECK(state.failure); diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.lua index 621a921aa..6555f93e1 100644 --- a/tests/conformance/calls.lua +++ b/tests/conformance/calls.lua @@ -236,4 +236,12 @@ if not limitedstack then assert(not err and string.find(msg, "error")) end +-- testing deep nested calls with a large thread stack +do + function recurse(n, ...) return n <= 1 and (1 + #{...}) or recurse(n-1, table.unpack(table.create(4000, 1))) + 1 end + + local ok, msg = pcall(recurse, 19000) + assert(not ok and string.find(msg, "not enough memory")) +end + return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index c2be2708a..265c397bd 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -168,6 +168,10 @@ checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, functio checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) +co = coroutine.create(function() table.create(1e6) end) +coroutine.resume(co) +checkresults({ false, "not enough memory" }, coroutine.close(co)) + -- ensure that pcall and xpcall close upvalues when handling error local upclo local function uptest(y)