From 02241b6d247201226a2592cc84860b1aa61a18fb Mon Sep 17 00:00:00 2001 From: aaron Date: Fri, 27 Sep 2024 11:58:21 -0700 Subject: [PATCH] Sync to upstream/release/645 (#1440) In this update, we continue to improve the overall stability of the new type solver. We're also shipping some early bits of two new features, one of the language and one of the analysis API: user-defined type functions and an incremental typechecking API. If you use the new solver and want to use all new fixes included in this release, you have to reference an additional Luau flag: ```c++ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) ``` And set its value to `645`: ```c++ DFInt::LuauTypeSolverRelease.value = 645; // Or a higher value for future updates ``` ## New Solver * Fix a crash where scopes are incorrectly accessed cross-module after they've been deallocated by appropriately zeroing out associated scope pointers for free types, generic types, table types, etc. * Fix a crash where we were incorrectly caching results for bound types in generalization. * Eliminated some unnecessary intermediate allocations in the constraint solver and type function infrastructure. * Built some initial groundwork for an incremental typecheck API for use by language servers. * Built an initial technical preview for [user-defined type functions](https://rfcs.luau-lang.org/user-defined-type-functions.html), more work still to come (including calling type functions from other type functions), but adventurous folks wanting to experiment with it can try it out by enabling `FFlag::LuauUserDefinedTypeFunctionsSyntax` and `FFlag::LuauUserDefinedTypeFunction` in their local environment. Special thanks to @joonyoo181 who built up all the initial infrastructure for this during his internship! ## Miscellaneous changes * Fix a compilation error on Ubuntu (fixes #1437) --- Internal Contributors: Co-authored-by: Aaron Weiss Co-authored-by: Hunter Goldstein Co-authored-by: Jeremy Yoo Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov Co-authored-by: Junseo Yoo --- Analysis/include/Luau/ConstraintSolver.h | 13 +- Analysis/include/Luau/Error.h | 10 +- Analysis/include/Luau/FragmentAutocomplete.h | 23 + Analysis/include/Luau/OverloadResolution.h | 3 + Analysis/include/Luau/Subtyping.h | 2 + Analysis/include/Luau/TypeChecker2.h | 1 + Analysis/include/Luau/TypeFunction.h | 30 +- Analysis/include/Luau/TypeFunctionRuntime.h | 267 ++ .../include/Luau/TypeFunctionRuntimeBuilder.h | 52 + Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/ConstraintSolver.cpp | 44 +- Analysis/src/Error.cpp | 13 + Analysis/src/FragmentAutocomplete.cpp | 48 + Analysis/src/Frontend.cpp | 2 + Analysis/src/Generalization.cpp | 13 + Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/Module.cpp | 17 + Analysis/src/NonStrictTypeChecker.cpp | 10 +- Analysis/src/Normalize.cpp | 6 +- Analysis/src/OverloadResolution.cpp | 16 +- Analysis/src/Subtyping.cpp | 4 +- Analysis/src/ToString.cpp | 1 + Analysis/src/TypeChecker2.cpp | 25 +- Analysis/src/TypeFunction.cpp | 186 +- Analysis/src/TypeFunctionReductionGuesser.cpp | 1 + Analysis/src/TypeFunctionRuntime.cpp | 2192 +++++++++++++++++ Analysis/src/TypeFunctionRuntimeBuilder.cpp | 788 ++++++ Analysis/src/TypeInfer.cpp | 12 +- Ast/include/Luau/ParseOptions.h | 12 + Ast/include/Luau/Parser.h | 2 +- Ast/include/Luau/TimeTrace.h | 1 + Ast/src/Lexer.cpp | 14 +- Ast/src/Parser.cpp | 14 +- Ast/src/TimeTrace.cpp | 1 + CMakeLists.txt | 3 +- CodeGen/src/IrLoweringA64.cpp | 4 +- Makefile | 4 +- Sources.cmake | 8 + VM/src/ltm.cpp | 80 +- tests/Conformance.test.cpp | 25 +- tests/ConstraintGeneratorFixture.cpp | 4 +- tests/ConstraintGeneratorFixture.h | 1 + tests/FragmentAutocomplete.test.cpp | 139 ++ tests/Parser.test.cpp | 20 +- tests/Subtyping.test.cpp | 3 +- tests/Transpiler.test.cpp | 4 +- tests/TypeFunction.test.cpp | 14 - tests/TypeFunction.user.test.cpp | 1007 ++++++++ tests/TypeInfer.aliases.test.cpp | 5 +- tests/TypeInfer.builtins.test.cpp | 14 + tests/TypeInfer.loops.test.cpp | 10 - tests/TypeInfer.modules.test.cpp | 78 + 52 files changed, 5034 insertions(+), 218 deletions(-) create mode 100644 Analysis/include/Luau/FragmentAutocomplete.h create mode 100644 Analysis/include/Luau/TypeFunctionRuntime.h create mode 100644 Analysis/include/Luau/TypeFunctionRuntimeBuilder.h create mode 100644 Analysis/src/FragmentAutocomplete.cpp create mode 100644 Analysis/src/TypeFunctionRuntime.cpp create mode 100644 Analysis/src/TypeFunctionRuntimeBuilder.cpp create mode 100644 tests/FragmentAutocomplete.test.cpp create mode 100644 tests/TypeFunction.user.test.cpp diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index c6b4a8288..4d38118af 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -12,6 +12,7 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeFwd.h" #include "Luau/Variant.h" @@ -62,6 +63,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull typeFunctionRuntime; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -111,6 +113,7 @@ struct ConstraintSolver explicit ConstraintSolver( NotNull normalizer, + NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, ModuleName moduleName, @@ -278,18 +281,18 @@ struct ConstraintSolver /** * @returns true if the TypeId is in a blocked state. */ - bool isBlocked(TypeId ty); + bool isBlocked(TypeId ty) const; /** * @returns true if the TypePackId is in a blocked state. */ - bool isBlocked(TypePackId tp); + bool isBlocked(TypePackId tp) const; /** * Returns whether the constraint is blocked on anything. * @param constraint the constraint to check. */ - bool isBlocked(NotNull constraint); + bool isBlocked(NotNull constraint) const; /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. @@ -381,8 +384,8 @@ struct ConstraintSolver TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); - void throwTimeLimitError(); - void throwUserCancelError(); + void throwTimeLimitError() const; + void throwUserCancelError() const; ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index baf3318c4..fe9d79248 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; }; +struct UserDefinedTypeFunctionError +{ + std::string message; + + bool operator==(const UserDefinedTypeFunctionError& rhs) const; +}; + using TypeErrorData = Variant< TypeMismatch, UnknownSymbol, @@ -496,7 +503,8 @@ using TypeErrorData = Variant< CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping, - ExplicitFunctionAnnotationRecommended>; + ExplicitFunctionAnnotationRecommended, + UserDefinedTypeFunctionError>; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h new file mode 100644 index 000000000..53e301c10 --- /dev/null +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -0,0 +1,23 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/Ast.h" + +#include + + +namespace Luau +{ + +struct FragmentAutocompleteAncestryResult +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; + std::vector ancestry; + AstStat* nearestStatement; +}; + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); + +} // namespace Luau diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h index 9a2974a5b..83a33215a 100644 --- a/Analysis/include/Luau/OverloadResolution.h +++ b/Analysis/include/Luau/OverloadResolution.h @@ -35,6 +35,7 @@ struct OverloadResolver NotNull builtinTypes, NotNull arena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull reporter, NotNull limits, @@ -44,6 +45,7 @@ struct OverloadResolver NotNull builtinTypes; NotNull arena; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull scope; NotNull ice; NotNull limits; @@ -109,6 +111,7 @@ SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter, NotNull limits, NotNull scope, diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 09f46c4df..1e7810560 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -135,6 +135,7 @@ struct Subtyping NotNull builtinTypes; NotNull arena; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull iceReporter; TypeCheckLimits limits; @@ -155,6 +156,7 @@ struct Subtyping NotNull builtinTypes, NotNull typeArena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter ); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 0faf036d8..e7db9411d 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -83,6 +83,7 @@ struct TypeChecker2 DenseHashSet seenTypeFunctionInstances{nullptr}; Normalizer normalizer; + TypeFunctionRuntime typeFunctionRuntime; Subtyping _subtyping; NotNull subtyping; diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index c686f4822..252b4c9af 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -1,10 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/ConstraintSolver.h" +#include "Luau/Constraint.h" #include "Luau/Error.h" #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunctionRuntime.h" #include "Luau/TypeFwd.h" #include @@ -16,14 +17,23 @@ namespace Luau struct TypeArena; struct TxnLog; +struct ConstraintSolver; class Normalizer; +struct TypeFunctionRuntime +{ + // For user-defined type functions, we store all generated types and packs for the duration of the typecheck + TypedAllocator typeArena; + TypedAllocator typePackArena; +}; + struct TypeFunctionContext { NotNull arena; NotNull builtins; NotNull scope; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull ice; NotNull limits; @@ -35,23 +45,14 @@ struct TypeFunctionContext std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional userFuncBody; // Body of the user-defined type function; only available for UDTFs - TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) - : arena(cs->arena) - , builtins(cs->builtinTypes) - , scope(scope) - , normalizer(cs->normalizer) - , ice(NotNull{&cs->iceReporter}) - , limits(NotNull{&cs->limits}) - , solver(cs.get()) - , constraint(constraint.get()) - { - } + TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint); TypeFunctionContext( NotNull arena, NotNull builtins, NotNull scope, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull ice, NotNull limits ) @@ -59,6 +60,7 @@ struct TypeFunctionContext , builtins(builtins) , scope(scope) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) , limits(limits) , solver(nullptr) @@ -66,7 +68,7 @@ struct TypeFunctionContext { } - NotNull pushConstraint(ConstraintV&& c); + NotNull pushConstraint(ConstraintV&& c) const; }; /// Represents a reduction result, which may have successfully reduced the type, @@ -88,6 +90,8 @@ struct TypeFunctionReductionResult /// Any type packs that need to be progressed or mutated before the /// reduction may proceed. std::vector blockedPacks; + /// A runtime error message from user-defined type functions + std::optional error; }; template diff --git a/Analysis/include/Luau/TypeFunctionRuntime.h b/Analysis/include/Luau/TypeFunctionRuntime.h new file mode 100644 index 000000000..eb5d19ee9 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntime.h @@ -0,0 +1,267 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Variant.h" + +#include +#include +#include +#include + +using lua_State = struct lua_State; + +namespace Luau +{ + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize); + +// Replica of types from Type.h +struct TypeFunctionType; +using TypeFunctionTypeId = const TypeFunctionType*; + +struct TypeFunctionTypePackVar; +using TypeFunctionTypePackId = const TypeFunctionTypePackVar*; + +struct TypeFunctionPrimitiveType +{ + enum Type + { + NilType, + Boolean, + Number, + String, + }; + + Type type; + + TypeFunctionPrimitiveType(Type type) + : type(type) + { + } +}; + +struct TypeFunctionBooleanSingleton +{ + bool value = false; +}; + +struct TypeFunctionStringSingleton +{ + std::string value; +}; + +using TypeFunctionSingletonVariant = Variant; + +struct TypeFunctionSingletonType +{ + TypeFunctionSingletonVariant variant; + + explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant) + : variant(std::move(variant)) + { + } +}; + +template +const T* get(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->variant) : nullptr; +} + +template +T* getMutable(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->variant) : nullptr; +} + +struct TypeFunctionUnionType +{ + std::vector components; +}; + +struct TypeFunctionIntersectionType +{ + std::vector components; +}; + +struct TypeFunctionAnyType +{ +}; + +struct TypeFunctionUnknownType +{ +}; + +struct TypeFunctionNeverType +{ +}; + +struct TypeFunctionNegationType +{ + TypeFunctionTypeId type; +}; + +struct TypeFunctionTypePack +{ + std::vector head; + std::optional tail; +}; + +struct TypeFunctionVariadicTypePack +{ + TypeFunctionTypeId type; +}; + +using TypeFunctionTypePackVariant = Variant; + +struct TypeFunctionTypePackVar +{ + TypeFunctionTypePackVariant type; + + TypeFunctionTypePackVar(TypeFunctionTypePackVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionTypePackVar& rhs) const; +}; + +struct TypeFunctionFunctionType +{ + TypeFunctionTypePackId argTypes; + TypeFunctionTypePackId retTypes; +}; + +template +const T* get(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->type) : nullptr; +} + +struct TypeFunctionTableIndexer +{ + TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType) + : keyType(keyType) + , valueType(valueType) + { + } + + TypeFunctionTypeId keyType; + TypeFunctionTypeId valueType; +}; + +struct TypeFunctionProperty +{ + static TypeFunctionProperty readonly(TypeFunctionTypeId ty); + static TypeFunctionProperty writeonly(TypeFunctionTypeId ty); + static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type. + static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type. + + bool isReadOnly() const; + bool isWriteOnly() const; + + std::optional readTy; + std::optional writeTy; +}; + +struct TypeFunctionTableType +{ + using Name = std::string; + using Props = std::unordered_map; + + Props props; + + std::optional indexer; + + // Should always be a TypeFunctionTableType + std::optional metatable; +}; + +struct TypeFunctionClassType +{ + using Name = std::string; + using Props = std::unordered_map; + + Props props; + + std::optional indexer; + + std::optional metatable; // metaclass? + + std::optional parent; + + std::string name; +}; + +using TypeFunctionTypeVariant = Luau::Variant< + TypeFunctionPrimitiveType, + TypeFunctionAnyType, + TypeFunctionUnknownType, + TypeFunctionNeverType, + TypeFunctionSingletonType, + TypeFunctionUnionType, + TypeFunctionIntersectionType, + TypeFunctionNegationType, + TypeFunctionFunctionType, + TypeFunctionTableType, + TypeFunctionClassType>; + +struct TypeFunctionType +{ + TypeFunctionTypeVariant type; + + TypeFunctionType(TypeFunctionTypeVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionType& rhs) const; +}; + +template +const T* get(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&const_cast(tv)->type) : nullptr; +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult); + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type); +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type); + +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type); + +bool isTypeUserData(lua_State* L, int idx); +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx); +std::optional optionalTypeUserData(lua_State* L, int idx); + +void registerTypeUserData(lua_State* L); + +void setTypeFunctionEnvironment(lua_State* L); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h new file mode 100644 index 000000000..c9e1152f9 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFunctionRuntime.h" + +namespace Luau +{ + +using Kind = Variant; + +template +const T* get(const Kind& kind) +{ + return get_if(&kind); +} + +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& tfkind) +{ + return get_if(&tfkind); +} + +struct TypeFunctionRuntimeBuilderState +{ + NotNull ctx; + + // Mapping of class name to ClassType + // Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function + // Using this invariant, whenever a ClassType is serialized, we can put it into this map + // whenever a ClassType is deserialized, we can use this map to return the corresponding value + DenseHashMap classesSerialized{{}}; + + // List of errors that occur during serialization/deserialization + // At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process + std::vector errors{}; + + TypeFunctionRuntimeBuilderState(NotNull ctx) + : ctx(ctx) + , classesSerialized({}) + , errors({}) + { + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state); +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state); + +} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 868e31f12..0cb14879a 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -149,13 +149,15 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T if (FFlag::LuauSolverV2) { + TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + if (FFlag::LuauAutocompleteNewSolverLimit) { unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; } - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}}; + Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; return subtyping.isSubtype(subTy, superTy, scope).isSubtype; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f7c4fb5eb..7db74cfbd 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -321,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor ConstraintSolver::ConstraintSolver( NotNull normalizer, + NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, ModuleName moduleName, @@ -332,11 +333,12 @@ ConstraintSolver::ConstraintSolver( : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) , moduleResolver(moduleResolver) - , requireCycles(requireCycles) + , requireCycles(std::move(requireCycles)) , logger(logger) , limits(std::move(limits)) { @@ -344,7 +346,7 @@ ConstraintSolver::ConstraintSolver( for (NotNull c : this->constraints) { - unsolvedConstraints.push_back(c); + unsolvedConstraints.emplace_back(c); // initialize the reference counts for the free types in this constraint. for (auto ty : c->getMaybeMutatedFreeTypes()) @@ -1240,7 +1242,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location + builtinTypes, + NotNull{arena}, + normalizer, + typeFunctionRuntime, + constraint->scope, + NotNull{&iceReporter}, + NotNull{&limits}, + constraint->location }; auto [status, overload] = resolver.selectOverload(fn, argsPack); TypeId overloadToUse = fn; @@ -1270,7 +1279,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulllocation, addition)); + upperBoundContributors[expanded].emplace_back(constraint->location, addition); } if (occursCheckPassed && c.callSite) @@ -1437,8 +1446,17 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; - shiftReferences(c.freeType, bindTo); - bind(constraint, c.freeType, bindTo); + if (DFInt::LuauTypeSolverRelease >= 645) + { + auto ty = follow(c.freeType); + shiftReferences(ty, bindTo); + bind(constraint, ty, bindTo); + } + else + { + shiftReferences(c.freeType, bindTo); + bind(constraint, c.freeType, bindTo); + } return true; } @@ -2603,7 +2621,7 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI for (const auto& [expanded, additions] : u2.expandedFreeTypes) { for (TypeId addition : additions) - upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); + upperBoundContributors[expanded].emplace_back(constraint->location, addition); } } else @@ -2820,7 +2838,7 @@ void ConstraintSolver::reproduceConstraints(NotNull scope, const Location } } -bool ConstraintSolver::isBlocked(TypeId ty) +bool ConstraintSolver::isBlocked(TypeId ty) const { ty = follow(ty); @@ -2830,7 +2848,7 @@ bool ConstraintSolver::isBlocked(TypeId ty) return nullptr != get(ty) || nullptr != get(ty); } -bool ConstraintSolver::isBlocked(TypePackId tp) +bool ConstraintSolver::isBlocked(TypePackId tp) const { tp = follow(tp); @@ -2840,7 +2858,7 @@ bool ConstraintSolver::isBlocked(TypePackId tp) return nullptr != get(tp); } -bool ConstraintSolver::isBlocked(NotNull constraint) +bool ConstraintSolver::isBlocked(NotNull constraint) const { auto blockedIt = blockedConstraints.find(constraint); return blockedIt != blockedConstraints.end() && blockedIt->second > 0; @@ -2851,7 +2869,7 @@ NotNull ConstraintSolver::pushConstraint(NotNull scope, const std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + unsolvedConstraints.emplace_back(borrow); return borrow; } @@ -2997,12 +3015,12 @@ TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) return arena->addTypePack(resultTypes, resultTail); } -LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() +LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() const { throw TimeLimitError(currentModuleName); } -LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() +LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() const { throw UserCancelError(currentModuleName); } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 60058d991..c91ce00d5 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -793,6 +793,11 @@ struct ErrorConverter return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); } + std::string operator()(const UserDefinedTypeFunctionError& e) const + { + return e.message; + } + std::string operator()(const CannotAssignToNever& e) const { std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; @@ -1175,6 +1180,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi return tp == rhs.tp; } +bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const +{ + return message == rhs.message; +} + bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const { if (cause.size() != rhs.cause.size()) @@ -1384,6 +1394,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) e.ty = clone(e.ty); else if constexpr (std::is_same_v) e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.rhsType = clone(e.rhsType); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp new file mode 100644 index 000000000..4088c500c --- /dev/null +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/FragmentAutocomplete.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" + +namespace Luau +{ + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos) +{ + std::vector ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos); + DenseHashMap localMap{AstName()}; + std::vector localStack; + AstStat* nearestStatement = nullptr; + for (AstNode* node : ancestry) + { + if (auto block = node->as()) + { + for (auto stat : block->body) + { + if (stat->location.begin <= cursorPos) + nearestStatement = stat; + if (stat->location.begin <= cursorPos) + { + // This statement precedes the current one + if (auto loc = stat->as()) + { + for (auto v : loc->vars) + { + localStack.push_back(v); + localMap[v->name] = v; + } + } + else if (auto locFun = stat->as()) + { + localStack.push_back(locFun->name); + localMap[locFun->name->name] = locFun->name; + } + } + } + } + } + + return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8c439181a..ca6277284 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1383,6 +1383,7 @@ ModulePtr check( unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; + TypeFunctionRuntime typeFunctionRuntime; ConstraintGenerator cg{ result, @@ -1402,6 +1403,7 @@ ModulePtr check( ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->name, diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index d209cb818..a79814ec6 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -9,6 +9,8 @@ #include "Luau/TypePack.h" #include "Luau/VisitType.h" +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + namespace Luau { @@ -871,6 +873,17 @@ struct TypeCacher : TypeOnceVisitor markUncacheable(tp); return false; } + + bool visit(TypePackId tp, const BoundTypePack& btp) override { + if (DFInt::LuauTypeSolverRelease >= 645) { + traverse(btp.boundTo); + if (isUncacheable(btp.boundTo)) + markUncacheable(tp); + return false; + } + return true; + } + }; std::optional generalize( diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index a3d8b4e34..64e059933 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; else if constexpr (std::is_same_v) stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; + else if constexpr (std::is_same_v) + stream << "UserDefinedTypeFunctionError { " << err.message << " }"; else if constexpr (std::is_same_v) { stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 3a0492169..564a3c353 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,6 +15,7 @@ #include LUAU_FASTFLAG(LuauSolverV2); +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau { @@ -131,10 +132,26 @@ struct ClonePublicInterface : Substitution } ftv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + ftv->scope = nullptr; } else if (TableType* ttv = getMutable(result)) { ttv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + ttv->scope = nullptr; + } + + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + { + if (auto freety = getMutable(result)) + { + freety->scope = nullptr; + } + else if (auto genericty = getMutable(result)) + { + genericty->scope = nullptr; + } } return result; diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 116cf5cb7..2131887a9 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -160,6 +160,7 @@ struct NonStrictTypeChecker NotNull arena; Module* module; Normalizer normalizer; + TypeFunctionRuntime typeFunctionRuntime; Subtyping subtyping; NotNull dfg; DenseHashSet noTypeFunctionErrors{nullptr}; @@ -182,7 +183,7 @@ struct NonStrictTypeChecker , arena(arena) , module(module) , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} - , subtyping{builtinTypes, arena, NotNull(&normalizer), ice} + , subtyping{builtinTypes, arena, NotNull(&normalizer), NotNull(&typeFunctionRuntime), ice} , dfg(dfg) , limits(limits) { @@ -228,7 +229,12 @@ struct NonStrictTypeChecker return instance; ErrorVec errors = - reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits}, + true + ) .errors; if (errors.empty()) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index c768f02c0..7ca57e617 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -3434,11 +3434,12 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, N UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime // Subtyping under DCR is not implemented using unification! if (FFlag::LuauSolverV2) { - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; + Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; return subtyping.isSubtype(subPack, superPack, scope).isSubtype; } diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index 972c9e3ac..fbcce2b7c 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -17,6 +17,7 @@ OverloadResolver::OverloadResolver( NotNull builtinTypes, NotNull arena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull reporter, NotNull limits, @@ -25,10 +26,11 @@ OverloadResolver::OverloadResolver( : builtinTypes(builtinTypes) , arena(arena) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , scope(scope) , ice(reporter) , limits(limits) - , subtyping({builtinTypes, arena, normalizer, ice}) + , subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice}) , callLoc(callLocation) { } @@ -199,8 +201,9 @@ std::pair OverloadResolver::checkOverload_ const std::vector* argExprs ) { - FunctionGraphReductionResult result = - reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); + FunctionGraphReductionResult result = reduceTypeFunctions( + fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true + ); if (!result.errors.empty()) return {OverloadIsNonviable, result.errors}; @@ -405,6 +408,7 @@ std::optional selectOverload( NotNull builtinTypes, NotNull arena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull iceReporter, NotNull limits, @@ -413,7 +417,7 @@ std::optional selectOverload( TypePackId argsPack ) { - OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; + OverloadResolver resolver{builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location}; auto [status, overload] = resolver.selectOverload(fn, argsPack); if (status == OverloadResolver::Analysis::Ok) @@ -429,6 +433,7 @@ SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter, NotNull limits, NotNull scope, @@ -437,7 +442,8 @@ SolveResult solveFunctionCall( TypePackId argsPack ) { - std::optional overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); + std::optional overloadToUse = + selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); if (!overloadToUse) return {SolveResult::NoMatchingOverload}; diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index b13a2327b..f8347c720 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -440,11 +440,13 @@ Subtyping::Subtyping( NotNull builtinTypes, NotNull typeArena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter ) : builtinTypes(builtinTypes) , arena(typeArena) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , iceReporter(iceReporter) { } @@ -1911,7 +1913,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope) { - TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; + TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; TypeId function = arena->addType(*functionInstance); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); ErrorVec errors; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f08508354..66d037ed2 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1040,6 +1040,7 @@ struct TypeStringifier state.emit(tfitv.userFuncName->value); else state.emit(tfitv.function->name); + state.emit("<"); bool comma = false; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index ed66453dd..3dc708a2d 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -31,6 +31,7 @@ #include LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau @@ -306,7 +307,7 @@ TypeChecker2::TypeChecker2( , sourceModule(sourceModule) , module(module) , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} - , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{unifierState->iceHandler}} , subtyping(&_subtyping) { } @@ -484,13 +485,16 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l return instance; seenTypeFunctionInstances.insert(instance); - ErrorVec errors = reduceTypeFunctions( - instance, - location, - TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, - true - ) - .errors; + ErrorVec errors = + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{ + NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits + }, + true + ) + .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); return instance; @@ -1194,8 +1198,8 @@ void TypeChecker2::visit(AstStatTypeAlias* stat) void TypeChecker2::visit(AstStatTypeFunction* stat) { // TODO: add type checking for user-defined type functions - - reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); + if (!FFlag::LuauUserDefinedTypeFunctions) + reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); } void TypeChecker2::visit(AstTypeList types) @@ -1446,6 +1450,7 @@ void TypeChecker2::visitCall(AstExprCall* call) builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, NotNull{stack.back()}, ice, limits, diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 31154cc24..6d928faaa 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -2,7 +2,9 @@ #include "Luau/TypeFunction.h" +#include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/Compiler.h" #include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Instantiation.h" @@ -12,17 +14,25 @@ #include "Luau/Set.h" #include "Luau/Simplify.h" #include "Luau/Subtyping.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypeFunctionRuntimeBuilder.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" #include "Luau/VecDeque.h" #include "Luau/VisitType.h" +#include "lua.h" +#include "lualib.h" + #include +#include +#include // used to control emitting CodeTooComplex warnings on type function reduction LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -35,7 +45,8 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); -LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) @@ -166,7 +177,7 @@ struct TypeFunctionReducer return SkipTestResult::Okay; } - SkipTestResult testForSkippability(TypePackId ty) + SkipTestResult testForSkippability(TypePackId ty) const { ty = follow(ty); @@ -214,15 +225,18 @@ struct TypeFunctionReducer { irreducible.insert(subject); + if (reduction.error.has_value()) + result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error}); + if (reduction.uninhabited || force) { if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); if constexpr (std::is_same_v) - result.errors.push_back(TypeError{location, UninhabitedTypeFunction{subject}}); + result.errors.emplace_back(location, UninhabitedTypeFunction{subject}); else if constexpr (std::is_same_v) - result.errors.push_back(TypeError{location, UninhabitedTypePackFunction{subject}}); + result.errors.emplace_back(location, UninhabitedTypePackFunction{subject}); } else if (!reduction.uninhabited && !force) { @@ -243,7 +257,7 @@ struct TypeFunctionReducer } } - bool done() + bool done() const { return queuedTys.empty() && queuedTps.empty(); } @@ -422,7 +436,7 @@ static FunctionGraphReductionResult reduceFunctionsInternal( ++iterationCount; if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) { - reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}}); + reducer.result.errors.emplace_back(location, CodeTooComplex{}); break; } } @@ -506,7 +520,7 @@ static std::optional> tryDistributeTypeFunct size_t cartesianProductSize = 1; const UnionType* firstUnion = nullptr; - size_t unionIndex; + size_t unionIndex = 0; std::vector arguments = typeParams; for (size_t i = 0; i < arguments.size(); ++i) @@ -572,6 +586,8 @@ static std::optional> tryDistributeTypeFunct return std::nullopt; } +using StateRef = std::unique_ptr; + TypeFunctionReductionResult userDefinedTypeFunction( TypeId instance, const std::vector& typeParams, @@ -585,9 +601,122 @@ TypeFunctionReductionResult userDefinedTypeFunction( return {std::nullopt, true, {}, {}}; } - // TODO: implementation of user-defined type functions goes here + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + + // block if we need to + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + } + + AstName name = *ctx->userFuncName; + AstExprFunction* function = *ctx->userFuncBody; + + // Construct ParseResult containing the type function + Allocator allocator; + AstNameTable names(allocator); + + AstExprGlobal globalName{Location{}, name}; + AstStatFunction typeFunction{Location{}, &globalName, function}; + AstStat* stmtArray[] = {&typeFunction}; + AstArray stmts{stmtArray, 1}; + AstStatBlock exec{Location{}, stmts}; + ParseResult parseResult{&exec, 1}; + + BytecodeBuilder builder; + try + { + compileOrThrow(builder, parseResult, names); + } + catch (CompileError& e) + { + std::string errMsg = format("'%s' type function failed to compile with error message: %s", name.value, e.what()); + return {std::nullopt, true, {}, {}, errMsg}; + } + + std::string bytecode = builder.getBytecode(); + + // Initialize Lua state + StateRef globalState(lua_newstate(typeFunctionAlloc, nullptr), lua_close); + lua_State* L = globalState.get(); + + lua_setthreaddata(L, ctx.get()); + + setTypeFunctionEnvironment(L); + + // Register type userdata + registerTypeUserData(L); + + luaL_sandbox(L); + luaL_sandboxthread(L); + + // Load bytecode into Luau state + if (auto error = checkResultForError(L, name.value, luau_load(L, name.value, bytecode.data(), bytecode.size(), 0))) + return {std::nullopt, true, {}, {}, error}; + + // Execute the loaded chunk to register the function in the global environment + if (auto error = checkResultForError(L, name.value, lua_pcall(L, 0, 0, 0))) + return {std::nullopt, true, {}, {}, error}; + + // Get type function from the global environment + lua_getglobal(L, name.value); + if (!lua_isfunction(L, -1)) + { + std::string errMsg = format("Could not find '%s' type function in the global scope", name.value); + + return {std::nullopt, true, {}, {}, errMsg}; + } + + // Push serialized arguments onto the stack + + // Since there aren't any new class types being created in type functions, there isn't a deserialization function + // class types. Instead, we can keep this map and return the mapping as the "deserialized value" + std::unique_ptr runtimeBuilder = std::make_unique(ctx); + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + // This is checked at the top of the function, and should still be true. + LUAU_ASSERT(!isPending(ty, ctx->solver)); + + TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get()); + // Check if there were any errors while serializing + if (runtimeBuilder->errors.size() != 0) + return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + + allocTypeUserData(L, serializedTy->type); + } + + // Set up an interrupt handler for type functions to respect type checking limits and LSP cancellation requests. + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + auto ctx = static_cast(lua_getthreaddata(lua_mainthread(L))); + if (ctx->limits->finishTime && TimeTrace::getClock() > *ctx->limits->finishTime) + ctx->solver->throwTimeLimitError(); + + if (ctx->limits->cancellationToken && ctx->limits->cancellationToken->requested()) + ctx->solver->throwUserCancelError(); + }; + + if (auto error = checkResultForError(L, name.value, lua_resume(L, nullptr, int(typeParams.size())))) + return {std::nullopt, true, {}, {}, error}; + + // If the return value is not a type userdata, return with error message + if (!isTypeUserData(L, 1)) + return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)}; + + TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1); - return {std::nullopt, true, {}, {}}; + // No errors should be present here since we should've returned already if any were raised during serialization. + LUAU_ASSERT(runtimeBuilder->errors.size() == 0); + + TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); + + // At least 1 error occured while deserializing + if (runtimeBuilder->errors.size() > 0) + return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + + return {retTypeId, false, {}, {}}; } TypeFunctionReductionResult notTypeFunction( @@ -711,7 +840,7 @@ TypeFunctionReductionResult lenTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -808,7 +937,7 @@ TypeFunctionReductionResult unmTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -818,7 +947,20 @@ TypeFunctionReductionResult unmTypeFunction( return {std::nullopt, true, {}, {}}; } -NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) +TypeFunctionContext::TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) + : arena(cs->arena) + , builtins(cs->builtinTypes) + , scope(scope) + , normalizer(cs->normalizer) + , typeFunctionRuntime(cs->typeFunctionRuntime) + , ice(NotNull{&cs->iceReporter}) + , limits(NotNull{&cs->limits}) + , solver(cs.get()) + , constraint(constraint.get()) +{ +} + +NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) const { LUAU_ASSERT(solver); NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); @@ -921,12 +1063,16 @@ TypeFunctionReductionResult numericBinopTypeFunction( SolveResult solveResult; if (!reversed) - solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + solveResult = solveFunctionCall( + ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ); else { TypePack* p = getMutable(argPack); std::swap(p->head.front(), p->head.back()); - solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + solveResult = solveFunctionCall( + ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ); } if (!solveResult.typePackId.has_value()) @@ -1156,7 +1302,7 @@ TypeFunctionReductionResult concatTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -1410,7 +1556,7 @@ static TypeFunctionReductionResult comparisonTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -1554,7 +1700,7 @@ TypeFunctionReductionResult eqTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -2004,7 +2150,7 @@ TypeFunctionReductionResult keyofFunctionImpl( if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx)) continue; - for (auto key : keys) + for (auto& key : keys) { // remove any keys that are not present in each class if (!localKeys.contains(key)) @@ -2039,7 +2185,7 @@ TypeFunctionReductionResult keyofFunctionImpl( if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx)) continue; - for (auto key : keys) + for (auto& key : keys) { // remove any keys that are not present in each table if (!localKeys.contains(key)) @@ -2239,7 +2385,7 @@ TypeFunctionReductionResult indexFunctionImpl( return {std::nullopt, true, {}, {}}; // indexer can be a union —> break them down into a vector - const std::vector* typesToFind; + const std::vector* typesToFind = nullptr; const std::vector singleType{indexerTy}; if (auto unionTy = get(indexerTy)) typesToFind = &unionTy->options; diff --git a/Analysis/src/TypeFunctionReductionGuesser.cpp b/Analysis/src/TypeFunctionReductionGuesser.cpp index d4a7c7c0a..389a797d7 100644 --- a/Analysis/src/TypeFunctionReductionGuesser.cpp +++ b/Analysis/src/TypeFunctionReductionGuesser.cpp @@ -3,6 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Normalize.h" +#include "Luau/ToString.h" #include "Luau/TypeFunction.h" #include "Luau/Type.h" #include "Luau/TypePack.h" diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp new file mode 100644 index 000000000..d3a33d077 --- /dev/null +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -0,0 +1,2192 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntime.h" + +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/TypeFunction.h" + +#include "lua.h" +#include "lualib.h" + +#include +#include +#include + +// defined in TypeFunctionRuntimeBuilder.cpp +LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit); + +namespace Luau +{ + +constexpr int kTypeUserdataTag = 42; + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + ::operator delete(ptr); + return nullptr; + } + else if (osize == 0) + { + return ::operator new(nsize); + } + else + { + void* data = ::operator new(nsize); + memcpy(data, ptr, nsize < osize ? nsize : osize); + + ::operator delete(ptr); + + return data; + } +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult) +{ + switch (luaResult) + { + case LUA_OK: + return std::nullopt; + case LUA_YIELD: + case LUA_BREAK: + return format("'%s' type function errored: unexpected yield or break", typeFunctionName); + default: + if (!lua_gettop(L)) + return format("'%s' type function errored unexpectedly", typeFunctionName); + + if (lua_isstring(L, -1)) + return format("'%s' type function errored at runtime: %s", typeFunctionName, lua_tostring(L, -1)); + + return format("'%s' type function errored at runtime: raised an error of type %s", typeFunctionName, lua_typename(L, -1)); + } +} + +static const TypeFunctionContext* getTypeFunctionContext(lua_State* L) +{ + return static_cast(lua_getthreaddata(lua_mainthread(L))); +} + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type) +{ + auto ctx = getTypeFunctionContext(L); + return ctx->typeFunctionRuntime->typeArena.allocate(std::move(type)); +} + +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type) +{ + auto ctx = getTypeFunctionContext(L); + return ctx->typeFunctionRuntime->typePackArena.allocate(std::move(type)); +} + +// Pushes a new type userdata onto the stack +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type) +{ + // allocate a new type userdata + TypeFunctionTypeId* ptr = static_cast(lua_newuserdatatagged(L, sizeof(TypeFunctionTypeId), kTypeUserdataTag)); + *ptr = allocateTypeFunctionType(L, std::move(type)); + + // set the new userdata's metatable to type metatable + luaL_getmetatable(L, "type"); + lua_setmetatable(L, -2); +} + +void deallocTypeUserData(lua_State* L, void* data) +{ + // only non-owning pointers into an arena is stored +} + +bool isTypeUserData(lua_State* L, int idx) +{ + if (!lua_isuserdata(L, idx)) + return false; + + return lua_touserdatatagged(L, idx, kTypeUserdataTag) != nullptr; +} + +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx) +{ + if (auto typ = static_cast(lua_touserdatatagged(L, idx, kTypeUserdataTag))) + return *typ; + + luaL_typeerrorL(L, idx, "type"); +} + +std::optional optionalTypeUserData(lua_State* L, int idx) +{ + if (lua_isnoneornil(L, idx)) + return std::nullopt; + else + return getTypeUserData(L, idx); +} + +// returns a string tag of TypeFunctionTypeId +static std::string getTag(lua_State* L, TypeFunctionTypeId ty) +{ + if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::NilType) + return "nil"; + else if (auto b = get(ty); b && b->type == TypeFunctionPrimitiveType::Type::Boolean) + return "boolean"; + else if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::Number) + return "number"; + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) + return "string"; + else if (get(ty)) + return "unknown"; + else if (get(ty)) + return "never"; + else if (get(ty)) + return "any"; + else if (auto s = get(ty)) + return "singleton"; + else if (get(ty)) + return "negation"; + else if (get(ty)) + return "union"; + else if (get(ty)) + return "intersection"; + else if (get(ty)) + return "table"; + else if (get(ty)) + return "function"; + else if (get(ty)) + return "class"; + + LUAU_UNREACHABLE(); + luaL_error(L, "VM encountered unexpected type variant when determining tag"); +} + +// Luau: `type.unknown` +// Returns the type instance representing unknown +static int createUnknown(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionUnknownType{}); + + return 1; +} + +// Luau: `type.never` +// Returns the type instance representing never +static int createNever(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionNeverType{}); + + return 1; +} + +// Luau: `type.any` +// Returns the type instance representing any +static int createAny(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionAnyType{}); + + return 1; +} + +// Luau: `type.boolean` +// Returns the type instance representing boolean +static int createBoolean(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Boolean}); + + return 1; +} + +// Luau: `type.number` +// Returns the type instance representing number +static int createNumber(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Number}); + + return 1; +} + +// Luau: `type.string` +// Returns the type instance representing string +static int createString(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::String}); + + return 1; +} + +// Luau: `type.singleton(value: string | boolean | nil) -> type` +// Returns the type instance representing string or boolean singleton or nil +static int createSingleton(lua_State* L) +{ + if (lua_isboolean(L, 1)) // Create boolean singleton + { + bool value = luaL_checkboolean(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionBooleanSingleton{value}}); + + return 1; + } + + // n.b. we cannot use lua_isstring here because lua committed the cardinal sin of calling a number a string + if (lua_type(L, 1) == LUA_TSTRING) // Create string singleton + { + const char* value = luaL_checkstring(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{value}}); + + return 1; + } + + if (lua_isnil(L, 1)) + { + allocTypeUserData(L, TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + + return 1; + } + + luaL_error(L, "types.singleton: can't create singleton from `%s` type", lua_typename(L, 1)); +} + +// Luau: `self:value() -> type` +// Returns the value of a singleton +static int getSingletonValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.value: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfpt = get(self)) + { + if (tfpt->type != TypeFunctionPrimitiveType::NilType) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + lua_pushnil(L); + return 1; + } + + auto tfst = get(self); + if (!tfst) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + if (auto tfbst = get(tfst)) + { + lua_pushboolean(L, tfbst->value); + return 1; + } + + if (auto tfsst = get(tfst)) + { + lua_pushlstring(L, tfsst->value.c_str(), tfsst->value.length()); + return 1; + } + + luaL_error(L, "type.value: can't call `value` method on `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.unionof(...: type) -> type` +// Returns the type instance representing union +static int createUnion(lua_State* L) +{ + // get the number of arguments for union + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.unionof: expected at least 2 types to union, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionUnionType{components}); + + return 1; +} + +// Luau: `types.intersectionof(...: type) -> type` +// Returns the type instance representing intersection +static int createIntersection(lua_State* L) +{ + // get the number of arguments for intersection + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.intersectionof: expected at least 2 types to intersection, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionIntersectionType{components}); + + return 1; +} + +// Luau: `self:components() -> {type}` +// Returns the components of union or intersection +static int getComponents(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.components: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfut = get(self); + if (tfut) + { + int argSize = int(tfut->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfut->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + auto tfit = get(self); + if (tfit) + { + int argSize = int(tfit->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfit->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + luaL_error(L, "type.components: cannot call components of `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.negationof(arg: type) -> type` +// Returns the type instance representing negation +static int createNegation(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.negationof: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + if (get(arg) || get(arg)) + luaL_error(L, "types.negationof: cannot perform negation on `%s` type", getTag(L, arg).c_str()); + + allocTypeUserData(L, TypeFunctionNegationType{arg}); + + return 1; +} + +// Luau: `self:inner() -> type` +// Returns the type instance being negated +static int getNegatedValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfnt = get(self); !tfnt) + allocTypeUserData(L, tfnt->type->type); + else + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + + return 1; +} + +// Luau: `types.newtable(props: {[type]: type | { read: type, write: type }}?, indexer: {index: type, readresult: type, writeresult: type}?, +// metatable: type?) -> type` Returns the type instance representing table +static int createTable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "types.newtable: expected 0-3 arguments, but got %d", argumentCount); + + // Parse prop + TypeFunctionTableType::Props props{}; + if (lua_istable(L, 1)) + { + lua_pushnil(L); + while (lua_next(L, 1) != 0) + { + TypeFunctionTypeId key = getTypeUserData(L, -2); + + auto tfst = get(key); + if (!tfst) + luaL_error(L, "types.newtable: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "types.newtable: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (lua_istable(L, -1)) + { + lua_getfield(L, -1, "read"); + std::optional readTy; + if (!lua_isnil(L, -1)) + readTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, -1, "write"); + std::optional writeTy; + if (!lua_isnil(L, -1)) + writeTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + props[tfsst->value] = TypeFunctionProperty{readTy, writeTy}; + } + else + { + TypeFunctionTypeId value = getTypeUserData(L, -1); + props[tfsst->value] = TypeFunctionProperty::rw(value); + } + + lua_pop(L, 1); + } + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + // Parse indexer + std::optional indexer; + if (lua_istable(L, 2)) + { + // Parse keyType and valueType + lua_getfield(L, 2, "index"); + TypeFunctionTypeId keyType = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "readresult"); + TypeFunctionTypeId valueType = getTypeUserData(L, -1); + lua_pop(L, 1); + + indexer = TypeFunctionTableIndexer(keyType, valueType); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + // Parse metatable + std::optional metatable = optionalTypeUserData(L, 3); + if (metatable && !get(*metatable)) + luaL_error(L, "types.newtable: expected to be given a table type as a metatable, but got %s instead", getTag(L, *metatable).c_str()); + + allocTypeUserData(L, TypeFunctionTableType{props, indexer, metatable}); + return 1; +} + +// Luau: `self:setproperty(key: type, value: type?)` +// Sets the properties of a table +static int setTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + tftt->props.erase(tfsst->value); + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + tftt->props[tfsst->value] = TypeFunctionProperty::rw(value, value); + + return 0; +} + +// Luau: `self:setreadproperty(key: type, value: type?)` +// Sets the properties of a table +static int setReadTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreadproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setreadproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setreadproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setreadproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's read-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isReadOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the read type. + else if (iter != tftt->props.end()) + iter->second.readTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::readonly(value); + else + iter->second.readTy = value; + + return 0; +} + +// Luau: `self:setwriteproperty(key: type, value: type?)` +// Sets the properties of a table +static int setWriteTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setwriteproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setwriteproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setwriteproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setwriteproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's write-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isWriteOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the write type. + else if (iter != tftt->props.end()) + iter->second.writeTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::writeonly(value); + else + iter->second.writeTy = value; + + return 0; +} + +// Luau: `self:readproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int readTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.readproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.readproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.readproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.readproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.readTy) + allocTypeUserData(L, (*prop.readTy)->type); + else + luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); + + return 1; +} +// +// Luau: `self:writeproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int writeTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.writeproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.writeproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.writeproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.writeproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.writeTy) + allocTypeUserData(L, (*prop.writeTy)->type); + else + luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); + + return 1; +} + +// Luau: `self:setindexer(key: type, value: type)` +// Sets the indexer of the table +static int setTableIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 3) + luaL_error(L, "type.setindexer: expected 3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setindexer: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + TypeFunctionTypeId value = getTypeUserData(L, 3); + + tftt->indexer = TypeFunctionTableIndexer{key, value}; + + return 0; +} + +// Luau: `self:setreadindexer(key: type, value: type)` +// Sets the read indexer of the table +static int setTableReadIndexer(lua_State* L) +{ + luaL_error(L, "type.setreadindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setwriteindexer(key: type, value: type)` +// Sets the write indexer of the table +static int setTableWriteIndexer(lua_State* L) +{ + luaL_error(L, "type.setwriteindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setmetatable(arg: type)` +// Sets the metatable of the table +static int setTableMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.setmetatable: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setmetatable: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId arg = getTypeUserData(L, 2); + if (!get(arg)) + luaL_error(L, "type.setmetatable: expected the argument to be a table, but got %s instead", getTag(L, self).c_str()); + + tftt->metatable = arg; + + return 0; +} + +// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}) -> type` +// Returns the type instance representing a function +static int createFunction(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 2) + luaL_error(L, "types.newfunction: expected 0-2 arguments, but got %d", argumentCount); + + TypeFunctionTypePackId argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + if (lua_istable(L, 1)) + { + std::vector head{}; + lua_getfield(L, 1, "head"); + if (lua_istable(L, -1)) + { + int argSize = lua_objlen(L, -1); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + lua_pop(L, 1); // Pop the "head" field + + std::optional tail; + lua_getfield(L, 1, "tail"); + if (auto type = optionalTypeUserData(L, -1)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + lua_pop(L, 1); // Pop the "tail" field + + if (head.size() == 0 && tail.has_value()) + argTypes = *tail; + else + argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + TypeFunctionTypePackId retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + if (lua_istable(L, 2)) + { + std::vector head{}; + lua_getfield(L, 2, "head"); + if (lua_istable(L, -1)) + { + int argSize = lua_objlen(L, -1); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + lua_pop(L, 1); // Pop the "head" field + + std::optional tail; + lua_getfield(L, 2, "tail"); + if (auto type = optionalTypeUserData(L, -1)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + lua_pop(L, 1); // Pop the "tail" field + + if (head.size() == 0 && tail.has_value()) + retTypes = *tail; + else + retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + allocTypeUserData(L, TypeFunctionFunctionType{argTypes, retTypes}); + + return 1; +} + +// Luau: `self:setparameters(head: {type}?, tail: type?)` +// Sets the parameters of the function +static int setFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3 || argumentCount < 1) + luaL_error(L, "type.setparameters: expected 1-3, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack + tfft->argTypes = *tail; + else // Make argTypes a type pack + tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return 0; +} + +// Luau: `self:parameters() -> {head: {type}?, tail: type?}` +// Returns the parameters of the function +static int getFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parameters: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + if (auto tftp = get(tfft->argTypes)) + { + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) + { + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; + } + + if (auto tfvp = get(tfft->argTypes)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + + return 1; + } + + lua_createtable(L, 0, 0); + return 1; +} + +// Luau: `self:setreturns(head: {type}?, tail: type?)` +// Sets the returns of the function +static int setFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreturns: expected 1-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack + tfft->retTypes = *tail; + else // Make retTypes a type pack + tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return 0; +} + +// Luau: `self:returns() -> {head: {type}?, tail: type?}` +// Returns the returns of the function +static int getFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.returns: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + if (auto tftp = get(tfft->retTypes)) + { + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) + { + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; + } + + if (auto tfvp = get(tfft->retTypes)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + + return 1; + } + + lua_createtable(L, 0, 0); + return 1; +} + +// Luau: `self:parent() -> type` +// Returns the parent of a class type +static int getClassParent(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfct = get(self); + if (!tfct) + luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); + + // If the parent does not exist, we should return nil + if (!tfct->parent) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->parent)->type); + + return 1; +} + +// Luau: `self:properties() -> {[type]: { read: type?, write: type? }}` +// Returns the properties of a table or class type +static int getProps(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.properties: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + lua_createtable(L, int(tftt->props.size()), 0); + for (auto& [name, prop] : tftt->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + if (auto tfct = get(self)) + { + lua_createtable(L, int(tfct->props.size()), 0); + for (auto& [name, prop] : tfct->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + luaL_error(L, "type.properties: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:indexer() -> {index: type, readresult: type, writeresult: type}?` +// Returns the indexer of a table or class type +static int getIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.indexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + luaL_error(L, "type.indexer: self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:readindexer() -> {index: type, result: type}?` +// Returns the read indexer of a table or class type +static int getReadIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.readindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.readindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:writeindexer() -> {index: type, result: type}?` +// Returns the write indexer of a table or class type +static int getWriteIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.writeindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.writeindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:metatable() -> type?` +// Returns the metatable of a table or class type +static int getMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.metatable: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfmt = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfmt->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfmt->metatable)->type); + + return 1; + } + + if (auto tfct = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfct->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->metatable)->type); + + return 1; + } + + luaL_error(L, "type.metatable: expected self to be a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:is(arg: string) -> boolean` +// Returns true if given argument is a tag of self +static int checkTag(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.is: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + std::string arg = luaL_checkstring(L, 2); + + lua_pushboolean(L, getTag(L, self) == arg); + return 1; +} + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty); // Forward declaration + +// Luau: `types.copy(arg: string) -> type` +// Returns a deep copy of the argument +static int deepCopy(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.copy: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + + TypeFunctionTypeId copy = deepClone(getTypeFunctionContext(L)->typeFunctionRuntime, arg); + allocTypeUserData(L, copy->type); + return 1; +} + +// Luau: `self == arg -> boolean` +// Used to set the __eq metamethod +static int isEqualToType(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + TypeFunctionTypeId arg = getTypeUserData(L, 2); + + lua_pushboolean(L, *self == *arg); + return 1; +} + +// Register the type userdata +void registerTypeUserData(lua_State* L) +{ + // List of fields for type userdata + luaL_Reg typeUserdataFields[] = { + {"unknown", createUnknown}, + {"never", createNever}, + {"any", createAny}, + {"boolean", createBoolean}, + {"number", createNumber}, + {"string", createString}, + {nullptr, nullptr} + }; + + // List of methods for type userdata + luaL_Reg typeUserdataMethods[] = { + {"singleton", createSingleton}, + {"negationof", createNegation}, + {"unionof", createUnion}, + {"intersectionof", createIntersection}, + {"newtable", createTable}, + {"newfunction", createFunction}, + {"copy", deepCopy}, + + // Common methods + {"is", checkTag}, + + // Negation type methods + {"inner", getNegatedValue}, + + // Singleton type methods + {"value", getSingletonValue}, + + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, + + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, + + // Union and Intersection type methods + {"components", getComponents}, + + // Class type methods + {"parent", getClassParent}, + {"indexer", getIndexer}, + {nullptr, nullptr} + }; + + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); + + // Protect metatable from being fetched. + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); + + // Set type userdata metatable's __eq to type_equals() + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); + + // Set type userdata metatable's __index to itself + lua_pushvalue(L, -1); // Push a copy of type userdata metatable + lua_setfield(L, -2, "__index"); + + luaL_register(L, nullptr, typeUserdataMethods); + + // Set fields for type userdata + for (luaL_Reg* l = typeUserdataFields; l->name; l++) + { + l->func(L); + lua_setfield(L, -2, l->name); + } + + // Set types library as a global name "types" + lua_setglobal(L, "types"); + + // Sets up a destructor for the type userdata. + lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); +} + +// Used to redirect all the removed global functions to say "this function is unsupported" +int unsupportedFunction(lua_State* L) +{ + luaL_errorL(L, "this function is not supported in type functions"); + return 0; +} + +// Add libraries / globals for type function environment +void setTypeFunctionEnvironment(lua_State* L) +{ + // Register math library + luaopen_math(L); + lua_pop(L, 1); + + // Register table library + luaopen_table(L); + lua_pop(L, 1); + + // Register string library + luaopen_string(L); + lua_pop(L, 1); + + // Register bit32 library + luaopen_bit32(L); + lua_pop(L, 1); + + // Register utf8 library + luaopen_utf8(L); + lua_pop(L, 1); + + // Register buffer library + luaopen_buffer(L); + lua_pop(L, 1); + + // Register base library + luaopen_base(L); + lua_pop(L, 1); + + // Remove certain global functions from the base library + static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + for (auto& name : unavailableGlobals) + { + lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); + lua_setglobal(L, name.c_str()); + } +} + +/* + * Below are helper methods for __eq + * Same as one from Type.cpp + */ +using SeenSet = std::set>; +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs); +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs); + +bool seenSetContains(SeenSet& seen, const void* lhs, const void* rhs) +{ + if (lhs == rhs) + return true; + + auto p = std::make_pair(lhs, rhs); + if (seen.find(p) != seen.end()) + return true; + + seen.insert(p); + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionSingletonType& lhs, const TypeFunctionSingletonType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + { + const TypeFunctionBooleanSingleton* lp = get(&lhs); + const TypeFunctionBooleanSingleton* rp = get(&lhs); + if (lp && rp) + return lp->value == rp->value; + } + + { + const TypeFunctionStringSingleton* lp = get(&lhs); + const TypeFunctionStringSingleton* rp = get(&lhs); + if (lp && rp) + return lp->value == rp->value; + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionUnionType& lhs, const TypeFunctionUnionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionIntersectionType& lhs, const TypeFunctionIntersectionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionNegationType& lhs, const TypeFunctionNegationType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTableType& lhs, const TypeFunctionTableType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.props.size() != rhs.props.size()) + return false; + + if (bool(lhs.indexer) != bool(rhs.indexer)) + return false; + + if (lhs.indexer && rhs.indexer) + { + if (!areEqual(seen, *lhs.indexer->keyType, *rhs.indexer->keyType)) + return false; + + if (!areEqual(seen, *lhs.indexer->valueType, *rhs.indexer->valueType)) + return false; + } + + auto l = lhs.props.begin(); + auto r = rhs.props.begin(); + + while (l != lhs.props.end()) + { + if ((l->second.readTy && !r->second.readTy) || (!l->second.readTy && r->second.readTy)) + return false; + + if (l->second.readTy && r->second.readTy && !areEqual(seen, **(l->second.readTy), **(r->second.readTy))) + return false; + + if ((l->second.writeTy && !r->second.writeTy) || (!l->second.writeTy && r->second.writeTy)) + return false; + + if (l->second.writeTy && r->second.writeTy && !areEqual(seen, **(l->second.writeTy), **(r->second.writeTy))) + return false; + + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunctionFunctionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (bool(lhs.argTypes) != bool(rhs.argTypes)) + return false; + + if (lhs.argTypes && rhs.argTypes) + { + if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) + return false; + } + + if (bool(lhs.retTypes) != bool(rhs.retTypes)) + return false; + + if (lhs.retTypes && rhs.retTypes) + { + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) + return false; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionClassType& lhs, const TypeFunctionClassType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return lhs.name == rhs.name; +} + +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs) +{ + + if (lhs.type.index() != rhs.type.index()) + return false; + + { + const TypeFunctionPrimitiveType* lp = get(&lhs); + const TypeFunctionPrimitiveType* rp = get(&rhs); + if (lp && rp) + return lp->type == rp->type; + } + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + { + const TypeFunctionSingletonType* lf = get(&lhs); + const TypeFunctionSingletonType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionUnionType* lf = get(&lhs); + const TypeFunctionUnionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionIntersectionType* lf = get(&lhs); + const TypeFunctionIntersectionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionNegationType* lf = get(&lhs); + const TypeFunctionNegationType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionTableType* lt = get(&lhs); + const TypeFunctionTableType* rt = get(&rhs); + if (lt && rt) + return areEqual(seen, *lt, *rt); + } + + { + const TypeFunctionFunctionType* lf = get(&lhs); + const TypeFunctionFunctionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionClassType* lf = get(&lhs); + const TypeFunctionClassType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePack& lhs, const TypeFunctionTypePack& rhs) +{ + if (lhs.head.size() != rhs.head.size()) + return false; + + auto l = lhs.head.begin(); + auto r = rhs.head.begin(); + + while (l != lhs.head.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionVariadicTypePack& lhs, const TypeFunctionVariadicTypePack& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs) +{ + { + const TypeFunctionTypePack* lb = get(&lhs); + const TypeFunctionTypePack* rb = get(&rhs); + if (lb && rb) + return areEqual(seen, *lb, *rb); + } + + { + const TypeFunctionVariadicTypePack* lv = get(&lhs); + const TypeFunctionVariadicTypePack* rv = get(&rhs); + if (lv && rv) + return areEqual(seen, *lv, *rv); + } + + return false; +} + +bool TypeFunctionType::operator==(const TypeFunctionType& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + +bool TypeFunctionTypePackVar::operator==(const TypeFunctionTypePackVar& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + + +TypeFunctionProperty TypeFunctionProperty::readonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.readTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::writeonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.writeTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId ty) +{ + return TypeFunctionProperty::rw(ty, ty); +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId read, TypeFunctionTypeId write) +{ + TypeFunctionProperty p; + p.readTy = read; + p.writeTy = write; + return p; +} + +bool TypeFunctionProperty::isReadOnly() const +{ + return readTy && !writeTy; +} + +bool TypeFunctionProperty::isWriteOnly() const +{ + return writeTy && !readTy; +} + +/* + * Below is a helper class for type.copy() + * Forked version of Clone.cpp + */ +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& kind) +{ + return get_if(&kind); +} + +class TypeFunctionCloner +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been cloned, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be TypeFunctionPrimitiveType; `second` is trying to copy `first` + std::vector> queue; + + SeenTypes types{{}}; // Mapping of TypeFunctionTypeIds that have been shallow cloned to TypeFunctionTypeIds + SeenTypePacks packs{{}}; // Mapping of TypeFunctionTypePackIds that have been shallow cloned to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionCloner(TypeFunctionRuntime* typeFunctionRuntime) + : typeFunctionRuntime(typeFunctionRuntime) + { + } + + TypeFunctionTypeId clone(TypeFunctionTypeId ty) + { + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId clone(TypeFunctionTypePackId tp) + { + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + return steps + queue.size() >= (size_t)DFInt::LuauTypeFunctionSerdeIterationLimit; + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + cloneChildren(ty, tfti); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowClone(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case TypeFunctionPrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case TypeFunctionPrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + default: + break; + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + target = ty; // Don't copy a class since they are immutable + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowClone(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void cloneChildren(TypeFunctionTypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + cloneChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + cloneChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + cloneChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + cloneChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + cloneChildren(t1, t2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + cloneChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + cloneChildren(c1, c2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionTypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + cloneChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + cloneChildren(vPack1, vPack2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionKind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + cloneChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + cloneChildren(*tp, *tftp); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionPrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void cloneChildren(TypeFunctionNeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void cloneChildren(TypeFunctionAnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void cloneChildren(TypeFunctionSingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeFunctionTypeId& ty : u1->components) + u2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionIntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeFunctionTypeId& ty : i1->components) + i2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionNegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowClone(n1->type); + } + + void cloneChildren(TypeFunctionTableType* t1, TypeFunctionTableType* t2) + { + for (auto& [k, p] : t1->props) + { + std::optional readTy; + if (p.readTy) + readTy = shallowClone(*p.readTy); + + std::optional writeTy; + if (p.writeTy) + writeTy = shallowClone(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer.has_value()) + t2->indexer = TypeFunctionTableIndexer(shallowClone(t1->indexer->keyType), shallowClone(t1->indexer->valueType)); + + if (t1->metatable.has_value()) + t2->metatable = shallowClone(*t1->metatable); + } + + void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->argTypes = shallowClone(f1->argTypes); + f2->retTypes = shallowClone(f1->retTypes); + } + + void cloneChildren(TypeFunctionClassType* c1, TypeFunctionClassType* c2) + { + // noop. + } + + void cloneChildren(TypeFunctionTypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeFunctionTypeId& ty : t1->head) + t2->head.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowClone(v1->type); + } +}; + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty) +{ + return TypeFunctionCloner(runtime).clone(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp new file mode 100644 index 000000000..e14c37739 --- /dev/null +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -0,0 +1,788 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntimeBuilder.h" + +#include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypePack.h" +#include "Luau/ToString.h" + +#include + +// used to control the recursion limit of any operations done by user-defined type functions +// currently, controls serialization, deserialization, and `type.copy` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); + +namespace Luau +{ + +// Forked version of Clone.cpp +class TypeFunctionSerializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been serialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is PrimitiveType, + // second must be TypeFunctionPrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeIds that have been shallow serialized to TypeFunctionTypeIds + SeenTypePacks packs; // Mapping of TypePackIds that have been shallow serialized to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionSerializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}) + { + } + + TypeFunctionTypeId serialize(TypeId ty) + { + shallowSerialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId serialize(TypePackId tp) + { + shallowSerialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + serializeChildren(ty, tfti); + } + } + + std::optional find(TypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(Kind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionRuntimeSerializer"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowSerialize(TypeId ty) + { + ty = follow(ty); + + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case PrimitiveType::Type::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case PrimitiveType::Type::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case PrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case PrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + case PrimitiveType::Thread: + case PrimitiveType::Function: + case PrimitiveType::Table: + case PrimitiveType::Buffer: + default: + { + std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + else + { + std::string error = format("Argument of singleton type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto m = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + { + state->classesSerialized[c->name] = ty; + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name}); + } + else + { + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowSerialize(TypePackId tp) + { + tp = follow(tp); + + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else + { + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void serializeChildren(TypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + serializeChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + serializeChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + serializeChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + serializeChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + serializeChildren(t1, t2); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; m1 && m2) + serializeChildren(m1, m2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + serializeChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + serializeChildren(c1, c2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(TypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + serializeChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + serializeChildren(vPack1, vPack2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(Kind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + serializeChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + serializeChildren(*tp, *tftp); + else + state->ctx->ice->ice("Serializing user defined type function arguments: kind and tfkind do not represent the same type"); + } + + void serializeChildren(PrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void serializeChildren(UnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void serializeChildren(NeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void serializeChildren(AnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void serializeChildren(SingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void serializeChildren(UnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeId& ty : u1->options) + u2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(IntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeId& ty : i1->parts) + i2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(NegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowSerialize(n1->ty); + } + + void serializeChildren(TableType* t1, TypeFunctionTableType* t2) + { + for (const auto& [k, p] : t1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer) + t2->indexer = TypeFunctionTableIndexer(shallowSerialize(t1->indexer->indexType), shallowSerialize(t1->indexer->indexResultType)); + } + + void serializeChildren(MetatableType* m1, TypeFunctionTableType* m2) + { + auto tmpTable = get(shallowSerialize(m1->table)); + if (!tmpTable) + state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType"); + + m2->props = tmpTable->props; + m2->indexer = tmpTable->indexer; + + m2->metatable = shallowSerialize(m1->metatable); + } + + void serializeChildren(FunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->argTypes = shallowSerialize(f1->argTypes); + f2->retTypes = shallowSerialize(f1->retTypes); + } + + void serializeChildren(ClassType* c1, TypeFunctionClassType* c2) + { + for (const auto& [k, p] : c1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + c2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (c1->indexer) + c2->indexer = TypeFunctionTableIndexer(shallowSerialize(c1->indexer->indexType), shallowSerialize(c1->indexer->indexResultType)); + + if (c1->metatable) + c2->metatable = shallowSerialize(*c1->metatable); + + if (c1->parent) + c2->parent = shallowSerialize(*c1->parent); + } + + void serializeChildren(TypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeId& ty : t1->head) + t2->head.push_back(shallowSerialize(ty)); + + if (t1->tail.has_value()) + t2->tail = shallowSerialize(*t1->tail); + } + + void serializeChildren(VariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowSerialize(v1->ty); + } +}; + +// Complete inverse of TypeFunctionSerializer +class TypeFunctionDeserializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeIds that have been deserialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be PrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds + SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds + + int steps = 0; + +public: + explicit TypeFunctionDeserializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}){}; + + TypeId deserialize(TypeFunctionTypeId ty) + { + shallowDeserialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypeId error = state->ctx->builtins->errorRecoveryType(); + types[ty] = error; + return error; + } + + return find(ty).value_or(state->ctx->builtins->errorRecoveryType()); + } + + TypePackId deserialize(TypeFunctionTypePackId tp) + { + shallowDeserialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypePackId error = state->ctx->builtins->errorRecoveryTypePack(); + packs[tp] = error; + return error; + } + + return find(tp).value_or(state->ctx->builtins->errorRecoveryTypePack()); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [tfti, ty] = queue.back(); + queue.pop_back(); + + deserializeChildren(tfti, ty); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionDeserializer"); + return std::nullopt; + } + } + + TypeId shallowDeserialize(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow deserialization + TypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = state->ctx->builtins->nilType; + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = state->ctx->builtins->booleanType; + break; + case TypeFunctionPrimitiveType::Type::Number: + target = state->ctx->builtins->numberType; + break; + case TypeFunctionPrimitiveType::Type::String: + target = state->ctx->builtins->stringType; + break; + default: + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + } + else if (auto u = get(ty)) + target = state->ctx->builtins->unknownType; + else if (auto n = get(ty)) + target = state->ctx->builtins->neverType; + else if (auto a = get(ty)) + target = state->ctx->builtins->anyType; + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = state->ctx->arena->addType(SingletonType{BooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = state->ctx->arena->addType(SingletonType{StringSingleton{ss->value}}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + else if (auto u = get(ty)) + target = state->ctx->arena->addTV(Type(UnionType{{}})); + else if (auto i = get(ty)) + target = state->ctx->arena->addTV(Type(IntersectionType{{}})); + else if (auto n = get(ty)) + target = state->ctx->arena->addType(NegationType{state->ctx->builtins->unknownType}); + else if (auto t = get(ty); t && !t->metatable.has_value()) + target = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + else if (auto m = get(ty); m && m->metatable.has_value()) + { + TypeId emptyTable = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + target = state->ctx->arena->addType(MetatableType{emptyTable, emptyTable}); + } + else if (auto f = get(ty)) + { + TypePackId emptyTypePack = state->ctx->arena->addTypePack(TypePack{}); + target = state->ctx->arena->addType(FunctionType{emptyTypePack, emptyTypePack, {}, false}); + } + else if (auto c = get(ty)) + { + if (auto result = state->classesSerialized.find(c->name)) + target = *result; + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized"); + } + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypePackId shallowDeserialize(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow deserialization + TypePackId target = {}; + if (auto tPack = get(tp)) + target = state->ctx->arena->addTypePack(TypePack{}); + else if (auto vPack = get(tp)) + target = state->ctx->arena->addTypePack(VariadicTypePack{}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void deserializeChildren(TypeFunctionTypeId tfti, TypeId ty) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + deserializeChildren(p2, p1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + deserializeChildren(a2, a1); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + deserializeChildren(s2, s1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + deserializeChildren(i2, i1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; + t1 && t2 && !t2->metatable.has_value()) + deserializeChildren(t2, t1); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; + m1 && m2 && m2->metatable.has_value()) + deserializeChildren(m2, m1); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + deserializeChildren(f2, f1); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + deserializeChildren(c2, c1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionTypePackId tftp, TypePackId tp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + deserializeChildren(tPack2, tPack1); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + deserializeChildren(vPack2, vPack1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionKind tfkind, Kind kind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + deserializeChildren(*tfty, *ty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + deserializeChildren(*tftp, *tp); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: tfkind and kind do not represent the same type"); + } + + void deserializeChildren(TypeFunctionPrimitiveType* p2, PrimitiveType* p1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnknownType* u2, UnknownType* u1) + { + // noop. + } + + void deserializeChildren(TypeFunctionNeverType* n2, NeverType* n1) + { + // noop. + } + + void deserializeChildren(TypeFunctionAnyType* a2, AnyType* a1) + { + // noop. + } + + void deserializeChildren(TypeFunctionSingletonType* s2, SingletonType* s1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnionType* u2, UnionType* u1) + { + for (TypeFunctionTypeId& ty : u2->components) + u1->options.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionIntersectionType* i2, IntersectionType* i1) + { + for (TypeFunctionTypeId& ty : i2->components) + i1->parts.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionNegationType* n2, NegationType* n1) + { + n1->ty = shallowDeserialize(n2->type); + } + + void deserializeChildren(TypeFunctionTableType* t2, TableType* t1) + { + for (const auto& [k, p] : t2->props) + { + if (p.readTy && p.writeTy) + t1->props[k] = Property::rw(shallowDeserialize(*p.readTy), shallowDeserialize(*p.writeTy)); + else if (p.readTy) + t1->props[k] = Property::readonly(shallowDeserialize(*p.readTy)); + else if (p.writeTy) + t1->props[k] = Property::writeonly(shallowDeserialize(*p.writeTy)); + } + + if (t2->indexer.has_value()) + t1->indexer = TableIndexer(shallowDeserialize(t2->indexer->keyType), shallowDeserialize(t2->indexer->valueType)); + } + + void deserializeChildren(TypeFunctionTableType* m2, MetatableType* m1) + { + TypeFunctionTypeId temp = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{m2->props, m2->indexer}); + m1->table = shallowDeserialize(temp); + + if (m2->metatable.has_value()) + m1->metatable = shallowDeserialize(*m2->metatable); + } + + void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) + { + if (f2->argTypes) + f1->argTypes = shallowDeserialize(f2->argTypes); + + if (f2->retTypes) + f1->retTypes = shallowDeserialize(f2->retTypes); + } + + void deserializeChildren(TypeFunctionClassType* c2, ClassType* c1) + { + // noop. + } + + void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1) + { + for (TypeFunctionTypeId& ty : t2->head) + t1->head.push_back(shallowDeserialize(ty)); + + if (t2->tail.has_value()) + t1->tail = shallowDeserialize(*t2->tail); + } + + void deserializeChildren(TypeFunctionVariadicTypePack* v2, VariadicTypePack* v1) + { + v1->ty = shallowDeserialize(v2->type); + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionSerializer(state).serialize(ty); +} + +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionDeserializer(state).deserialize(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6b2e861d2..7a7be71d7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,7 +33,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) -LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false) namespace Luau @@ -1284,19 +1283,10 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) - { - for (TypeId var : varTypes) - unify(unknownType, var, scope, forin.location); - } else { - TypeId varTy = errorRecoveryType(loopScope); - for (TypeId var : varTypes) - unify(varTy, var, scope, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); + unify(unknownType, var, scope, forin.location); } return check(loopScope, *forin.body); diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 01f2a74fa..804d16fca 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -1,6 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +#include + namespace Luau { @@ -12,10 +17,17 @@ enum class Mode Definition, // Type definition module, has special parsing rules }; +struct FragmentParseResumeSettings +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; +}; + struct ParseOptions { bool allowDeclarationSyntax = false; bool captureComments = false; + std::optional parseFragment = std::nullopt; }; } // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 4e49028a5..83d6eefda 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -452,4 +452,4 @@ class Parser std::string scratchData; }; -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index bd2ca86ba..2259f21ce 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -7,6 +7,7 @@ #include #include +#include LUAU_FASTFLAG(DebugLuauTimeTracing) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a5e1d40ea..545402150 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) - namespace Luau { @@ -434,13 +432,11 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; - if (FFlag::LuauLexerLookaheadRemembersBraceType) - { - if (braceStack.size() < currentBraceStackSize) - braceStack.push_back(currentBraceType); - else if (braceStack.size() > currentBraceStackSize) - braceStack.pop_back(); - } + + if (braceStack.size() < currentBraceStackSize) + braceStack.push_back(currentBraceType); + else if (braceStack.size() > currentBraceStackSize) + braceStack.pop_back(); return result; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4b9eddda9..44a40abf4 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -19,7 +19,8 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauSolverV2, false) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax, false) +LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing, false) namespace Luau { @@ -211,6 +212,15 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc scratchExpr.reserve(16); scratchLocal.reserve(16); scratchBinding.reserve(16); + + if (FFlag::LuauAllowFragmentParsing) + { + if (options.parseFragment) + { + localMap = options.parseFragment->localMap; + localStack = options.parseFragment->localStack; + } + } } bool Parser::blockFollow(const Lexeme& l) @@ -891,7 +901,7 @@ AstStat* Parser::parseReturn() AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // parsing a type function - if (FFlag::LuauUserDefinedTypeFunctions) + if (FFlag::LuauUserDefinedTypeFunctionsSyntax) { if (lexer.current().type == Lexeme::ReservedFunction) return parseTypeFunction(start); diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e8be59eb2..8bccffce2 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -3,6 +3,7 @@ #include "Luau/StringUtils.h" +#include #include #include diff --git a/CMakeLists.txt b/CMakeLists.txt index b18cd5c9d..c8053cc1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ target_link_libraries(Luau.Config PUBLIC Luau.Ast) target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config) +target_link_libraries(Luau.Analysis PRIVATE Luau.Compiler Luau.VM) target_compile_features(Luau.EqSat PUBLIC cxx_std_17) target_include_directories(Luau.EqSat PUBLIC EqSat/include) @@ -276,7 +277,7 @@ foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.Cod if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler") message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components") endif() - if(LIB MATCHES "Ast|Analysis|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") + if(LIB MATCHES "Ast|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components") endif() if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config") diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index cd73bcbbf..a63655ccc 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,8 +11,6 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenArmNumToVecFix, false) - namespace Luau { namespace CodeGen @@ -1121,7 +1119,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) else { RegisterA64 tempd = tempDouble(inst.a); - RegisterA64 temps = FFlag::LuauCodegenArmNumToVecFix ? regs.allocTemp(KindA64::s) : castReg(KindA64::s, tempd); + RegisterA64 temps = regs.allocTemp(KindA64::s); build.fcvt(temps, tempd); build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); diff --git a/Makefile b/Makefile index 3e6b85ad9..cb199de88 100644 --- a/Makefile +++ b/Makefile @@ -142,7 +142,7 @@ endif $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include $(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include -$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include +$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -ICompiler/include -IVM/include $(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include @@ -227,7 +227,7 @@ luau-tests: $(TESTS_TARGET) # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILER_TARGET) $(VM_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) diff --git a/Sources.cmake b/Sources.cmake index 80bcd5b25..103ea2806 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -182,6 +182,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h + Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/Frontend.h Analysis/include/Luau/Generalization.h Analysis/include/Luau/GlobalTypes.h @@ -223,6 +224,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeFunction.h Analysis/include/Luau/TypeFunctionReductionGuesser.h + Analysis/include/Luau/TypeFunctionRuntime.h + Analysis/include/Luau/TypeFunctionRuntimeBuilder.h Analysis/include/Luau/TypeFwd.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeOrPack.h @@ -253,6 +256,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Differ.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp + Analysis/src/FragmentAutocomplete.cpp Analysis/src/Frontend.cpp Analysis/src/Generalization.cpp Analysis/src/GlobalTypes.cpp @@ -287,6 +291,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypedAllocator.cpp Analysis/src/TypeFunction.cpp Analysis/src/TypeFunctionReductionGuesser.cpp + Analysis/src/TypeFunctionRuntime.cpp + Analysis/src/TypeFunctionRuntimeBuilder.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypeOrPack.cpp Analysis/src/TypePack.cpp @@ -440,6 +446,7 @@ if(TARGET Luau.UnitTest) tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h + tests/FragmentAutocomplete.test.cpp tests/Frontend.test.cpp tests/Generalization.test.cpp tests/InsertionOrderedMap.test.cpp @@ -474,6 +481,7 @@ if(TARGET Luau.UnitTest) tests/Transpiler.test.cpp tests/TxnLog.test.cpp tests/TypeFunction.test.cpp + tests/TypeFunction.user.test.cpp tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.anyerror.test.cpp diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 16775f9b9..f38ab80bf 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -10,8 +10,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauPreserveLudataRenaming, false) - // clang-format off const char* const luaT_typenames[] = { // ORDER TYPE @@ -124,74 +122,40 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) { - if (FFlag::LuauPreserveLudataRenaming) + // Userdata created by the environment can have a custom type name set in the individual metatable + // If there is no custom name, 'userdata' is returned + if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) { - // Userdata created by the environment can have a custom type name set in the individual metatable - // If there is no custom name, 'userdata' is returned - if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) - { - const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); - - return L->global->ttname[ttype(o)]; - } - - // Tagged lightuserdata can be named using lua_setlightuserdataname - if (ttislightuserdata(o)) - { - int tag = lightuserdatatag(o); - - if (unsigned(tag) < LUA_LUTAG_LIMIT) - { - if (const TString* name = L->global->lightuserdataname[tag]) - return name; - } - } + const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); - // For all types except userdata and table, a global metatable can be set with a global name override - if (Table* mt = L->global->mt[ttype(o)]) - { - const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); - } + if (ttisstring(type)) + return tsvalue(type); return L->global->ttname[ttype(o)]; } - else + + // Tagged lightuserdata can be named using lua_setlightuserdataname + if (ttislightuserdata(o)) { - if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) - { - const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); + int tag = lightuserdatatag(o); - if (ttisstring(type)) - return tsvalue(type); - } - else if (ttislightuserdata(o)) + if (unsigned(tag) < LUA_LUTAG_LIMIT) { - int tag = lightuserdatatag(o); - - if (unsigned(tag) < LUA_LUTAG_LIMIT) - { - const TString* name = L->global->lightuserdataname[tag]; - - if (name) - return name; - } + if (const TString* name = L->global->lightuserdataname[tag]) + return name; } - else if (Table* mt = L->global->mt[ttype(o)]) - { - const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); + } - if (ttisstring(type)) - return tsvalue(type); - } + // For all types except userdata and table, a global metatable can be set with a global name override + if (Table* mt = L->global->mt[ttype(o)]) + { + const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); - return L->global->ttname[ttype(o)]; + if (ttisstring(type)) + return tsvalue(type); } + + return L->global->ttname[ttype(o)]; } const char* luaT_objtypename(lua_State* L, const TValue* o) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 376caa44d..df6e53320 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -34,8 +34,6 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauNativeAttribute) -LUAU_FASTFLAG(LuauPreserveLudataRenaming) -LUAU_FASTFLAG(LuauCodegenArmNumToVecFix) static lua_CompileOptions defaultOptions() { @@ -825,8 +823,6 @@ TEST_CASE("Pack") TEST_CASE("Vector") { - ScopedFastFlag luauCodegenArmNumToVecFix{FFlag::LuauCodegenArmNumToVecFix, true}; - lua_CompileOptions copts = defaultOptions(); Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); @@ -2251,20 +2247,17 @@ TEST_CASE("LightuserdataApi") lua_pop(L, 1); - if (FFlag::LuauPreserveLudataRenaming) - { - // Still possible to rename the global lightuserdata name using a metatable - lua_pushlightuserdata(L, value); - CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); + // Still possible to rename the global lightuserdata name using a metatable + lua_pushlightuserdata(L, value); + CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); - lua_createtable(L, 0, 1); - lua_pushstring(L, "luserdata"); - lua_setfield(L, -2, "__type"); - lua_setmetatable(L, -2); + lua_createtable(L, 0, 1); + lua_pushstring(L, "luserdata"); + lua_setfield(L, -2, "__type"); + lua_setmetatable(L, -2); - CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); - lua_pop(L, 1); - } + CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); + lua_pop(L, 1); globalState.reset(); } diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index 7f168465b..f595d6ec1 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -42,7 +42,9 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code) void ConstraintGeneratorFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}}; + ConstraintSolver cs{ + NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {} + }; cs.run(); } diff --git a/tests/ConstraintGeneratorFixture.h b/tests/ConstraintGeneratorFixture.h index ff362be11..acf616e00 100644 --- a/tests/ConstraintGeneratorFixture.h +++ b/tests/ConstraintGeneratorFixture.h @@ -20,6 +20,7 @@ struct ConstraintGeneratorFixture : Fixture DcrLogger logger; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeFunctionRuntime typeFunctionRuntime; std::unique_ptr dfg; std::unique_ptr cg; diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp new file mode 100644 index 000000000..b8b7829da --- /dev/null +++ b/tests/FragmentAutocomplete.test.cpp @@ -0,0 +1,139 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/FragmentAutocomplete.h" +#include "Fixture.h" +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" + + +using namespace Luau; + +struct FragmentAutocompleteFixture : Fixture +{ + + FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) + { + ParseResult p = tryParse(source); // We don't care about parsing incomplete asts + REQUIRE(p.root); + return findAncestryForFragmentParse(p.root, cursorPos); + } +}; + +TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTest"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +)", + {2, 11} + ); + + CHECK_EQ(3, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("y", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_within_scope_tracks_locals_from_previous_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +)", + {4, 15} + ); + + CHECK_EQ(5, result.ancestry.size()); + CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("e", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("e", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_that_comes_later_shouldnt_capture_locals_in_unavailable_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +local z = x + x +if y == 5 then + local q = x + y + z +end +)", + {8, 23} + ); + + CHECK_EQ(6, result.ancestry.size()); + CHECK_EQ(4, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("q", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("q", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nearest_enclosing_statement_can_be_non_local") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then +)", + {3, 4} + ); + + CHECK_EQ(4, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); + + AstStatIf* ifS = result.nearestStatement->as(); + CHECK(ifS != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_funcs_show_up_in_local_stack") +{ + auto result = runAutocompleteVisitor( + R"( +local function foo() return 4 end +local x = foo() +local function bar() return x + foo() end +)", + {3, 32} + ); + + CHECK_EQ(8, result.ancestry.size()); + CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + CHECK_EQ("bar", std::string(result.localStack.back()->name.value)); + auto returnSt = result.nearestStatement->as(); + CHECK(returnSt != nullptr); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index dfcf0ded5..74d7a9207 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -3,6 +3,7 @@ #include "AstQueryDsl.h" #include "Fixture.h" +#include "Luau/Common.h" #include "ScopedFlags.h" #include "doctest.h" @@ -11,13 +12,12 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType); -LUAU_FASTINT(LuauRecursionLimit); -LUAU_FASTINT(LuauTypeLengthLimit); -LUAU_FASTINT(LuauParseErrorLimit); -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTINT(LuauTypeLengthLimit) +LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) namespace { @@ -2380,7 +2380,7 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; AstStat* stat = parse(R"( type function foo() @@ -3138,8 +3138,6 @@ TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") { - ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true}; - ParseResult result = tryParse(R"( local x = `{ {y} }` )"); @@ -3149,8 +3147,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2") { - ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true}; - ParseResult result = tryParse(R"( local x = `{ { y{} } }` )"); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index a59312ac6..05bea2f73 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -66,6 +66,7 @@ struct SubtypeFixture : Fixture InternalErrorReporter iceReporter; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeFunctionRuntime typeFunctionRuntime; ScopedFastFlag sff{FFlag::LuauSolverV2, true}; @@ -77,7 +78,7 @@ struct SubtypeFixture : Fixture Subtyping mkSubtyping() { - return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}}; + return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; } TypePackId pack(std::initializer_list tys) diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index edc2bf470..f6208c1b1 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,7 +12,7 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) TEST_SUITE_BEGIN("TranspilerTests"); @@ -698,7 +698,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index d3732d606..18d8f17b8 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -1247,18 +1247,4 @@ TEST_CASE_FIXTURE(ClassFixture, "rawget_type_function_errors_w_classes") CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); } -TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") -{ - if (!FFlag::LuauUserDefinedTypeFunctions) - return; - - CheckResult result = check(R"( - type function foo() - return nil - end - )"); - LUAU_CHECK_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == "This syntax is not supported"); -} - TEST_SUITE_END(); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp new file mode 100644 index 000000000..fbce4df2f --- /dev/null +++ b/tests/TypeFunction.user.test.cpp @@ -0,0 +1,1007 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ClassFixture.h" +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) + +TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_nil(arg) + return arg + end + type type_being_serialized = nil + local function ok(idx: serialize_nil): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnil() + local ty = types.singleton(nil) + if ty:is("nil") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnil<>): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_unknown(arg) + return arg + end + type type_being_serialized = unknown + local function ok(idx: serialize_unknown): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getunknown() + local ty = types.unknown + if ty:is("unknown") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getunknown<>): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_never(arg) + return arg + end + type type_being_serialized = never + local function ok(idx: serialize_never): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnever() + local ty = types.never + if ty:is("never") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnever<>): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_any(arg) + return arg + end + type type_being_serialized = any + local function ok(idx: serialize_any): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getany() + local ty = types.any + if ty:is("any") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getany<>): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_bool(arg) + return arg + end + type type_being_serialized = boolean + local function ok(idx: serialize_bool): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getboolean() + local ty = types.boolean + if ty:is("boolean") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolean<>): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_num(arg) + return arg + end + type type_being_serialized = number + local function ok(idx: serialize_num): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnumber() + local ty = types.number + if ty:is("number") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnumber<>): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_str(arg) + return arg + end + type type_being_serialized = string + local function ok(idx: serialize_str): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getstring() + local ty = types.string + if ty:is("string") then + return ty + end + -- this should never be returned + return types.boolean + end + local function ok(idx: getstring<>): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_boolsingleton(arg) + return arg + end + type type_being_serialized = true + local function ok(idx: serialize_boolsingleton): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getboolsingleton() + local ty = types.singleton(true) + if ty:is("singleton") and ty:value() then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolsingleton<>): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_strsingleton(arg) + return arg + end + type type_being_serialized = "popcorn and movies!" + local function ok(idx: serialize_strsingleton): "popcorn and movies!" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getstrsingleton() + local ty = types.singleton("hungry hippo") + if ty:is("singleton") and ty:value() == "hungry hippo" then + return ty + end + -- this should never be returned + return types.number + end + local function ok(idx: getstrsingleton<>): "hungry hippo" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_union(arg) + return arg + end + type type_being_serialized = number | string | boolean + -- forcing an error here to check the exact type of the union + local function ok(idx: serialize_union): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getunion() + local ty = types.unionof(types.string, types.number, types.boolean) + if ty:is("union") then + -- creating a copy of `ty` + local arr = {} + for _, value in ty:components() do + table.insert(arr, value) + end + return types.unionof(table.unpack(arr)) + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the union + local function ok(idx: getunion<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_intersection(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number } & { boolean: boolean, string: string } + -- forcing an error here to check the exact type of the intersection + local function ok(idx: serialize_intersection): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getintersection() + local tbl1 = types.newtable(nil, nil, nil) + tbl1:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl1:setproperty(types.singleton("number"), types.number) -- {boolean: boolean, number: number} + local tbl2 = types.newtable(nil, nil, nil) + tbl2:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl2:setproperty(types.singleton("string"), types.string) -- {boolean: boolean, string: string} + local ty = types.intersectionof(tbl1, tbl2) + if ty:is("intersection") then + -- creating a copy of `ty` + local arr = {} + for index, value in ty:components() do + table.insert(arr, value) + end + return types.intersectionof(table.unpack(arr)) + end + -- this should never be returned + return types.string + end + -- forcing an error here to check the exact type of the intersection + local function ok(idx: getintersection<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnegation() + local ty = types.negationof(types.string) + if ty:is("negation") then + return ty + end + -- this should never be returned + return types.number + end + + -- forcing an error here to check the exact type of the negation + local function ok(idx: getnegation<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "~string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_table(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number, [string]: number } + -- forcing an error here to check the exact type of the table + local function ok(idx: serialize_table): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [string]: number, boolean: boolean, number: number }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function gettable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number] = boolean} + ty:setproperty(types.singleton("number"), types.string) -- {string: number, number: string, [number] = boolean} + ty:setproperty(types.singleton("string"), nil) -- {number: string, [number] = boolean} + local ret = types.newtable(nil, nil, nil) -- {} + -- creating a copy of `ty` + for k, v in ty:properties() do + ret:setreadproperty(k, v.read) + ret:setwriteproperty(k, v.write) + end + if ret:is("table") then + ret:setindexer(types.boolean, types.string) -- {number: string, [boolean] = string} + return ret -- {number: string, [boolean] = string} + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the table + local function ok(idx: gettable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [boolean]: string, number: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getmetatable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + metatbl:setmetatable(types.newtable(nil, indexer, nil)) -- { { }, @metatable { [number]: boolean } } + local ret = metatbl:metatable() + if metatbl:is("table") and metatbl:metatable() then + return ret -- { @metatable { [number]: boolean } } + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getmetatable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{boolean}"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_func(arg) + return arg + end + type type_being_serialized = (boolean, number, nil) -> (...string) + local function ok(idx: serialize_func): (boolean, number, nil) -> (...string) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getfunction() + local ty = types.newfunction(nil, nil) -- () -> () + ty:setparameters({types.string, types.number}, nil) -- (string, number) -> () + ty:setreturns(nil, types.boolean) -- (string, number) -> (...boolean) + if ty:is("function") then + -- creating a copy of `ty` parameters + local arr = {} + for index, val in ty:parameters().head do + table.insert(arr, val) + end + return types.newfunction({head = arr}, ty:returns()) -- (string, number) -> (...boolean) + end + -- this should never be returned + return types.number + end + local function ok(idx: getfunction<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "(string, number) -> (...boolean)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_class(arg) + return arg + end + local function ok(idx: serialize_class): BaseClass return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local indexer = arg:indexer() + local metatable = arg:metatable() + return types.newtable(props, indexer, metatable) + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function checkmut() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(props, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if metatbl:is("table") and metatbl:metatable() then + return metatbl -- { @metatable { [number]: boolean }, { } } + end + -- this should never be returned + return types.number + end + local function ok(idx: checkmut<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable {boolean}, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getcopy() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metaty = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + local copy = types.copy(metaty) + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if copy:is("table") and copy:metatable() then + return copy -- { { }, @metatable { [number]: boolean, string: number } } + end + -- this should never be returned + return types.number + end + local function ok(idx: getcopy<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable { [number]: boolean, string: number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_cycle(arg) + return arg + end + type basety = { + first: basety2 + } + type basety2 = { + second: basety + } + local function ok(idx: serialize_cycle): basety return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function badmetatable() + return types.newtable(nil, nil, types.number) + end + local function bad(arg: badmetatable<>) end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'badmetatable' type function errored at runtime: [string \"badmetatable\"]:3: types.newtable: expected to be given a table " + "type as a metatable, but got number instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_cycle2(arg) + return arg + end + type Employee = { + name: string, + department: Department? + } + type Department = { + name: string, + manager: Employee?, + employees: { Employee }, + company: Company? + } + type Company = { + name: string, + departments: { Department } + } + local function ok(idx: serialize_cycle2): Company return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function errors_if_string(arg) + if arg:is("string") then + local a = 1 + error("We are in a math class! not english") + end + return arg + end + local function ok(idx: errors_if_string): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'errors_if_string' type function errored at runtime: [string \"errors_if_string\"]:5: We are in a math class! not english"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function hello(arg) + error(type(arg)) + end + local function ok(idx: hello): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: userdata"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function hello() + local p1 = types.string + local p2 = types.string + local t1 = types.newtable(nil, nil, nil) + t1:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + local t2 = types.newtable(nil, nil, nil) + t2:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + if p1 == p2 and t1 == t2 then + return types.number + end + end + local function ok(idx: hello<>): number return idx end + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function hello(arg) + local arr = arg:properties() + end + local function ok(idx: hello<() -> ()>): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: type.properties: expected self to be either a table or class, " + "but got function instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_cannot_call_other") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function foo() + return "hi" + end + local x = true; + type function cannot_call_others() + return foo() + end + local function ok(idx: cannot_call_others<>): string return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'cannot_call_others' type function errored at runtime: [string \"cannot_call_others\"]:7: attempt to call a nil value"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function optionify(tbl) + if not tbl:is("table") then + error("Argument is not a table") + end + for k, v in tbl:properties() do + tbl:setproperty(k, types.unionof(v.read, types.singleton(nil))) + end + return tbl + end + type Person = { + name: string, + age: number, + alive: boolean + } + local function ok(idx: optionify): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ age: number?, alive: boolean?, name: string? }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function illegal(arg) + gcinfo() -- this should error + + return arg -- this should not be reached + end + + local function ok(idx: illegal): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'illegal' type function errored at runtime: [string \"illegal\"]:3: this function is not supported in type functions"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") +{ + ScopedFastFlag newSolver{ FFlag::LuauSolverV2, true }; + ScopedFastFlag udtfSyntax{ FFlag::LuauUserDefinedTypeFunctionsSyntax, true }; + ScopedFastFlag udtf{ FFlag::LuauUserDefinedTypeFunctions, true }; + + CheckResult result = check(R"( + type function foo(tbl) + local count = 0 + for k,v in tbl:properties() do count += 1 end + if count < 100 then + tbl:setproperty(types.singleton(`m{count}`), types.string) + foo(tbl) + end + for i = 1,100 do table.create(10000) end + return tbl + end + type Test = {} + local function ok(idx: foo): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index bb5a2cdd3..15eed3927 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,6 +9,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) TEST_SUITE_BEGIN("TypeAliases"); @@ -1169,8 +1170,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_f TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") { - if (!FFlag::LuauUserDefinedTypeFunctions) - return; + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag noUDTFimpl{FFlag::LuauUserDefinedTypeFunctions, false}; CheckResult result = check(R"( type function foo() diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 9912cc356..25f3d1132 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1427,4 +1427,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(requireType("e")), "number?"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function StringSplit(input, separator) + string.find(input, separator) + if not separator then + separator = "%s+" + end + end + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index de79654b7..ec36b30eb 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -15,7 +15,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauOkWithIteratingOverTableProperties) LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) @@ -699,8 +698,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") if (FFlag::LuauSolverV2) return; - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( local t = {} for _ in t do @@ -784,7 +781,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") // CLI-116498 Sometimes you can iterate over tables with no indexers. ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, false}, - {FFlag::LuauOkWithIteratingOverTableProperties, true} }; CheckResult result = check(R"( @@ -937,8 +933,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") { - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( function print(x) end @@ -1095,8 +1089,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") // CLI-116498 - Sometimes you can iterate over tables with no indexer. ScopedFastFlag sff0{FFlag::LuauSolverV2, false}; - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( local function f() local t = { p = 5, q = "hello" } @@ -1118,8 +1110,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") { - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( --!nonstrict local function f() diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 42f1229f7..ba54aca07 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -530,4 +530,82 @@ return l0 CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_scope_is_nullptr_after_shallow_copy") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + frontend.options.retainFullTypeGraphs = false; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from ReactTypes.lua +type CoreBinding = {} +type BindingMap = {} +export type Binding = CoreBinding & BindingMap + +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local Types = require(game.A) +type Binding = Types.Binding + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_free_variables_are_generialized_across_function_boundaries") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from react-shallow-renderer +function createUpdater(renderer) + local updater = { + _renderer = renderer, + } + + function updater.enqueueForceUpdate(publicInstance, callback, _callerName) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueReplaceState( + publicInstance, + completeState, + callback, + _callerName + ) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueSetState(publicInstance, partialState, callback, _callerName) + local currentState = updater._renderer._newState or publicInstance.state + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + return updater +end + +local ReactShallowRenderer = {} + +function ReactShallowRenderer:_reset() + self._updater = createUpdater(self) +end + +return ReactShallowRenderer + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local ReactShallowRenderer = require(game.A); + )")); +} + TEST_SUITE_END();