diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/call.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/call.rs index 084a3bcf7..0be966d06 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/call.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/call.rs @@ -10,8 +10,10 @@ use crate::{ pub fn analyze_call(analyzer: &mut LuaAnalyzer, call_expr: LuaCallExpr) -> Option<()> { let prefix_expr = call_expr.clone().get_prefix_expr()?; - match analyzer.infer_expr(&prefix_expr) { - Ok(LuaType::Signature(signature_id)) => { + // Constructor discovery only needs the callee's declared signature. Full + // flow inference here replays narrowing for every call in call-dense files. + match analyzer.infer_expr_no_flow(&prefix_expr) { + Ok(Some(LuaType::Signature(signature_id))) => { let signature = analyzer.db.get_signature_index().get(&signature_id)?; for (idx, param_info) in signature.param_docs.iter() { if param_info diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/mod.rs index 8ad9ae240..3719d8567 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/mod.rs @@ -26,7 +26,7 @@ use crate::{ compilation::analyzer::{AnalysisPipeline, lua::call::analyze_call}, db_index::{DbIndex, LuaType}, profile::Profile, - semantic::infer_expr, + semantic::{infer_expr, try_infer_expr_no_flow}, }; use super::AnalyzeContext; @@ -121,4 +121,9 @@ impl LuaAnalyzer<'_> { let cache = self.context.infer_manager.get_infer_cache(self.file_id); infer_expr(self.db, cache, expr.clone()) } + + fn infer_expr_no_flow(&mut self, expr: &LuaExpr) -> Result, InferFailReason> { + let cache = self.context.infer_manager.get_infer_cache(self.file_id); + try_infer_expr_no_flow(self.db, cache, expr.clone()) + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs b/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs index ed923ad51..a0828719d 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs @@ -27,6 +27,9 @@ mod test { "#, ), ]); + + let ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(ty), "A"); } #[test] diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index f1ee2121b..903a13d46 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -192,6 +192,38 @@ mod test { assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); } + #[test] + fn test_stacked_local_call_alias_type_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if not pred(value) then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@param v any + ---@return TypeGuard + local function is_string(v) + return true + end + + local pred = is_string + local value ---@type string|integer|boolean + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked local call alias type guard repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + #[test] fn test_stacked_same_var_call_type_guard_eq_false_build_semantic_model() { let mut ws = VirtualWorkspace::new(); @@ -226,6 +258,51 @@ mod test { assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); } + #[test] + fn test_flow_assigned_call_type_guard_prefix_keeps_narrowing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local guard + guard = instance_of + + local value ---@type string|integer|boolean + + if guard(value, "string") then + after_guard = value + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_condition_narrowed_call_type_guard_prefix_keeps_narrowing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param guard (fun(v: any): TypeGuard)? + ---@param value string|integer|boolean + local function f(guard, value) + if guard and guard(value) then + after_guard = value + end + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + #[test] fn test_branch_join_keeps_union_when_only_one_side_narrows() { let mut ws = VirtualWorkspace::new(); @@ -1928,6 +2005,28 @@ end assert_eq!(b, b_expected); } + #[test] + fn test_feature_initializer_alias_keeps_flow_type() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + local x --- @type string | integer + + if type(x) ~= "string" then + return + end + + local y = x + after = y + "#, + ); + + let after = ws.expr_ty("after"); + let after_expected = ws.ty("string"); + assert_eq!(after, after_expected); + } + #[test] fn test_feature_const_local_alias_chain_does_not_inherit_flow() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -2634,6 +2733,362 @@ _2 = a[1] assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo|Bar")); } + #[test] + fn test_assignment_from_call_index_rhs_keeps_precise_rhs_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + + ---@class Bar + ---@field kind "bar" + + ---@class Baz + ---@field kind "baz" + + ---@class Box + ---@field value Bar + + ---@return Box + local function get_box() + end + + local x ---@type Foo|Bar|Baz + + if x.kind == "foo" then + x = get_box().value + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Bar")); + } + + #[test] + fn test_assignment_table_rhs_keeps_multiple_narrowed_field_values() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class LeftFoo + ---@field kind "foo" + + ---@class LeftBar + ---@field kind "bar" + + ---@class RightBaz + ---@field kind "baz" + + ---@class RightQux + ---@field kind "qux" + + local left ---@type LeftFoo|LeftBar + local right ---@type RightBaz|RightQux + + if left.kind == "foo" and right.kind == "baz" then + local pair = { left = left, right = right } + after_left = pair.left + after_right = pair.right + end + "#, + ); + + assert_eq!(ws.expr_ty("after_left"), ws.ty("LeftFoo")); + assert_eq!(ws.expr_ty("after_right"), ws.ty("RightBaz")); + } + + #[test] + fn test_assignment_and_rhs_keeps_narrowed_index_on_second_operand() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Left + + ---@class RightFoo + ---@field kind "foo" + ---@field value string + + ---@class RightBar + ---@field kind "bar" + ---@field value integer + + local left ---@type Left? + local right ---@type RightFoo|RightBar + + if left and right.kind == "foo" then + local result = left and right.value + after_assign = result + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("string")); + } + + #[test] + fn test_assignment_rhs_keeps_multiple_flow_dependencies() { + let mut ws = VirtualWorkspace::new(); + let left_guards = "if not left then return end\n".repeat(STACKED_TYPE_GUARDS); + let right_guards = "if not right then return end\n".repeat(STACKED_TYPE_GUARDS); + + let block = format!( + r#" + ---@class Pattern + ---@operator mul(Pattern): Pattern + + ---@class PatternFactory + ---@field new fun(value: string): Pattern + + local factory ---@type PatternFactory + local left ---@type Pattern? + local right ---@type Pattern? + + {left_guards} + {right_guards} + left = left * factory.new("x") * right + after_assign = left + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for multi-dependency RHS assignment repro" + ); + let after_assign = ws.expr_ty("after_assign"); + assert_eq!(ws.humanize_type(after_assign), "Pattern"); + } + + #[test] + fn test_eq_uses_branch_narrowed_rhs_ref_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string|integer + local y ---@type string|integer + + if type(y) ~= "string" then + return + end + + if x == y then + after_guard = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_eq_uses_branch_narrowed_rhs_index_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + + ---@class Bar + ---@field kind "bar" + + local x ---@type "foo"|"bar" + local y ---@type Foo|Bar + + if y.kind == "foo" then + if x == y.kind then + after_guard = x + end + end + "#, + ); + + let after_guard = ws.expr_ty("after_guard"); + assert_eq!(ws.humanize_type(after_guard), r#""foo""#); + } + + #[test] + fn test_initializer_uses_branch_narrowed_dynamic_key() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class T + ---@field foo string + ---@field bar integer + + local t ---@type T + local key ---@type "foo"|"bar" + + if true then + key = "foo" + local value = t[key] + after_guard = value + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_eq_uses_branch_narrowed_dynamic_rhs_key() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class T + ---@field foo string + ---@field bar integer + + local t ---@type T + local key ---@type "foo"|"bar" + local x ---@type string|integer + + key = "foo" + if x == t[key] then + after_guard = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_field_literal_eq_uses_branch_narrowed_dynamic_key() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field value_key string + ---@field value string + + ---@class Bar + ---@field kind "bar" + ---@field value_key "foo" + ---@field value integer + + local obj ---@type Foo|Bar + local key ---@type "kind"|"value_key" + + key = "kind" + if obj[key] == "foo" then + after_guard = obj.value + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_field_truthy_uses_branch_narrowed_dynamic_key() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Present + ---@field present true + ---@field other true + ---@field value string + + ---@class Missing + ---@field present false? + ---@field other true + ---@field value integer + + local obj ---@type Present|Missing + local key ---@type "present"|"other" + + key = "present" + if obj[key] then + after_guard = obj.value + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_stacked_dynamic_field_truthy_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if not obj[key] then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class PresentDynamic + ---@field present true + ---@field other true + ---@field value string + + ---@class MissingDynamic + ---@field present false? + ---@field other true + ---@field value integer + + local obj ---@type PresentDynamic|MissingDynamic + local key ---@type "present"|"other" + + key = "present" + {repeated_guards} + after_guard = obj.value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked dynamic-field truthiness repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_field_literal_eq_uses_branch_narrowed_dynamic_key_index_dependency() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class FooIndexKey + ---@field kind "foo" + ---@field value_key string + ---@field value string + + ---@class BarIndexKey + ---@field kind "bar" + ---@field value_key "foo" + ---@field value integer + + local obj ---@type FooIndexKey|BarIndexKey + local keys = { "kind", "value_key" } + local slot ---@type 1|2 + + slot = 1 + if obj[keys[slot]] == "foo" then + after_guard = obj.value + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + #[test] fn test_assignment_after_pending_return_cast_guard_drops_branch_narrowing() { let mut ws = VirtualWorkspace::new(); @@ -2741,6 +3196,50 @@ _2 = a[1] assert_eq!(ws.expr_ty("after_assign"), ws.ty("Player|Monster")); } + #[test] + fn test_assignment_missing_rhs_slot_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local cond ---@type boolean + local y = cond and "s" or 1 + + if type(y) == "string" then + local x + x, y = 1 + after_assign = y + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("nil")); + } + + #[test] + fn test_assignment_exhausted_return_slot_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@return string + local function one() + end + + local cond ---@type boolean + local y = cond and "s" or 1 + + if type(y) == "string" then + local x + x, y = one() + after_assign = y + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("nil")); + } + #[test] fn test_assignment_from_nullable_union_keeps_rhs_members() { let mut ws = VirtualWorkspace::new(); @@ -2760,6 +3259,30 @@ _2 = a[1] assert_eq!(ws.expr_ty("after_assign"), ws.ty("number?")); } + #[test] + fn test_index_expr_replay_keeps_literal_field_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class T + ---@field x "foo"|"bar" + + local t ---@type T + local x ---@type "foo"|"bar" + + if t.x == "foo" then + if x == t.x then + after_guard = x + end + end + "#, + ); + + let after_guard = ws.expr_ty("after_guard"); + assert_eq!(ws.humanize_type(after_guard), r#""foo""#); + } + #[test] fn test_assignment_from_partially_overlapping_union_keeps_rhs_members() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 647eaadcd..fee33bf17 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -36,6 +36,24 @@ mod test { assert_eq!(c_ty, union_type); } + #[test] + fn test_exact_missing_table_key_does_not_scan_broad_members() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + local t = { + a = 1, + b = "b", + } + + value = t["missing"] + "#, + ); + + assert_eq!(ws.expr_ty("value"), LuaType::Nil); + } + #[test] fn test_issue_314_generic_inheritance() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -417,6 +435,28 @@ mod test { assert_eq!(ws.expr_ty("result"), LuaType::Integer); } + #[test] + fn test_rawget_alias_guard_narrows_matching_index_expr() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class T + ---@field x? integer + + ---@type T + local t = {} + local get = rawget + + if get(t, "x") then + result = t.x + end + "#, + ); + + assert_eq!(ws.expr_ty("result"), LuaType::Integer); + } + #[test] fn test_type_guard_call_narrows_matching_index_expr() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index b4c69d9ca..25cca95d7 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -176,6 +176,31 @@ mod test { assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); } + #[test] + fn test_pcall_alias_callee_narrows_return_after_error_guard() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@return integer + local function foo() + return 1 + end + + local runner = pcall + local ok, result = runner(foo) + + if not ok then + error(result) + end + + narrowed = result + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + #[test] fn test_pcall_any_callable_splits_success_unknown_and_failure_string() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs b/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs index 0ac54c4fd..55a395546 100644 --- a/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs +++ b/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs @@ -6,7 +6,10 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use super::lua_member_feature::LuaMemberFeature; -use crate::{DbIndex, FileId, GlobalId, InferFailReason, LuaInferCache, LuaType, infer_expr}; +use crate::{ + DbIndex, FileId, GlobalId, InferFailReason, LuaInferCache, LuaType, + semantic::try_infer_expr_for_index, +}; #[derive(Debug)] pub struct LuaMember { @@ -116,7 +119,9 @@ impl LuaMemberKey { } LuaIndexKey::Idx(idx) => Ok(LuaMemberKey::Integer(*idx as i64)), LuaIndexKey::Expr(expr) => { - let expr_type = infer_expr(db, cache, expr.clone())?; + let Some(expr_type) = try_infer_expr_for_index(db, cache, expr.clone())? else { + return Err(InferFailReason::None); + }; match expr_type { LuaType::StringConst(s) => Ok(LuaMemberKey::Name(s.deref().clone())), LuaType::DocStringConst(s) => Ok(LuaMemberKey::Name(s.deref().clone())), diff --git a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs index ae03a6098..c6dc256fa 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs @@ -2,8 +2,8 @@ mod cache_options; pub use cache_options::{CacheOptions, LuaAnalysisPhase}; use emmylua_parser::{LuaExpr, LuaSyntaxId, LuaVarExpr}; -use hashbrown::HashMap; -use std::{rc::Rc, sync::Arc}; +use hashbrown::{HashMap, HashSet}; +use std::{mem, rc::Rc, sync::Arc}; use crate::{ FileId, FlowId, LuaFunctionType, @@ -17,13 +17,6 @@ pub enum CacheEntry { Cache(T), } -#[derive(Debug, Clone)] -pub(in crate::semantic) struct FlowConditionInfo { - pub expr: LuaExpr, - pub index_var_ref_id: Option, - pub index_prefix_var_ref_id: Option, -} - #[derive(Debug, Clone)] pub(in crate::semantic) struct FlowAssignmentInfo { pub vars: Vec, @@ -39,7 +32,7 @@ pub(in crate::semantic) enum FlowMode { impl FlowMode { pub fn uses_conditions(self) -> bool { - !matches!(self, Self::WithoutConditions) + matches!(self, Self::WithConditions) } } @@ -53,15 +46,20 @@ pub(in crate::semantic) struct FlowVarCache { pub struct LuaInferCache { file_id: FileId, config: CacheOptions, + no_flow_mode: bool, pub expr_cache: HashMap>, + pub(in crate::semantic) expr_no_flow_cache: HashMap>>, pub call_cache: HashMap<(LuaSyntaxId, Option, LuaType), CacheEntry>>, + pub(in crate::semantic) call_no_flow_cache: + HashMap<(LuaSyntaxId, Option, LuaType), CacheEntry>>>, + replay_expr_types: Vec<(LuaSyntaxId, LuaType)>, pub(in crate::semantic) flow_cache_var_ref_ids: HashMap, pub(in crate::semantic) next_flow_cache_var_ref_id: u32, pub(in crate::semantic) flow_var_caches: Vec, pub(in crate::semantic) flow_branch_inputs_cache: Vec>>, - pub(in crate::semantic) flow_condition_info_cache: Vec>>, pub(in crate::semantic) flow_assignment_info_cache: Vec>>, + pub(in crate::semantic) no_flow_table_exprs: HashSet, pub index_ref_origin_type_cache: HashMap>, pub expr_var_ref_id_cache: HashMap, } @@ -71,14 +69,18 @@ impl LuaInferCache { Self { file_id, config, + no_flow_mode: false, expr_cache: HashMap::new(), + expr_no_flow_cache: HashMap::new(), call_cache: HashMap::new(), + call_no_flow_cache: HashMap::new(), + replay_expr_types: Vec::new(), flow_cache_var_ref_ids: HashMap::new(), next_flow_cache_var_ref_id: 0, flow_var_caches: Vec::new(), flow_branch_inputs_cache: Vec::new(), - flow_condition_info_cache: Vec::new(), flow_assignment_info_cache: Vec::new(), + no_flow_table_exprs: HashSet::new(), index_ref_origin_type_cache: HashMap::new(), expr_var_ref_id_cache: HashMap::new(), } @@ -92,19 +94,78 @@ impl LuaInferCache { self.file_id } + pub(in crate::semantic) fn is_no_flow(&self) -> bool { + self.no_flow_mode + } + + pub(in crate::semantic) fn with_no_flow( + &mut self, + f: impl FnOnce(&mut LuaInferCache) -> R, + ) -> R { + let previous_mode = mem::replace(&mut self.no_flow_mode, true); + let result = f(self); + self.no_flow_mode = previous_mode; + result + } + + pub(in crate::semantic) fn replay_expr_type(&self, syntax_id: LuaSyntaxId) -> Option<&LuaType> { + self.replay_expr_types + .iter() + .rev() + .find_map(|(overlay_id, ty)| (*overlay_id == syntax_id).then_some(ty)) + } + + pub(in crate::semantic) fn with_replay_overlay( + &mut self, + expr_types: &[(LuaSyntaxId, LuaType)], + table_exprs: &[LuaSyntaxId], + f: impl FnOnce(&mut LuaInferCache) -> R, + ) -> R { + if expr_types.is_empty() && table_exprs.is_empty() { + return f(self); + } + + // Replay overlays change no-flow answers, so isolate overlay-dependent + // cache writes from the normal no-flow caches. + let overlay_len = self.replay_expr_types.len(); + self.replay_expr_types.extend(expr_types.iter().cloned()); + let mut inserted_table_exprs = Vec::new(); + for syntax_id in table_exprs { + if self.no_flow_table_exprs.insert(*syntax_id) { + inserted_table_exprs.push(*syntax_id); + } + } + let saved_expr_no_flow_cache = mem::take(&mut self.expr_no_flow_cache); + let saved_call_no_flow_cache = mem::take(&mut self.call_no_flow_cache); + + let result = f(self); + + self.expr_no_flow_cache = saved_expr_no_flow_cache; + self.call_no_flow_cache = saved_call_no_flow_cache; + for syntax_id in inserted_table_exprs { + self.no_flow_table_exprs.remove(&syntax_id); + } + self.replay_expr_types.truncate(overlay_len); + result + } + pub fn set_phase(&mut self, phase: LuaAnalysisPhase) { self.config.analysis_phase = phase; } pub fn clear(&mut self) { + self.no_flow_mode = false; self.expr_cache.clear(); + self.expr_no_flow_cache.clear(); self.call_cache.clear(); + self.call_no_flow_cache.clear(); + self.replay_expr_types.clear(); self.flow_cache_var_ref_ids.clear(); self.next_flow_cache_var_ref_id = 0; self.flow_var_caches.clear(); self.flow_branch_inputs_cache.clear(); - self.flow_condition_info_cache.clear(); self.flow_assignment_info_cache.clear(); + self.no_flow_table_exprs.clear(); self.index_ref_origin_type_cache.clear(); self.expr_var_ref_id_cache.clear(); } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index fc5d08a7c..1c8e39996 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -174,6 +174,7 @@ fn infer_callable_return_from_arg_types( call_arg_types, false, None, + &[], ); member_returns.push(callable?.get_ret().clone()); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_require.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_require.rs index 32b020801..668365d62 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_require.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_require.rs @@ -5,7 +5,7 @@ use crate::{ semantic::infer::InferResult, }; -pub fn infer_require_call( +pub(super) fn infer_require_call( db: &DbIndex, cache: &mut LuaInferCache, call_expr: LuaCallExpr, @@ -14,7 +14,12 @@ pub fn infer_require_call( let first_arg = arg_list.get_args().next().ok_or(InferFailReason::None)?; let require_path_type = infer_expr(db, cache, first_arg)?; let module_path: String = match &require_path_type { - LuaType::StringConst(module_path) => module_path.as_ref().to_string(), + LuaType::StringConst(module_path) | LuaType::DocStringConst(module_path) => { + module_path.as_ref().to_string() + } + _ if cache.is_no_flow() => { + return Err(InferFailReason::None); + } _ => { return Ok(LuaType::Any); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs index d5cb2c26a..a61fc5986 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs @@ -6,7 +6,7 @@ use crate::{ semantic::{infer::InferResult, member::find_members_with_key}, }; -pub fn infer_setmetatable_call( +pub(super) fn infer_setmetatable_call( db: &DbIndex, cache: &mut LuaInferCache, call_expr: LuaCallExpr, @@ -22,6 +22,12 @@ pub fn infer_setmetatable_call( let metatable = args[1].clone(); let (meta_type, is_index) = infer_metatable_index_type(db, cache, metatable)?; + if cache.is_no_flow() && !is_index && !meta_type.is_custom_type() { + // No-flow setmetatable inference is only used as a conservative fallback. + // If the metatable does not resolve to an actual metatable shape, decline + // instead of treating arbitrary static expressions as the result type. + return Err(InferFailReason::None); + } match &basic_table { LuaExpr::TableExpr(table_expr) => { if table_expr.is_empty() && is_index { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 756d3b9a4..cd1360cf2 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -43,18 +43,31 @@ pub fn infer_call_expr_func( ) -> InferCallFuncResult { let syntax_id = call_expr.get_syntax_id(); let key = (syntax_id, args_count, call_expr_type.clone()); - if let Some(cache) = cache.call_cache.get(&key) { - match cache { + let is_no_flow = cache.is_no_flow(); + if is_no_flow { + if let Some(cache_entry) = cache.call_no_flow_cache.get(&key) { + match cache_entry { + CacheEntry::Cache(Some(ty)) => return Ok(ty.clone()), + CacheEntry::Cache(None) => return Err(InferFailReason::None), + CacheEntry::Ready => return Err(InferFailReason::RecursiveInfer), + } + } + } else if let Some(cache_entry) = cache.call_cache.get(&key) { + match cache_entry { CacheEntry::Cache(ty) => return Ok(ty.clone()), - _ => return Err(InferFailReason::RecursiveInfer), + CacheEntry::Ready => return Err(InferFailReason::RecursiveInfer), } } - cache.call_cache.insert(key.clone(), CacheEntry::Ready); + if is_no_flow { + cache + .call_no_flow_cache + .insert(key.clone(), CacheEntry::Ready); + } else { + cache.call_cache.insert(key.clone(), CacheEntry::Ready); + } let result = match &call_expr_type { - LuaType::DocFunction(func) => { - infer_doc_function(db, cache, func, call_expr.clone(), args_count) - } + LuaType::DocFunction(func) => infer_doc_function(db, cache, func, call_expr.clone()), LuaType::Signature(signature_id) => { infer_signature_doc_function(db, cache, *signature_id, call_expr.clone(), args_count) } @@ -151,15 +164,34 @@ pub fn infer_call_expr_func( match &result { Ok(func_ty) => { + if is_no_flow { + cache + .call_no_flow_cache + .insert(key, CacheEntry::Cache(Some(func_ty.clone()))); + } else { + cache + .call_cache + .insert(key, CacheEntry::Cache(func_ty.clone())); + } + } + Err(InferFailReason::None) | Err(InferFailReason::RecursiveInfer) if is_no_flow => { cache - .call_cache - .insert(key, CacheEntry::Cache(func_ty.clone())); + .call_no_flow_cache + .insert(key, CacheEntry::Cache(None)); } Err(r) if r.is_need_resolve() => { - cache.call_cache.remove(&key); + if is_no_flow { + cache.call_no_flow_cache.remove(&key); + } else { + cache.call_cache.remove(&key); + } } Err(InferFailReason::None) => { - cache.call_cache.remove(&key); + if is_no_flow { + cache.call_no_flow_cache.remove(&key); + } else { + cache.call_cache.remove(&key); + } } _ => {} } @@ -189,7 +221,6 @@ fn infer_doc_function( cache: &mut LuaInferCache, func: &LuaFunctionType, call_expr: LuaCallExpr, - _: Option, ) -> InferCallFuncResult { if func.contain_tpl() { let result = instantiate_func_generic(db, cache, func, call_expr)?; @@ -626,7 +657,7 @@ fn infer_intersection( resolve_signature(db, cache, overloads, call_expr, false, args_count) } -pub(crate) fn unwrapp_return_type( +fn unwrapp_return_type( db: &DbIndex, cache: &mut LuaInferCache, return_type: LuaType, @@ -750,7 +781,8 @@ pub fn infer_call_expr( .get_ret() .clone(); - if let Some(tree) = db.get_flow_index().get_flow_tree(&cache.get_file_id()) + if !cache.is_no_flow() + && let Some(tree) = db.get_flow_index().get_flow_tree(&cache.get_file_id()) && let Some(flow_id) = tree.get_flow_id(call_expr.get_syntax_id()) && let Some(flow_ret_type) = get_type_at_call_expr_inline_cast(db, cache, tree, call_expr, flow_id, ret_type.clone()) diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs index a524d8c31..a155ba0e7 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs @@ -1,99 +1,78 @@ use emmylua_parser::{ LuaAstNode, LuaExpr, LuaForStat, LuaIndexKey, LuaIndexMemberExpr, LuaNameExpr, LuaUnaryExpr, - NumberResult, UnaryOperator, + UnaryOperator, }; use crate::{ - DbIndex, InferFailReason, LuaArrayLen, LuaArrayType, LuaInferCache, LuaType, TypeOps, - infer_expr, semantic::infer::narrow::get_var_expr_var_ref_id, + DbIndex, InferFailReason, LuaArrayLen, LuaArrayType, LuaInferCache, LuaMemberKey, LuaType, + TypeOps, + semantic::infer::{infer_index::infer_expr_for_index, narrow::get_var_expr_var_ref_id}, }; -pub fn infer_array_member( +pub(super) fn infer_array_member_by_key( db: &DbIndex, cache: &mut LuaInferCache, array_type: &LuaArrayType, - index_member_expr: LuaIndexMemberExpr, + index_expr: LuaIndexMemberExpr, + key_type: &LuaType, + key: &LuaMemberKey, ) -> Result { - let key = index_member_expr - .get_index_key() - .ok_or(InferFailReason::None)?; - let index_prefix_expr = match index_member_expr { - LuaIndexMemberExpr::TableField(_) => { - return Ok(array_type.get_base().clone()); - } - _ => index_member_expr - .get_prefix_expr() - .ok_or(InferFailReason::None)?, + let base = array_type.get_base(); + + let index_prefix_expr = match index_expr { + LuaIndexMemberExpr::TableField(_) => return Ok(base.clone()), + _ => index_expr.get_prefix_expr().ok_or(InferFailReason::None)?, }; - match key { - LuaIndexKey::Integer(i) => { - if !db.get_emmyrc().strict.array_index { - return Ok(array_type.get_base().clone()); - } + if let LuaMemberKey::Integer(index_value) = key { + if !db.get_emmyrc().strict.array_index { + return Ok(base.clone()); + } - let base_type = array_type.get_base(); - match array_type.get_len() { - LuaArrayLen::None => {} - LuaArrayLen::Max(max_len) => { - if let NumberResult::Int(index_value) = i.get_number_value() { - if index_value > 0 && index_value <= *max_len { - return Ok(base_type.clone()); - } - } - } - } + if let LuaArrayLen::Max(max_len) = array_type.get_len() + && *index_value > 0 + && *index_value <= *max_len + { + return Ok(base.clone()); + } - let result_type = match &base_type { - LuaType::Any | LuaType::Unknown => base_type.clone(), - _ => TypeOps::Union.apply(db, base_type, &LuaType::Nil), - }; + return Ok(array_member_fallback(db, base)); + } - Ok(result_type) + if !key_type.is_integer() { + if key_type.is_number() { + return Ok(array_member_fallback(db, base)); } - LuaIndexKey::Expr(expr) => { - let expr_type = infer_expr(db, cache, expr.clone())?; - if expr_type.is_integer() { - let base_type = array_type.get_base(); - match (array_type.get_len(), expr_type) { - ( - LuaArrayLen::Max(max_len), - LuaType::IntegerConst(index_value) | LuaType::DocIntegerConst(index_value), - ) => { - if index_value > 0 && index_value <= *max_len { - return Ok(base_type.clone()); - } - } - _ => { - if check_iter_var_range(db, cache, &expr, index_prefix_expr) - .unwrap_or(false) - { - return Ok(base_type.clone()); - } - } - } - - let result_type = match &base_type { - LuaType::Any | LuaType::Unknown => base_type.clone(), - _ => { - if db.get_emmyrc().strict.array_index { - TypeOps::Union.apply(db, base_type, &LuaType::Nil) - } else { - base_type.clone() - } - } - }; - Ok(result_type) - } else { - Err(InferFailReason::FieldNotFound) - } - } - _ => Err(InferFailReason::FieldNotFound), + return Err(InferFailReason::FieldNotFound); + } + + if let LuaArrayLen::Max(max_len) = array_type.get_len() + && let LuaType::IntegerConst(index_value) | LuaType::DocIntegerConst(index_value) = key_type + && *index_value > 0 + && *index_value <= *max_len + { + return Ok(base.clone()); + } + + if let Some(LuaIndexKey::Expr(expr)) = index_expr.get_index_key() + && check_iter_var_range(db, cache, &expr, index_prefix_expr).unwrap_or(false) + { + return Ok(base.clone()); + } + + Ok(array_member_fallback(db, base)) +} + +pub(super) fn array_member_fallback(db: &DbIndex, base: &LuaType) -> LuaType { + match base { + LuaType::Any | LuaType::Unknown => base.clone(), + _ if db.get_emmyrc().strict.array_index => TypeOps::Union.apply(db, base, &LuaType::Nil), + _ => base.clone(), } } -pub fn check_iter_var_range( +pub(super) fn check_iter_var_range( db: &DbIndex, cache: &mut LuaInferCache, may_iter_var: &LuaExpr, @@ -135,7 +114,7 @@ fn check_index_var_in_range( unary_expr } 3 => { - let step_type = infer_expr(db, cache, iter_exprs[2].clone()).ok()?; + let step_type = infer_expr_for_index(db, cache, iter_exprs[2].clone()).ok()?; let LuaType::IntegerConst(step_value) = step_type else { return None; }; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 8bd55fc1d..5a8fac6a1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -21,7 +21,9 @@ use crate::{ generic::{TypeSubstitutor, instantiate_type_generic}, infer::{ VarRefId, - infer_index::infer_array::{check_iter_var_range, infer_array_member}, + infer_index::infer_array::{ + array_member_fallback, check_iter_var_range, infer_array_member_by_key, + }, infer_name::get_name_expr_var_ref_id, narrow::infer_expr_narrow_type, }, @@ -31,7 +33,41 @@ use crate::{ }, }; -use super::{InferFailReason, InferResult, infer_expr, infer_name::infer_global_type}; +use super::{ + InferFailReason, InferResult, infer_expr, infer_name::infer_global_type, try_infer_expr_no_flow, +}; + +pub(crate) fn try_infer_expr_for_index( + db: &DbIndex, + cache: &mut LuaInferCache, + expr: LuaExpr, +) -> Result, InferFailReason> { + if cache.is_no_flow() { + return match expr { + LuaExpr::ParenExpr(paren_expr) => { + let Some(expr) = paren_expr.get_expr() else { + return Ok(None); + }; + try_infer_expr_for_index(db, cache, expr) + } + LuaExpr::ClosureExpr(_) => Ok(None), + expr => match try_infer_expr_no_flow(db, cache, expr) { + Ok(result) => Ok(result), + Err(err) if err.is_need_resolve() => Ok(None), + Err(err) => Err(err), + }, + }; + } + + Ok(Some(infer_expr(db, cache, expr)?)) +} + +fn infer_expr_for_index(db: &DbIndex, cache: &mut LuaInferCache, expr: LuaExpr) -> InferResult { + let Some(expr_type) = try_infer_expr_for_index(db, cache, expr)? else { + return Err(InferFailReason::None); + }; + Ok(expr_type) +} pub fn infer_index_expr( db: &DbIndex, @@ -40,67 +76,32 @@ pub fn infer_index_expr( pass_flow: bool, ) -> InferResult { let prefix_expr = index_expr.get_prefix_expr().ok_or(InferFailReason::None)?; - let prefix_type = infer_expr(db, cache, prefix_expr)?; + let prefix_type = infer_expr_for_index(db, cache, prefix_expr)?; let index_member_expr = LuaIndexMemberExpr::IndexExpr(index_expr.clone()); - let reason = match infer_member_by_member_key( - db, - cache, - &prefix_type, - index_member_expr.clone(), - &InferGuard::new(), - ) { - Ok(member_type) => { - if pass_flow { - return infer_member_type_pass_flow( - db, - cache, - index_expr, - // &prefix_type, - member_type, - ); - } - return Ok(member_type); - } - Err(InferFailReason::FieldNotFound) => InferFailReason::FieldNotFound, - Err(err) => return Err(err), - }; - - match infer_member_by_operator( + let member_type = infer_member( db, cache, &prefix_type, index_member_expr, &InferGuard::new(), - ) { - Ok(member_type) => { - if pass_flow { - return infer_member_type_pass_flow( - db, - cache, - index_expr, - // &prefix_type, - member_type, - ); - } - return Ok(member_type); - } - Err(InferFailReason::FieldNotFound) => {} - Err(err) => return Err(err), - } + )?; - Err(reason) + if pass_flow { + infer_member_type_pass_flow(db, cache, index_expr, member_type) + } else { + Ok(member_type) + } } fn infer_member_type_pass_flow( db: &DbIndex, cache: &mut LuaInferCache, index_expr: LuaIndexExpr, - // prefix_type: &LuaType, member_type: LuaType, ) -> InferResult { let Some(var_ref_id) = get_index_expr_var_ref_id(db, cache, &index_expr) else { - return Ok(member_type.clone()); + return Ok(member_type); }; cache @@ -143,56 +144,179 @@ pub fn get_index_expr_var_ref_id( None } +fn infer_index_key_type( + db: &DbIndex, + cache: &mut LuaInferCache, + index_key: &LuaIndexKey, +) -> Result, InferFailReason> { + match index_key { + LuaIndexKey::Name(name) => Ok(Some(LuaType::StringConst( + SmolStr::new(name.get_name_text()).into(), + ))), + LuaIndexKey::String(s) => Ok(Some(LuaType::StringConst( + SmolStr::new(s.get_value()).into(), + ))), + LuaIndexKey::Integer(i) => { + if let NumberResult::Int(index_value) = i.get_number_value() { + Ok(Some(LuaType::IntegerConst(index_value))) + } else { + Err(InferFailReason::FieldNotFound) + } + } + LuaIndexKey::Idx(i) => Ok(Some(LuaType::IntegerConst(*i as i64))), + LuaIndexKey::Expr(expr) => try_infer_expr_for_index(db, cache, expr.clone()), + } +} + +fn infer_index_expr_key_type( + db: &DbIndex, + cache: &mut LuaInferCache, + index_expr: &LuaIndexMemberExpr, +) -> InferResult { + let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; + infer_index_key_type(db, cache, &index_key)?.ok_or(InferFailReason::None) +} + +fn member_key_from_type(key_type: &LuaType) -> LuaMemberKey { + match key_type { + LuaType::StringConst(s) | LuaType::DocStringConst(s) => LuaMemberKey::Name((**s).clone()), + LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i) => LuaMemberKey::Integer(*i), + _ => LuaMemberKey::ExprType(key_type.clone()), + } +} + +#[derive(Debug, Clone)] +struct MemberLookupQuery { + index_expr: LuaIndexMemberExpr, + key_type: LuaType, + key: LuaMemberKey, +} + +impl MemberLookupQuery { + fn from_index_expr( + db: &DbIndex, + cache: &mut LuaInferCache, + index_expr: LuaIndexMemberExpr, + ) -> Result { + let key_type = infer_index_expr_key_type(db, cache, &index_expr)?; + Ok(Self::from_key_type(index_expr, key_type)) + } + + fn from_key_type(index_expr: LuaIndexMemberExpr, key_type: LuaType) -> Self { + let key = member_key_from_type(&key_type); + Self { + index_expr, + key_type, + key, + } + } +} + +pub fn infer_member( + db: &DbIndex, + cache: &mut LuaInferCache, + prefix_type: &LuaType, + index_expr: LuaIndexMemberExpr, + infer_guard: &InferGuardRef, +) -> InferResult { + let lookup = MemberLookupQuery::from_index_expr(db, cache, index_expr)?; + match infer_member_by_lookup(db, cache, prefix_type, &lookup, infer_guard) { + Ok(member_type) => Ok(member_type), + Err(InferFailReason::FieldNotFound) => infer_member_by_operator_key_type( + db, + cache, + prefix_type, + &lookup.key_type, + &InferGuard::new(), + ), + Err(err) => Err(err), + } +} + pub fn infer_member_by_member_key( db: &DbIndex, cache: &mut LuaInferCache, prefix_type: &LuaType, index_expr: LuaIndexMemberExpr, infer_guard: &InferGuardRef, +) -> InferResult { + let lookup = MemberLookupQuery::from_index_expr(db, cache, index_expr)?; + infer_member_by_lookup(db, cache, prefix_type, &lookup, infer_guard) +} + +pub fn infer_member_by_key_type( + db: &DbIndex, + cache: &mut LuaInferCache, + prefix_type: &LuaType, + index_expr: LuaIndexMemberExpr, + key_type: &LuaType, + infer_guard: &InferGuardRef, +) -> InferResult { + let lookup = MemberLookupQuery::from_key_type(index_expr, key_type.clone()); + infer_member_by_lookup(db, cache, prefix_type, &lookup, infer_guard) +} + +fn infer_member_by_lookup( + db: &DbIndex, + cache: &mut LuaInferCache, + prefix_type: &LuaType, + lookup: &MemberLookupQuery, + infer_guard: &InferGuardRef, ) -> InferResult { match &prefix_type { LuaType::Table | LuaType::Any | LuaType::Unknown => Ok(LuaType::Any), LuaType::Nil => Ok(LuaType::Never), - LuaType::TableConst(id) => infer_table_member(db, cache, id.clone(), index_expr), + LuaType::TableConst(id) => { + infer_table_member(db, id.clone(), &lookup.key_type, &lookup.key) + } LuaType::String | LuaType::Io | LuaType::StringConst(_) | LuaType::DocStringConst(_) | LuaType::Language(_) => { let decl_id = get_buildin_type_map_type_id(prefix_type).ok_or(InferFailReason::None)?; - infer_custom_type_member(db, cache, decl_id, index_expr, infer_guard) + infer_custom_type_member(db, cache, decl_id, lookup, infer_guard) } - LuaType::Ref(decl_id) => { - infer_custom_type_member(db, cache, decl_id.clone(), index_expr, infer_guard) - } - LuaType::Def(decl_id) => { - infer_custom_type_member(db, cache, decl_id.clone(), index_expr, infer_guard) + LuaType::Ref(decl_id) | LuaType::Def(decl_id) => { + infer_custom_type_member(db, cache, decl_id.clone(), lookup, infer_guard) } // LuaType::Module(_) => todo!(), - LuaType::Tuple(tuple_type) => infer_tuple_member(db, cache, tuple_type, index_expr), - LuaType::Object(object_type) => infer_object_member(db, cache, object_type, index_expr), + LuaType::Tuple(tuple_type) => infer_tuple_member(db, cache, tuple_type, lookup), + LuaType::Object(object_type) => { + infer_object_member(db, object_type, &lookup.key_type, &lookup.key) + } LuaType::Union(union_type) => { - infer_union_member(db, cache, union_type, index_expr, infer_guard) + infer_union_member(db, cache, union_type, lookup, infer_guard) } LuaType::MultiLineUnion(multi_union) => { let union_type = multi_union.to_union(); if let LuaType::Union(union_type) = union_type { - infer_union_member(db, cache, &union_type, index_expr, infer_guard) + infer_union_member(db, cache, &union_type, lookup, infer_guard) } else { Err(InferFailReason::FieldNotFound) } } LuaType::Intersection(intersection_type) => { - infer_intersection_member(db, cache, intersection_type, index_expr, infer_guard) + infer_intersection_member(db, cache, intersection_type, lookup, infer_guard) } LuaType::Generic(generic_type) => { - infer_generic_member(db, cache, generic_type, index_expr, infer_guard) + infer_generic_member(db, cache, generic_type, lookup, infer_guard) } - LuaType::Global => infer_global_field_member(db, cache, index_expr), - LuaType::Instance(inst) => infer_instance_member(db, cache, inst, index_expr, infer_guard), - LuaType::Namespace(ns) => infer_namespace_member(db, cache, ns, index_expr), - LuaType::Array(array_type) => infer_array_member(db, cache, array_type, index_expr), - LuaType::TplRef(tpl) => infer_tpl_ref_member(db, cache, tpl, index_expr, infer_guard), + LuaType::Global => infer_global_field_member(db, &lookup.key), + LuaType::Instance(inst) => infer_instance_member(db, cache, inst, lookup, infer_guard), + LuaType::Namespace(ns) => infer_namespace_member(db, ns, &lookup.key), + LuaType::Array(array_type) => infer_array_member_by_key( + db, + cache, + array_type, + lookup.index_expr.clone(), + &lookup.key_type, + &lookup.key, + ), + LuaType::TableGeneric(table_generic) => { + infer_table_generic_member_by_key_type(db, table_generic, &lookup.key_type) + } + LuaType::TplRef(tpl) => infer_tpl_ref_member(db, cache, tpl, lookup, infer_guard), LuaType::ModuleRef(file_id) => { let module_info = db.get_module_index().get_module(*file_id); if let Some(module_info) = module_info { @@ -201,13 +325,7 @@ pub fn infer_member_by_member_key( return Err(InferFailReason::RecursiveInfer); } - return infer_member_by_member_key( - db, - cache, - export_type, - index_expr, - infer_guard, - ); + return infer_member_by_lookup(db, cache, export_type, lookup, infer_guard); } else { return Err(InferFailReason::UnResolveModuleExport(*file_id)); } @@ -221,26 +339,32 @@ pub fn infer_member_by_member_key( fn infer_table_member( db: &DbIndex, - cache: &mut LuaInferCache, inst: InFiled, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, + key: &LuaMemberKey, ) -> InferResult { let owner = LuaMemberOwner::Element(inst); - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let key = LuaMemberKey::from_index_key(db, cache, &index_key)?; - let member_item = match db.get_member_index().get_member_item(&owner, &key) { - Some(member_item) => member_item, - None => return Err(InferFailReason::FieldNotFound), - }; + if let Some(member_item) = db.get_member_index().get_member_item(&owner, key) { + return member_item.resolve_type(db); + } + + if matches!(key, LuaMemberKey::Name(_) | LuaMemberKey::Integer(_)) { + // Exact keys already missed above. The matching scan is only for broad keys. + return Err(InferFailReason::FieldNotFound); + } - member_item.resolve_type(db) + if let Some(result_type) = infer_type_key_member_type(db, key_type, &owner) { + return Ok(result_type); + } + + infer_matching_member_key_type(db, &owner, key_type).ok_or(InferFailReason::FieldNotFound) } fn infer_custom_type_member( db: &DbIndex, cache: &mut LuaInferCache, prefix_type_id: LuaTypeDeclId, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { infer_guard.check(&prefix_type_id)?; @@ -250,18 +374,12 @@ fn infer_custom_type_member( .ok_or(InferFailReason::None)?; if type_decl.is_alias() { if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - return infer_member_by_member_key( - db, - cache, - &origin_type, - index_expr.clone(), - infer_guard, - ); + return infer_member_by_lookup(db, cache, &origin_type, lookup, infer_guard); } else { return Err(InferFailReason::FieldNotFound); } } - if let LuaIndexMemberExpr::IndexExpr(index_expr) = &index_expr + if let LuaIndexMemberExpr::IndexExpr(index_expr) = &lookup.index_expr && enum_variable_is_param(db, cache, index_expr, &LuaType::Ref(prefix_type_id.clone())) .is_some() { @@ -269,44 +387,22 @@ fn infer_custom_type_member( } let owner = LuaMemberOwner::Type(prefix_type_id.clone()); - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let key = LuaMemberKey::from_index_key(db, cache, &index_key)?; - - if let Some(member_item) = db.get_member_index().get_member_item(&owner, &key) { + if let Some(member_item) = db.get_member_index().get_member_item(&owner, &lookup.key) { return member_item.resolve_type(db); } - // 解决`key`为表达式的情况 - if let LuaIndexKey::Expr(expr) = index_key - && let Some(keys) = get_expr_member_key(db, cache, &expr) + // Exact keys may still resolve through super types below; only broad keys need key-type matching here. + if !matches!(lookup.key, LuaMemberKey::Name(_) | LuaMemberKey::Integer(_)) + && let Some(result_type) = infer_type_key_member_type(db, &lookup.key_type, &owner) { - let mut result_types = Vec::new(); - for key in keys { - // 解决 enum[enum] | class[class] 的情况 - if let Some(member_type) = get_expr_key_members(db, &key, &owner) { - result_types.push(member_type); - continue; - } - - if let Some(member_item) = db.get_member_index().get_member_item(&owner, &key) - && let Ok(member_type) = member_item.resolve_type(db) - { - result_types.push(member_type); - } - } - match &result_types[..] { - [] => {} - [first] => return Ok(first.clone()), - _ => return Ok(LuaType::from_vec(result_types)), - } + return Ok(result_type); } if type_decl.is_class() && let Some(super_types) = type_index.get_super_types(&prefix_type_id) { for super_type in super_types { - let result = - infer_member_by_member_key(db, cache, &super_type, index_expr.clone(), infer_guard); + let result = infer_member_by_lookup(db, cache, &super_type, lookup, infer_guard); match result { Ok(member_type) => { @@ -321,119 +417,94 @@ fn infer_custom_type_member( Err(InferFailReason::FieldNotFound) } -fn get_expr_key_members( +fn infer_type_key_member_type( db: &DbIndex, - key: &LuaMemberKey, + key_type: &LuaType, owner: &LuaMemberOwner, ) -> Option { - let LuaMemberKey::ExprType(LuaType::Ref(index_id)) = key else { - return None; - }; - let index_type_decl = db.get_type_index().get_type_decl(index_id)?; - let mut result = Vec::new(); + let keys = get_type_member_key(db, key_type)?; - let origin_type = if index_type_decl.is_alias() { - index_type_decl.get_alias_origin(db, None)? - } else { - LuaType::Ref(index_id.clone()) - }; - - if let Some(member_keys) = get_all_member_key(db, &origin_type) { - for key in member_keys { - if let Some(member_item) = db.get_member_index().get_member_item(owner, &key) - && let Ok(member_type) = member_item.resolve_type(db) - { - result.push(member_type); - } + let mut result_types = Vec::new(); + for key in keys { + if let Some(member_item) = db.get_member_index().get_member_item(owner, &key) + && let Ok(member_type) = member_item.resolve_type(db) + { + result_types.push(member_type); } } - match result.len() { - 0 => None, - 1 => Some(result[0].clone()), - _ => Some(LuaType::from_vec(result)), + match &result_types[..] { + [] => None, + [first] => Some(first.clone()), + _ => Some(LuaType::from_vec(result_types)), } } -fn get_all_member_key(db: &DbIndex, origin_type: &LuaType) -> Option> { - let mut result = Vec::new(); - let mut stack = vec![origin_type.clone()]; // 堆栈用于迭代处理 - let mut visited = HashSet::new(); - - while let Some(current_type) = stack.pop() { - if visited.contains(¤t_type) { - continue; - } - visited.insert(current_type.clone()); - match current_type { - LuaType::MultiLineUnion(types) => { - for (typ, _) in types.get_unions() { - match typ { - LuaType::DocStringConst(s) | LuaType::StringConst(s) => { - result.push((*s).to_string().into()); - } - LuaType::DocIntegerConst(i) | LuaType::IntegerConst(i) => { - result.push((*i).into()); - } - LuaType::Ref(_) => { - stack.push(typ.clone()); // 将 Ref 类型推入堆栈进一步处理 - } - _ => {} - } - } - } - LuaType::Union(union_type) => { - for typ in union_type.into_vec() { - if let LuaType::Ref(_) = typ { - stack.push(typ.clone()); // 推入堆栈 - } - } - } - LuaType::Ref(id) => { - if let Some(type_decl) = db.get_type_index().get_type_decl(&id) - && type_decl.is_enum() - { - let owner = LuaMemberOwner::Type(id.clone()); - if let Some(members) = db.get_member_index().get_members(&owner) { - let is_enum_key = type_decl.is_enum_key(); - for member in members { - if is_enum_key { - result.push(member.get_key().clone()); - } else if let Some(typ) = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|it| it.as_type()) - { - match typ { - LuaType::DocStringConst(s) | LuaType::StringConst(s) => { - result.push((*s).to_string().into()); - } - LuaType::DocIntegerConst(i) | LuaType::IntegerConst(i) => { - result.push((*i).into()); - } - _ => {} - } - } - } - } - } - } - _ => {} +fn infer_matching_member_key_type( + db: &DbIndex, + owner: &LuaMemberOwner, + key_type: &LuaType, +) -> Option { + let mut members = db.get_member_index().get_members(owner)?; + members.sort_by(|a, b| a.get_key().cmp(b.get_key())); + + let mut result_type = LuaType::Never; + let mut has_match = false; + for member in members { + let member_key_type = match member.get_key() { + LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), + LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), + _ => continue, + }; + if check_type_compact(db, key_type, &member_key_type).is_ok() { + let member_type = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|it| it.as_type()) + .unwrap_or(&LuaType::Unknown); + + has_match = true; + result_type = TypeOps::Union.apply(db, &result_type, member_type); } } - Some(result) + if !has_match { + return None; + } + + if matches!( + key_type, + LuaType::String | LuaType::Number | LuaType::Integer + ) { + result_type = TypeOps::Union.apply(db, &result_type, &LuaType::Nil); + } + + Some(result_type) +} + +fn get_type_member_key(db: &DbIndex, key_type: &LuaType) -> Option> { + let mut keys = HashSet::new(); + collect_type_member_keys(db, key_type, &mut keys); + if keys.is_empty() { + return None; + } + + let mut keys: Vec<_> = keys.into_iter().collect(); + keys.sort(); + Some(keys) } fn infer_tuple_member( db: &DbIndex, cache: &mut LuaInferCache, tuple_type: &LuaTupleType, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, ) -> InferResult { - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let key = LuaMemberKey::from_index_key(db, cache, &index_key)?; - match &key { + let index_key = lookup + .index_expr + .get_index_key() + .ok_or(InferFailReason::None)?; + match &lookup.key { LuaMemberKey::Integer(i) => { let index = if *i > 0 { *i - 1 } else { 0 }; return match tuple_type.get_type(index as usize) { @@ -455,17 +526,20 @@ fn infer_tuple_member( result = TypeOps::Union.apply(db, &result, typ); } - let index_prefix_expr = match index_expr { + let index_prefix_expr = match &lookup.index_expr { LuaIndexMemberExpr::TableField(_) => { return Ok(result); } - _ => index_expr.get_prefix_expr().ok_or(InferFailReason::None)?, + _ => lookup + .index_expr + .get_prefix_expr() + .ok_or(InferFailReason::None)?, }; let maybe_iter_var = match &index_key { LuaIndexKey::Expr(expr) => expr, _ => return Ok(result), }; - if check_iter_var_range(db, cache, &maybe_iter_var, index_prefix_expr) + if check_iter_var_range(db, cache, maybe_iter_var, index_prefix_expr) .unwrap_or(false) { return Ok(result); @@ -484,20 +558,18 @@ fn infer_tuple_member( fn infer_object_member( db: &DbIndex, - cache: &mut LuaInferCache, object_type: &LuaObjectType, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, + member_key: &LuaMemberKey, ) -> InferResult { - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let member_key = LuaMemberKey::from_index_key(db, cache, &index_key)?; - if let Some(member_type) = object_type.get_field(&member_key) { + if let Some(member_type) = object_type.get_field(member_key) { return Ok(member_type.clone()); } // todo let index_accesses = object_type.get_index_access(); for (key, value) in index_accesses { - let result = infer_index_metamethod(db, cache, &index_key, key, value); + let result = infer_index_metamethod_by_key_type(db, key_type, key, value); match result { Ok(typ) => { return Ok(typ); @@ -512,28 +584,13 @@ fn infer_object_member( Err(InferFailReason::FieldNotFound) } -fn infer_index_metamethod( +fn infer_index_metamethod_by_key_type( db: &DbIndex, - cache: &mut LuaInferCache, - index_key: &LuaIndexKey, + access_key_type: &LuaType, key_type: &LuaType, value_type: &LuaType, ) -> InferResult { - let access_key_type = match &index_key { - LuaIndexKey::Name(name) => LuaType::StringConst(SmolStr::new(name.get_name_text()).into()), - LuaIndexKey::String(s) => LuaType::StringConst(SmolStr::new(s.get_value()).into()), - LuaIndexKey::Integer(i) => { - if let NumberResult::Int(index_value) = i.get_number_value() { - LuaType::IntegerConst(index_value) - } else { - return Err(InferFailReason::FieldNotFound); - } - } - LuaIndexKey::Idx(i) => LuaType::IntegerConst(*i as i64), - LuaIndexKey::Expr(expr) => infer_expr(db, cache, expr.clone())?, - }; - - if check_type_compact(db, key_type, &access_key_type).is_ok() { + if check_type_compact(db, key_type, access_key_type).is_ok() { return Ok(value_type.clone()); } @@ -544,7 +601,7 @@ fn infer_union_member( db: &DbIndex, cache: &mut LuaInferCache, union_type: &LuaUnionType, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { let mut member_type = LuaType::Never; @@ -558,13 +615,7 @@ fn infer_union_member( } meet_string = true; } - let result = infer_member_by_member_key( - db, - cache, - &sub_type, - index_expr.clone(), - &infer_guard.fork(), - ); + let result = infer_member_by_lookup(db, cache, &sub_type, lookup, &infer_guard.fork()); match result { Ok(typ) => { has_member = true; @@ -591,13 +642,12 @@ fn infer_intersection_member( db: &DbIndex, cache: &mut LuaInferCache, intersection_type: &LuaIntersectionType, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { let mut result: Option = None; for member in intersection_type.get_types() { - match infer_member_by_member_key(db, cache, member, index_expr.clone(), &infer_guard.fork()) - { + match infer_member_by_lookup(db, cache, member, lookup, &infer_guard.fork()) { Ok(ty) => { result = Some(match result { Some(prev) => intersect_member_types(db, prev, ty), @@ -621,7 +671,7 @@ fn infer_generic_members_from_super_generics( cache: &mut LuaInferCache, type_decl_id: &LuaTypeDeclId, substitutor: &TypeSubstitutor, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> Option { let type_index = db.get_type_index(); @@ -635,14 +685,7 @@ fn infer_generic_members_from_super_generics( if let Some(super_types) = type_index.get_super_types(&type_decl_id) { super_types.iter().find_map(|super_type| { let super_type = instantiate_type_generic(db, super_type, substitutor); - infer_member_by_member_key( - db, - cache, - &super_type, - index_expr.clone(), - &infer_guard.fork(), - ) - .ok() + infer_member_by_lookup(db, cache, &super_type, lookup, &infer_guard.fork()).ok() }) } else { None @@ -653,7 +696,7 @@ fn infer_generic_member( db: &DbIndex, cache: &mut LuaInferCache, generic_type: &LuaGenericType, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { let base_type = generic_type.get_base_type(); @@ -667,13 +710,7 @@ fn infer_generic_member( && type_decl.is_alias() && let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { - return infer_member_by_member_key( - db, - cache, - &origin_type, - index_expr, - &infer_guard.fork(), - ); + return infer_member_by_lookup(db, cache, &origin_type, lookup, &infer_guard.fork()); } let result = infer_generic_members_from_super_generics( @@ -681,7 +718,7 @@ fn infer_generic_member( cache, base_type_decl_id, &substitutor, - index_expr.clone(), + lookup, infer_guard, ); if let Some(result) = result { @@ -689,7 +726,7 @@ fn infer_generic_member( } } - let member_type = infer_member_by_member_key(db, cache, &base_type, index_expr, infer_guard)?; + let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; Ok(instantiate_type_generic(db, &member_type, &substitutor)) } @@ -698,16 +735,15 @@ fn infer_instance_member( db: &DbIndex, cache: &mut LuaInferCache, inst: &LuaInstanceType, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { let range = inst.get_range(); let origin_type = inst.get_base(); - let base_result = - infer_member_by_member_key(db, cache, origin_type, index_expr.clone(), infer_guard); + let base_result = infer_member_by_lookup(db, cache, origin_type, lookup, infer_guard); match base_result { - Ok(typ) => match infer_table_member(db, cache, range.clone(), index_expr.clone()) { + Ok(typ) => match infer_table_member(db, range.clone(), &lookup.key_type, &lookup.key) { Ok(table_type) => { return Ok(match TypeOps::Intersect.apply(db, &typ, &table_type) { LuaType::Never => typ, @@ -721,56 +757,58 @@ fn infer_instance_member( Err(err) => return Err(err), } - infer_table_member(db, cache, range.clone(), index_expr.clone()) + infer_table_member(db, range.clone(), &lookup.key_type, &lookup.key) } -pub fn infer_member_by_operator( +fn infer_member_by_operator_key_type( db: &DbIndex, cache: &mut LuaInferCache, prefix_type: &LuaType, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, infer_guard: &InferGuardRef, ) -> InferResult { match &prefix_type { - LuaType::TableConst(in_filed) => { - infer_member_by_index_table(db, cache, in_filed, index_expr) - } - LuaType::Ref(decl_id) => { - infer_member_by_index_custom_type(db, cache, decl_id, index_expr, infer_guard) - } - LuaType::Def(decl_id) => { - infer_member_by_index_custom_type(db, cache, decl_id, index_expr, infer_guard) + LuaType::TableConst(in_filed) => infer_member_by_index_table(db, in_filed, key_type), + LuaType::Ref(decl_id) | LuaType::Def(decl_id) => { + infer_member_by_index_custom_type(db, cache, decl_id, key_type, infer_guard) } // LuaType::Module(arc) => todo!(), LuaType::Array(array_type) => { - infer_member_by_index_array(db, cache, array_type.get_base(), index_expr) + if key_type.is_number() { + Ok(array_member_fallback(db, array_type.get_base())) + } else { + Err(InferFailReason::FieldNotFound) + } + } + LuaType::Object(object) => { + let key = member_key_from_type(key_type); + infer_object_member(db, object, key_type, &key) } - LuaType::Object(object) => infer_member_by_index_object(db, cache, object, index_expr), LuaType::Union(union) => { - infer_member_by_index_union(db, cache, union, index_expr, infer_guard) + infer_member_by_index_union(db, cache, union, key_type, infer_guard) } LuaType::Intersection(intersection) => { - infer_member_by_index_intersection(db, cache, intersection, index_expr, infer_guard) + infer_member_by_index_intersection(db, cache, intersection, key_type, infer_guard) } LuaType::Generic(generic) => { - infer_member_by_index_generic(db, cache, generic, index_expr, infer_guard) + infer_member_by_index_generic(db, cache, generic, key_type, infer_guard) } LuaType::TableGeneric(table_generic) => { - infer_member_by_index_table_generic(db, cache, table_generic, index_expr) + infer_table_generic_member_by_key_type(db, table_generic, key_type) } LuaType::Instance(inst) => { let base = inst.get_base(); - infer_member_by_operator(db, cache, base, index_expr, infer_guard) + infer_member_by_operator_key_type(db, cache, base, key_type, infer_guard) } LuaType::ModuleRef(file_id) => { let module_info = db.get_module_index().get_module(*file_id); if let Some(module_info) = module_info { if let Some(export_type) = &module_info.export_type { - return infer_member_by_operator( + return infer_member_by_operator_key_type( db, cache, export_type, - index_expr, + key_type, infer_guard, ); } else { @@ -786,9 +824,8 @@ pub fn infer_member_by_operator( fn infer_member_by_index_table( db: &DbIndex, - cache: &mut LuaInferCache, table_range: &InFiled, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, ) -> InferResult { let metatable = db.get_metatable_index().get(table_range); match metatable { @@ -799,8 +836,6 @@ fn infer_member_by_index_table( .get_operators(&meta_owner, LuaOperatorMetaMethod::Index) .ok_or(InferFailReason::FieldNotFound)?; - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - for operator_id in operator_ids { let operator = db .get_operator_index() @@ -809,53 +844,15 @@ fn infer_member_by_index_table( let operand = operator.get_operand(db); let return_type = operator.get_result(db)?; if let Ok(typ) = - infer_index_metamethod(db, cache, &index_key, &operand, &return_type) + infer_index_metamethod_by_key_type(db, key_type, &operand, &return_type) { return Ok(typ); } } } None => { - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - if let LuaIndexKey::Expr(expr) = index_key { - let key_type = infer_expr(db, cache, expr.clone())?; - let members = db - .get_member_index() - .get_members(&LuaMemberOwner::Element(table_range.clone())); - if let Some(mut members) = members { - members.sort_by(|a, b| a.get_key().cmp(b.get_key())); - let mut result_type = LuaType::Never; - let mut has_match = false; - for member in members { - let member_key_type = match member.get_key() { - LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), - LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), - _ => continue, - }; - if check_type_compact(db, &key_type, &member_key_type).is_ok() { - let member_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|it| it.as_type()) - .unwrap_or(&LuaType::Unknown); - - has_match = true; - result_type = TypeOps::Union.apply(db, &result_type, member_type); - } - } - - if has_match { - if matches!( - key_type, - LuaType::String | LuaType::Number | LuaType::Integer - ) { - result_type = TypeOps::Union.apply(db, &result_type, &LuaType::Nil); - } - - return Ok(result_type); - } - } - } + let key = member_key_from_type(key_type); + return infer_table_member(db, table_range.clone(), key_type, &key); } } @@ -866,7 +863,7 @@ fn infer_member_by_index_custom_type( db: &DbIndex, cache: &mut LuaInferCache, prefix_type_id: &LuaTypeDeclId, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, infer_guard: &InferGuardRef, ) -> InferResult { infer_guard.check(prefix_type_id)?; @@ -876,12 +873,17 @@ fn infer_member_by_index_custom_type( .ok_or(InferFailReason::None)?; if type_decl.is_alias() { if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - return infer_member_by_operator(db, cache, &origin_type, index_expr, infer_guard); + return infer_member_by_operator_key_type( + db, + cache, + &origin_type, + key_type, + infer_guard, + ); } return Err(InferFailReason::None); } - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; if let Some(index_operator_ids) = db .get_operator_index() .get_operators(&prefix_type_id.clone().into(), LuaOperatorMetaMethod::Index) @@ -893,7 +895,7 @@ fn infer_member_by_index_custom_type( .ok_or(InferFailReason::None)?; let operand = operator.get_operand(db); let return_type = operator.get_result(db)?; - let typ = infer_index_metamethod(db, cache, &index_key, &operand, &return_type); + let typ = infer_index_metamethod_by_key_type(db, key_type, &operand, &return_type); if let Ok(typ) = typ { return Ok(typ); } @@ -906,7 +908,7 @@ fn infer_member_by_index_custom_type( { for super_type in super_types { let result = - infer_member_by_operator(db, cache, &super_type, index_expr.clone(), infer_guard); + infer_member_by_operator_key_type(db, cache, &super_type, key_type, infer_guard); match result { Ok(member_type) => { return Ok(member_type); @@ -920,64 +922,18 @@ fn infer_member_by_index_custom_type( Err(InferFailReason::FieldNotFound) } -fn infer_member_by_index_array( - db: &DbIndex, - cache: &mut LuaInferCache, - base: &LuaType, - index_expr: LuaIndexMemberExpr, -) -> InferResult { - let member_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let expression_type = if db.get_emmyrc().strict.array_index { - TypeOps::Union.apply(db, base, &LuaType::Nil) - } else { - base.clone() - }; - if member_key.is_integer() { - return Ok(expression_type); - } else if member_key.is_expr() { - let expr = member_key.get_expr().ok_or(InferFailReason::None)?; - let expr_type = infer_expr(db, cache, expr.clone())?; - if check_type_compact(db, &LuaType::Number, &expr_type).is_ok() { - return Ok(expression_type); - } - } - - Err(InferFailReason::FieldNotFound) -} - -fn infer_member_by_index_object( - db: &DbIndex, - cache: &mut LuaInferCache, - object: &LuaObjectType, - index_expr: LuaIndexMemberExpr, -) -> InferResult { - let member_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let access_member_type = object.get_index_access(); - if member_key.is_expr() { - let expr = member_key.get_expr().ok_or(InferFailReason::None)?; - let expr_type = infer_expr(db, cache, expr.clone())?; - for (key, field) in access_member_type { - if check_type_compact(db, key, &expr_type).is_ok() { - return Ok(field.clone()); - } - } - } - - Err(InferFailReason::FieldNotFound) -} - fn infer_member_by_index_union( db: &DbIndex, cache: &mut LuaInferCache, union: &LuaUnionType, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, infer_guard: &InferGuardRef, ) -> InferResult { let mut member_type = LuaType::Never; let mut has_member = false; for member in union.into_vec() { let result = - infer_member_by_operator(db, cache, &member, index_expr.clone(), &infer_guard.fork()); + infer_member_by_operator_key_type(db, cache, &member, key_type, &infer_guard.fork()); match result { Ok(typ) => { has_member = true; @@ -1001,12 +957,12 @@ fn infer_member_by_index_intersection( db: &DbIndex, cache: &mut LuaInferCache, intersection: &LuaIntersectionType, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, infer_guard: &InferGuardRef, ) -> InferResult { let mut result: Option = None; for member in intersection.get_types() { - match infer_member_by_operator(db, cache, member, index_expr.clone(), &infer_guard.fork()) { + match infer_member_by_operator_key_type(db, cache, member, key_type, &infer_guard.fork()) { Ok(ty) => { result = Some(match result { Some(prev) => intersect_member_types(db, prev, ty), @@ -1029,7 +985,7 @@ fn infer_member_by_index_generic( db: &DbIndex, cache: &mut LuaInferCache, generic: &LuaGenericType, - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, infer_guard: &InferGuardRef, ) -> InferResult { let base_type = generic.get_base_type(); @@ -1046,18 +1002,17 @@ fn infer_member_by_index_generic( .ok_or(InferFailReason::None)?; if type_decl.is_alias() { if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { - return infer_member_by_operator( + return infer_member_by_operator_key_type( db, cache, &instantiate_type_generic(db, &origin_type, &substitutor), - index_expr.clone(), + key_type, &infer_guard.fork(), ); } return Err(InferFailReason::None); } - let member_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; let operator_index = db.get_operator_index(); if let Some(index_operator_ids) = operator_index.get_operators(&type_decl_id.clone().into(), LuaOperatorMetaMethod::Index) @@ -1072,7 +1027,7 @@ fn infer_member_by_index_generic( instantiate_type_generic(db, &index_operator.get_result(db)?, &substitutor); let result = - infer_index_metamethod(db, cache, &member_key, &instianted_operand, &return_type); + infer_index_metamethod_by_key_type(db, key_type, &instianted_operand, &return_type); match result { Ok(member_type) => { @@ -1089,11 +1044,11 @@ fn infer_member_by_index_generic( // for supers if let Some(supers) = type_index.get_super_types(&type_decl_id) { for super_type in supers { - let result = infer_member_by_operator( + let result = infer_member_by_operator_key_type( db, cache, &instantiate_type_generic(db, &super_type, &substitutor), - index_expr.clone(), + key_type, &infer_guard.fork(), ); match result { @@ -1109,44 +1064,27 @@ fn infer_member_by_index_generic( Err(InferFailReason::FieldNotFound) } -fn infer_member_by_index_table_generic( +fn infer_table_generic_member_by_key_type( db: &DbIndex, - cache: &mut LuaInferCache, table_params: &[LuaType], - index_expr: LuaIndexMemberExpr, + key_type: &LuaType, ) -> InferResult { if table_params.len() != 2 { return Err(InferFailReason::None); } - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let key_type = &table_params[0]; + let table_key_type = &table_params[0]; let value_type = &table_params[1]; - infer_index_metamethod(db, cache, &index_key, key_type, value_type) + infer_index_metamethod_by_key_type(db, key_type, table_key_type, value_type) } -fn infer_global_field_member( - db: &DbIndex, - _: &LuaInferCache, - index_expr: LuaIndexMemberExpr, -) -> InferResult { - let member_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let name = member_key - .get_name() - .ok_or(InferFailReason::None)? - .get_name_text(); +fn infer_global_field_member(db: &DbIndex, key: &LuaMemberKey) -> InferResult { + let name = key.get_name().ok_or(InferFailReason::None)?; infer_global_type(db, name) } -fn infer_namespace_member( - db: &DbIndex, - cache: &mut LuaInferCache, - ns: &str, - index_expr: LuaIndexMemberExpr, -) -> InferResult { - let index_key = index_expr.get_index_key().ok_or(InferFailReason::None)?; - let member_key = LuaMemberKey::from_index_key(db, cache, &index_key)?; - let member_key = match member_key { +fn infer_namespace_member(db: &DbIndex, ns: &str, key: &LuaMemberKey) -> InferResult { + let member_key = match key { LuaMemberKey::Name(name) => name.to_string(), LuaMemberKey::Integer(i) => i.to_string(), _ => return Err(InferFailReason::None), @@ -1163,14 +1101,8 @@ fn infer_namespace_member( )) } -fn get_expr_member_key( - db: &DbIndex, - cache: &mut LuaInferCache, - expr: &LuaExpr, -) -> Option> { - let expr_type = infer_expr(db, cache, expr.clone()).ok()?; - let mut keys: HashSet = HashSet::new(); - let mut stack = vec![expr_type.clone()]; +fn collect_type_member_keys(db: &DbIndex, key_type: &LuaType, keys: &mut HashSet) { + let mut stack = vec![key_type.clone()]; let mut visited = HashSet::new(); while let Some(current_type) = stack.pop() { @@ -1221,33 +1153,51 @@ fn get_expr_member_key( continue; } } - if type_decl.is_enum() || type_decl.is_alias() { - keys.insert(LuaMemberKey::ExprType(current_type.clone())); + if type_decl.is_enum() { + let owner = LuaMemberOwner::Type(id.clone()); + if let Some(members) = db.get_member_index().get_members(&owner) { + let is_enum_key = type_decl.is_enum_key(); + for member in members { + if is_enum_key { + keys.insert(member.get_key().clone()); + } else if let Some(typ) = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|it| it.as_type()) + { + match typ { + LuaType::DocStringConst(s) | LuaType::StringConst(s) => { + keys.insert((*s).to_string().into()); + } + LuaType::DocIntegerConst(i) | LuaType::IntegerConst(i) => { + keys.insert((*i).into()); + } + _ => {} + } + } + } + } } } } _ => {} } } - - // 转换为 Vec 并排序以确保顺序确定性 - let mut keys: Vec<_> = keys.into_iter().collect(); - keys.sort(); - Some(keys) } fn infer_tpl_ref_member( db: &DbIndex, cache: &mut LuaInferCache, generic: &GenericTpl, - index_expr: LuaIndexMemberExpr, + lookup: &MemberLookupQuery, infer_guard: &InferGuardRef, ) -> InferResult { let extend_type = get_tpl_ref_extend_type( db, cache, &LuaType::TplRef(generic.clone().into()), - index_expr + lookup + .index_expr .get_index_expr() .ok_or(InferFailReason::None)? .get_prefix_expr() @@ -1255,5 +1205,5 @@ fn infer_tpl_ref_member( 0, ) .ok_or(InferFailReason::None)?; - infer_member_by_member_key(db, cache, &extend_type, index_expr.clone(), infer_guard) + infer_member_by_lookup(db, cache, &extend_type, lookup, infer_guard) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs index c35be60b1..357b9ae8f 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs @@ -7,7 +7,7 @@ use crate::{ db_index::{DbIndex, LuaDeclOrMemberId}, infer_node_semantic_decl, semantic::{ - infer::narrow::{VarRefId, infer_expr_narrow_type}, + infer::narrow::{VarRefId, get_var_ref_type, infer_expr_narrow_type}, semantic_info::resolve_global_decl_id, }, }; @@ -33,7 +33,7 @@ pub fn infer_name_expr( .ok_or(InferFailReason::None)?; let decl_id = file_ref.get_decl_id(&range); if let Some(decl_id) = decl_id { - infer_expr_narrow_type( + infer_var_ref_type( db, cache, LuaExpr::NameExpr(name_expr), @@ -47,8 +47,7 @@ pub fn infer_name_expr( fn infer_self(db: &DbIndex, cache: &mut LuaInferCache, name_expr: LuaNameExpr) -> InferResult { let decl_or_member_id = find_self_decl_or_member_id(db, cache, &name_expr).ok_or(InferFailReason::None)?; - // LuaDeclOrMemberId::Member(member_id) => find_decl_member_type(db, member_id), - infer_expr_narrow_type( + infer_var_ref_type( db, cache, LuaExpr::NameExpr(name_expr), @@ -56,6 +55,19 @@ fn infer_self(db: &DbIndex, cache: &mut LuaInferCache, name_expr: LuaNameExpr) - ) } +fn infer_var_ref_type( + db: &DbIndex, + cache: &mut LuaInferCache, + expr: LuaExpr, + var_ref_id: VarRefId, +) -> InferResult { + if cache.is_no_flow() { + get_var_ref_type(db, cache, &var_ref_id) + } else { + infer_expr_narrow_type(db, cache, expr, var_ref_id) + } +} + pub fn get_name_expr_var_ref_id( db: &DbIndex, cache: &mut LuaInferCache, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs index 210ca89f5..12e44a895 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs @@ -12,10 +12,7 @@ use crate::{ infer_call_expr_func, infer_expr, }; -use super::{ - InferFailReason, InferResult, - infer_index::{infer_member_by_member_key, infer_member_by_operator}, -}; +use super::{InferFailReason, InferResult, infer_index::infer_member}; pub fn infer_table_expr( db: &DbIndex, @@ -186,19 +183,7 @@ pub fn infer_table_field_value_should_be( .ok_or(InferFailReason::None)?; let parent_table_expr_type = infer_table_should_be(db, cache, parnet_table_expr)?; let index = LuaIndexMemberExpr::TableField(table_field.clone()); - let reason = match infer_member_by_member_key( - db, - cache, - &parent_table_expr_type, - index.clone(), - &InferGuard::new(), - ) { - Ok(member_type) => return Ok(member_type), - Err(InferFailReason::FieldNotFound) => InferFailReason::FieldNotFound, - Err(err) => return Err(err), - }; - - match infer_member_by_operator( + match infer_member( db, cache, &parent_table_expr_type, @@ -215,7 +200,7 @@ pub fn infer_table_field_value_should_be( return Ok(type_cache.as_type().clone()); }; - Err(reason) + Err(InferFailReason::FieldNotFound) } fn infer_table_type_by_callee( @@ -323,31 +308,13 @@ fn infer_table_field_type_by_parent( let parent_table_expr_type = infer_table_should_be(db, cache, parnet_table_expr)?; let index = LuaIndexMemberExpr::TableField(field); - let reason = match infer_member_by_member_key( - db, - cache, - &parent_table_expr_type, - index.clone(), - &InferGuard::new(), - ) { - Ok(member_type) => return Ok(member_type), - Err(InferFailReason::FieldNotFound) => InferFailReason::FieldNotFound, - Err(err) => return Err(err), - }; - - match infer_member_by_operator( + infer_member( db, cache, &parent_table_expr_type, index, &InferGuard::new(), - ) { - Ok(member_type) => return Ok(member_type), - Err(InferFailReason::FieldNotFound) => {} - Err(err) => return Err(err), - } - - Err(reason) + ) } fn infer_table_type_by_local( diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index daea5d672..eed7743ff 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -13,7 +13,7 @@ use std::ops::Deref; use emmylua_parser::{ LuaAst, LuaAstNode, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaLiteralExpr, LuaLiteralToken, - LuaTableExpr, LuaVarExpr, NumberResult, + LuaSyntaxId, LuaTableExpr, LuaVarExpr, NumberResult, }; use infer_binary::infer_binary_expr; use infer_call::infer_call_expr; @@ -21,6 +21,7 @@ pub use infer_call::infer_call_expr_func; pub use infer_doc_type::{DocTypeInferContext, infer_doc_type}; pub use infer_fail_reason::InferFailReason; pub use infer_index::infer_index_expr; +pub(crate) use infer_index::try_infer_expr_for_index; use infer_name::infer_name_expr; pub use infer_name::{find_self_decl_or_member_id, infer_param}; use infer_table::infer_table_expr; @@ -42,30 +43,81 @@ use super::{CacheEntry, LuaInferCache, member::infer_raw_member_type}; pub type InferResult = Result; pub use infer_call::InferCallFuncResult; -pub fn infer_expr(db: &DbIndex, cache: &mut LuaInferCache, expr: LuaExpr) -> InferResult { - let syntax_id = expr.get_syntax_id(); - let key = syntax_id; - if let Some(cache) = cache.expr_cache.get(&key) { - match cache { - CacheEntry::Cache(ty) => return Ok(ty.clone()), - _ => return Err(InferFailReason::RecursiveInfer), +fn prepare_expr_cache( + db: &DbIndex, + cache: &mut LuaInferCache, + syntax_id: LuaSyntaxId, +) -> Result, InferFailReason> { + let file_id = cache.get_file_id(); + if cache.is_no_flow() { + if let Some(ty) = cache.replay_expr_type(syntax_id) { + return Ok(Some(ty.clone())); + } + + if let Some(cache_entry) = cache.expr_no_flow_cache.get(&syntax_id) { + match cache_entry { + CacheEntry::Cache(Some(ty)) => return Ok(Some(ty.clone())), + CacheEntry::Cache(None) => return Err(InferFailReason::None), + CacheEntry::Ready => return Err(InferFailReason::RecursiveInfer), + } + } + + let in_filed_syntax_id = InFiled::new(file_id, syntax_id); + if let Some(bind_type_cache) = db + .get_type_index() + .get_type_cache(&in_filed_syntax_id.into()) + { + let ty = bind_type_cache.as_type().clone(); + cache + .expr_no_flow_cache + .insert(syntax_id, CacheEntry::Cache(Some(ty.clone()))); + return Ok(Some(ty)); + } + + cache + .expr_no_flow_cache + .insert(syntax_id, CacheEntry::Ready); + return Ok(None); + } + + if let Some(cache_entry) = cache.expr_cache.get(&syntax_id) { + match cache_entry { + CacheEntry::Cache(ty) => return Ok(Some(ty.clone())), + CacheEntry::Ready => return Err(InferFailReason::RecursiveInfer), } } - // for @as - let file_id = cache.get_file_id(); let in_filed_syntax_id = InFiled::new(file_id, syntax_id); if let Some(bind_type_cache) = db .get_type_index() .get_type_cache(&in_filed_syntax_id.into()) { + let ty = bind_type_cache.as_type().clone(); cache .expr_cache - .insert(key, CacheEntry::Cache(bind_type_cache.as_type().clone())); - return Ok(bind_type_cache.as_type().clone()); + .insert(syntax_id, CacheEntry::Cache(ty.clone())); + return Ok(Some(ty)); } - cache.expr_cache.insert(key, CacheEntry::Ready); + cache.expr_cache.insert(syntax_id, CacheEntry::Ready); + Ok(None) +} + +pub fn infer_expr(db: &DbIndex, cache: &mut LuaInferCache, expr: LuaExpr) -> InferResult { + let no_flow = cache.is_no_flow(); + let syntax_id = expr.get_syntax_id(); + if let Some(result_type) = prepare_expr_cache(db, cache, syntax_id)? { + return Ok(result_type); + } + if no_flow + && matches!(expr, LuaExpr::TableExpr(_)) + && !cache.no_flow_table_exprs.contains(&syntax_id) + { + cache + .expr_no_flow_cache + .insert(syntax_id, CacheEntry::Cache(None)); + return Err(InferFailReason::None); + } let result_type = match expr { LuaExpr::CallExpr(call_expr) => infer_call_expr(db, cache, call_expr), LuaExpr::TableExpr(table_expr) => infer_table_expr(db, cache, table_expr), @@ -79,39 +131,69 @@ pub fn infer_expr(db: &DbIndex, cache: &mut LuaInferCache, expr: LuaExpr) -> Inf paren_expr.get_expr().ok_or(InferFailReason::None)?, ), LuaExpr::NameExpr(name_expr) => infer_name_expr(db, cache, name_expr), - LuaExpr::IndexExpr(index_expr) => infer_index_expr(db, cache, index_expr, true), + LuaExpr::IndexExpr(index_expr) => infer_index_expr(db, cache, index_expr, !no_flow), }; match &result_type { Ok(result_type) => { - cache - .expr_cache - .insert(key, CacheEntry::Cache(result_type.clone())); + if no_flow { + cache + .expr_no_flow_cache + .insert(syntax_id, CacheEntry::Cache(Some(result_type.clone()))); + } else { + cache + .expr_cache + .insert(syntax_id, CacheEntry::Cache(result_type.clone())); + } } Err(InferFailReason::None) | Err(InferFailReason::RecursiveInfer) => { - cache - .expr_cache - .insert(key, CacheEntry::Cache(LuaType::Unknown)); - return Ok(LuaType::Unknown); + if no_flow { + cache + .expr_no_flow_cache + .insert(syntax_id, CacheEntry::Cache(None)); + } else { + cache + .expr_cache + .insert(syntax_id, CacheEntry::Cache(LuaType::Unknown)); + return Ok(LuaType::Unknown); + } } Err(InferFailReason::FieldNotFound) => { - if cache.get_config().analysis_phase.is_force() { + if no_flow { + cache.expr_no_flow_cache.remove(&syntax_id); + } else if cache.get_config().analysis_phase.is_force() { cache .expr_cache - .insert(key, CacheEntry::Cache(LuaType::Nil)); + .insert(syntax_id, CacheEntry::Cache(LuaType::Nil)); return Ok(LuaType::Nil); } else { - cache.expr_cache.remove(&key); + cache.expr_cache.remove(&syntax_id); } } _ => { - cache.expr_cache.remove(&key); + if no_flow { + cache.expr_no_flow_cache.remove(&syntax_id); + } else { + cache.expr_cache.remove(&syntax_id); + } } } result_type } +pub(crate) fn try_infer_expr_no_flow( + db: &DbIndex, + cache: &mut LuaInferCache, + expr: LuaExpr, +) -> Result, InferFailReason> { + match cache.with_no_flow(|cache| infer_expr(db, cache, expr)) { + Ok(result_type) => Ok(Some(result_type)), + Err(InferFailReason::None) | Err(InferFailReason::RecursiveInfer) => Ok(None), + Err(err) => Err(err), + } +} + fn infer_literal_expr(db: &DbIndex, config: &LuaInferCache, expr: LuaLiteralExpr) -> InferResult { match expr.get_literal().ok_or(InferFailReason::None)? { LuaLiteralToken::Nil(_) => Ok(LuaType::Nil), @@ -177,32 +259,6 @@ fn get_custom_type_operator( } } -pub fn infer_expr_list_value_type_at( - db: &DbIndex, - cache: &mut LuaInferCache, - exprs: &[LuaExpr], - value_idx: usize, -) -> Result, InferFailReason> { - let exprs_len = exprs.len(); - if exprs_len == 0 { - Ok(None) - } else if value_idx < exprs_len { - Ok( - infer_expr_list_types(db, cache, &exprs[value_idx..], Some(1), infer_expr)? - .first() - .map(|(ty, _)| ty.clone()), - ) - } else { - let last_idx = exprs_len - 1; - let offset = value_idx - last_idx; - Ok( - infer_expr_list_types(db, cache, &exprs[last_idx..], Some(offset + 1), infer_expr)? - .get(offset) - .map(|(ty, _)| ty.clone()), - ) - } -} - pub fn infer_expr_list_types( db: &DbIndex, cache: &mut LuaInferCache, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index 26d2739aa..0a0fc35ea 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -4,18 +4,20 @@ use emmylua_parser::{ }; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaType, TypeOps, infer_expr, + DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaType, TypeOps, semantic::infer::{ VarRefId, narrow::{ condition_flow::{ - ConditionFlowAction, ConditionSubquery, CorrelatedDiscriminantNarrow, + ConditionFlowAction, CorrelatedDiscriminantNarrow, CorrelatedSubquery, + ExprTypeContinuation, FieldConditionKind, FieldLiteralSiblingSubquery, InferConditionFlow, PendingConditionNarrow, always_literal_equal, call_flow::get_type_at_call_expr, }, get_single_antecedent, var_ref_id::get_var_expr_var_ref_id, }, + try_infer_expr_no_flow, }, }; @@ -59,10 +61,7 @@ pub fn get_type_at_binary_expr( flow_node, left_expr, right_expr, - match condition_flow { - InferConditionFlow::TrueCondition => InferConditionFlow::FalseCondition, - InferConditionFlow::FalseCondition => InferConditionFlow::TrueCondition, - }, + condition_flow.invert(), ), BinaryOperator::OpGt => try_get_at_gt_or_ge_expr( db, @@ -185,17 +184,15 @@ fn try_get_at_gt_or_ge_expr( return Ok(ConditionFlowAction::Continue); } - let right_expr_type = infer_expr(db, cache, right_expr)?; let antecedent_flow_id = get_single_antecedent(flow_node)?; - Ok(ConditionFlowAction::NeedSubquery( - ConditionSubquery::ArrayLen { - var_ref_id: var_ref_id.clone(), - antecedent_flow_id, + Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: right_expr, + resume: ExprTypeContinuation::ArrayLen { subquery_condition_flow: condition_flow, - right_expr_type, max_adjustment: if gt { 1 } else { 0 }, }, - )) + }) } _ => Ok(ConditionFlowAction::Continue), } @@ -285,7 +282,7 @@ fn maybe_type_guard_binary_action( let antecedent_flow_id = get_single_antecedent(flow_node)?; Ok(Some(ConditionFlowAction::NeedSubquery( - ConditionSubquery::Correlated { + CorrelatedSubquery { var_ref_id: maybe_var_ref_id, antecedent_flow_id, subquery_condition_flow: condition_flow, @@ -386,44 +383,28 @@ fn get_var_eq_condition_action( return Ok(ConditionFlowAction::Continue); } let antecedent_flow_id = get_single_antecedent(flow_node)?; - let right_expr_type = infer_expr(db, cache, right_expr)?; - return Ok(ConditionFlowAction::NeedSubquery( - ConditionSubquery::Correlated { - var_ref_id: maybe_ref_id, - antecedent_flow_id, + return Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: right_expr, + resume: ExprTypeContinuation::CorrelatedEq { + var_ref_id: maybe_ref_id.clone(), subquery_condition_flow: condition_flow, discriminant_decl_id, condition_position: left_name_expr.get_position(), - narrow: CorrelatedDiscriminantNarrow::Eq { - right_expr_type, - allow_literal_equivalence: true, - }, - fallback_expr: None, + allow_literal_equivalence: true, }, - )); + }); } - let right_expr_type = infer_expr(db, cache, right_expr)?; - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => { - // self 是特殊的, 我们删除其 nil 类型 - if var_ref_id.is_self_ref() && !right_expr_type.is_nil() { - TypeOps::Remove.apply(db, &right_expr_type, &LuaType::Nil) - } else { - return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { - right_expr_type, - condition_flow, - })); - } - } - InferConditionFlow::FalseCondition => { - return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { - right_expr_type, - condition_flow, - })); - } - }; - Ok(ConditionFlowAction::Result(result_type)) + let antecedent_flow_id = get_single_antecedent(flow_node)?; + Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: right_expr, + resume: ExprTypeContinuation::Eq { + condition_flow, + true_result_is_rhs: false, + }, + }) } LuaExpr::CallExpr(left_call_expr) => { if let LuaExpr::LiteralExpr(literal_expr) = right_expr { @@ -432,17 +413,17 @@ fn get_var_eq_condition_action( let flow = if b.is_true() { condition_flow } else { - match condition_flow { - InferConditionFlow::TrueCondition => { - InferConditionFlow::FalseCondition - } - InferConditionFlow::FalseCondition => { - InferConditionFlow::TrueCondition - } - } + condition_flow.invert() }; - return get_type_at_call_expr(db, cache, var_ref_id, left_call_expr, flow); + return get_type_at_call_expr( + db, + cache, + var_ref_id, + flow_node, + left_call_expr, + flow, + ); } _ => return Ok(ConditionFlowAction::Continue), } @@ -462,15 +443,15 @@ fn get_var_eq_condition_action( return Ok(ConditionFlowAction::Continue); } - let right_expr_type = infer_expr(db, cache, right_expr)?; - if matches!(condition_flow, InferConditionFlow::FalseCondition) { - return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { - right_expr_type, + let antecedent_flow_id = get_single_antecedent(flow_node)?; + Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: right_expr, + resume: ExprTypeContinuation::Eq { condition_flow, - })); - } - - Ok(ConditionFlowAction::Result(right_expr_type)) + true_result_is_rhs: true, + }, + }) } LuaExpr::UnaryExpr(unary_expr) => { let Some(op) = unary_expr.get_op_token() else { @@ -495,17 +476,15 @@ fn get_var_eq_condition_action( return Ok(ConditionFlowAction::Continue); } - let right_expr_type = infer_expr(db, cache, right_expr)?; let antecedent_flow_id = get_single_antecedent(flow_node)?; - Ok(ConditionFlowAction::NeedSubquery( - ConditionSubquery::ArrayLen { - var_ref_id: var_ref_id.clone(), - antecedent_flow_id, + Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: right_expr, + resume: ExprTypeContinuation::ArrayLen { subquery_condition_flow: condition_flow, - right_expr_type, max_adjustment: 0, }, - )) + }) } _ => { // If the left expression is not a name or call expression, we cannot narrow it @@ -541,9 +520,11 @@ fn maybe_field_literal_eq_action( return Ok(None); }; - let index_var_ref_id = - get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(index_expr.clone())); - if index_var_ref_id.as_ref() == Some(var_ref_id) { + // The exact index ref should use normal equality narrowing; this field + // path is for narrowing the prefix object through one of its fields. + if get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(index_expr.clone())) + .is_some_and(|index_ref_id| index_ref_id == *var_ref_id) + { return Ok(None); } @@ -554,13 +535,17 @@ fn maybe_field_literal_eq_action( if maybe_var_ref_id != *var_ref_id { if var_ref_id.start_with(&maybe_var_ref_id) { - let right_type = infer_expr(db, cache, LuaExpr::LiteralExpr(literal_expr))?; - return Ok(Some(ConditionFlowAction::NeedSubquery( - ConditionSubquery::FieldLiteralSibling { + let Some(right_type) = + try_infer_expr_no_flow(db, cache, LuaExpr::LiteralExpr(literal_expr))? + else { + return Ok(None); + }; + return Ok(Some(ConditionFlowAction::NeedFieldLiteralSibling( + FieldLiteralSiblingSubquery { var_ref_id: var_ref_id.clone(), discriminant_prefix_var_ref_id: maybe_var_ref_id, antecedent_flow_id: get_single_antecedent(flow_node)?, - subquery_condition_flow: condition_flow, + condition_flow, idx: LuaIndexMemberExpr::IndexExpr(index_expr), right_expr_type: right_type, }, @@ -570,15 +555,19 @@ fn maybe_field_literal_eq_action( return Ok(None); } - let antecedent_flow_id = get_single_antecedent(flow_node)?; - let right_type = infer_expr(db, cache, LuaExpr::LiteralExpr(literal_expr))?; - Ok(Some(ConditionFlowAction::NeedSubquery( - ConditionSubquery::FieldLiteralEq { - var_ref_id: var_ref_id.clone(), - antecedent_flow_id, - subquery_condition_flow: condition_flow, - idx: LuaIndexMemberExpr::IndexExpr(index_expr), - right_expr_type: right_type, + let Some(right_type) = try_infer_expr_no_flow(db, cache, LuaExpr::LiteralExpr(literal_expr))? + else { + return Ok(None); + }; + let idx = LuaIndexMemberExpr::IndexExpr(index_expr); + Ok(Some(ConditionFlowAction::Pending( + PendingConditionNarrow::Field { + idx, + key_type: None, + condition_flow, + kind: FieldConditionKind::LiteralEq { + right_expr_type: right_type, + }, }, ))) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index 49df32d29..e915915b1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -3,23 +3,27 @@ use std::{ops::Deref, sync::Arc}; use emmylua_parser::{LuaCallExpr, LuaExpr, LuaIndexMemberExpr}; use crate::{ - DbIndex, InferFailReason, InferGuard, LuaAliasCallKind, LuaAliasCallType, LuaFunctionType, - LuaInferCache, LuaSignatureId, LuaType, infer_call_expr_func, infer_expr, + DbIndex, FlowNode, InferFailReason, LuaAliasCallKind, LuaAliasCallType, LuaFunctionType, + LuaInferCache, LuaSignatureId, LuaType, semantic::infer::{ VarRefId, - infer_index::infer_member_by_member_key, narrow::{ - condition_flow::{ConditionFlowAction, InferConditionFlow, PendingConditionNarrow}, - get_var_ref_type, narrow_false_or_nil, remove_false_or_nil, + condition_flow::{ + ConditionFlowAction, ExprTypeContinuation, InferConditionFlow, + PendingConditionNarrow, + }, + get_single_antecedent, get_var_ref_type, narrow_false_or_nil, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id, }, }, + semantic::instantiate_func_generic, }; pub fn get_type_at_call_expr( db: &DbIndex, cache: &mut LuaInferCache, var_ref_id: &VarRefId, + flow_node: &FlowNode, call_expr: LuaCallExpr, condition_flow: InferConditionFlow, ) -> Result { @@ -27,43 +31,77 @@ pub fn get_type_at_call_expr( return Ok(ConditionFlowAction::Continue); }; - let maybe_func = if call_expr.is_colon_call() { - match &prefix_expr { - LuaExpr::IndexExpr(index_expr) => { - if let Some(self_expr) = index_expr.get_prefix_expr() - && let Some(self_var_ref_id) = get_var_expr_var_ref_id(db, cache, self_expr) - && self_var_ref_id == *var_ref_id - { - let self_type = get_var_ref_type(db, cache, var_ref_id)?; - let member_type = infer_member_by_member_key( - db, - cache, - &self_type, - LuaIndexMemberExpr::IndexExpr(index_expr.clone()), - &InferGuard::new(), - )?; - - if needs_antecedent_same_var_colon_lookup(&member_type) { - // Keep the dedicated pending case here: replay needs the antecedent type - // for member lookup itself, not just for applying a cast after lookup. - return Ok(ConditionFlowAction::Pending( - PendingConditionNarrow::SameVarColonCall { - idx: LuaIndexMemberExpr::IndexExpr(index_expr.clone()), - condition_flow, - }, - )); - } else { - member_type - } - } else { - infer_expr(db, cache, prefix_expr.clone())? - } + let mut receiver_method_idx = None; + let mut targets_var = false; + if let LuaExpr::IndexExpr(index_expr) = &prefix_expr { + if let Some(self_expr) = index_expr.get_prefix_expr() { + let self_ref_id = get_var_expr_var_ref_id(db, cache, self_expr.clone()); + targets_var |= self_ref_id + .as_ref() + .is_some_and(|self_ref_id| refs_overlap(self_ref_id, var_ref_id)); + + if call_expr.is_colon_call() && self_ref_id.as_ref() == Some(var_ref_id) { + receiver_method_idx = + Some((LuaIndexMemberExpr::IndexExpr(index_expr.clone()), self_expr)); } - _ => infer_expr(db, cache, prefix_expr.clone())?, } - } else { - infer_expr(db, cache, prefix_expr.clone())? - }; + } + + targets_var |= call_expr.get_args_list().is_some_and(|arg_list| { + arg_list + .get_args() + .any(|arg| expr_targets_var(db, cache, arg, var_ref_id)) + }); + if !targets_var { + return Ok(ConditionFlowAction::Continue); + } + + if let Some((idx, receiver_expr)) = receiver_method_idx { + let antecedent_flow_id = get_single_antecedent(flow_node)?; + return Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: receiver_expr, + resume: ExprTypeContinuation::ReceiverMethodCall { + condition_flow, + idx, + call_expr: call_expr.clone(), + }, + }); + } + + let antecedent_flow_id = get_single_antecedent(flow_node)?; + Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: prefix_expr.clone(), + resume: ExprTypeContinuation::Call { + call_expr: call_expr.clone(), + condition_flow, + }, + }) +} + +fn expr_targets_var( + db: &DbIndex, + cache: &mut LuaInferCache, + expr: LuaExpr, + var_ref_id: &VarRefId, +) -> bool { + get_var_expr_var_ref_id(db, cache, expr) + .is_some_and(|expr_ref_id| refs_overlap(&expr_ref_id, var_ref_id)) +} + +fn refs_overlap(left: &VarRefId, right: &VarRefId) -> bool { + left == right || left.start_with(right) || right.start_with(left) +} + +pub(super) fn get_type_at_call_expr_by_func( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, + call_expr: LuaCallExpr, + maybe_func: LuaType, + condition_flow: InferConditionFlow, +) -> Result { match maybe_func { LuaType::DocFunction(f) => { let return_type = f.get_ret(); @@ -120,6 +158,9 @@ pub fn get_type_at_call_expr( let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) else { return Ok(ConditionFlowAction::Continue); }; + let Some(prefix_expr) = call_expr.get_prefix_expr() else { + return Ok(ConditionFlowAction::Continue); + }; match signature_cast.name.as_str() { "self" => get_type_at_call_expr_by_signature_self( @@ -145,7 +186,7 @@ pub fn get_type_at_call_expr( } } -fn needs_antecedent_same_var_colon_lookup(member_type: &LuaType) -> bool { +pub(super) fn needs_deferred_receiver_method_lookup(member_type: &LuaType) -> bool { let candidate_members = match member_type { LuaType::Union(union_type) => union_type.into_vec(), LuaType::MultiLineUnion(multi_union) => match multi_union.to_union() { @@ -184,16 +225,11 @@ fn get_type_guard_call_info( let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { - let call_expr_type = LuaType::DocFunction(func_type); - let inst_func = infer_call_expr_func( - db, - cache, - call_expr, - call_expr_type, - &InferGuard::new(), - None, - )?; - + let Ok(inst_func) = cache.with_no_flow(|cache| { + instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) + }) else { + return Ok(None); + }; return_type = inst_func.get_ret().clone(); } @@ -339,7 +375,7 @@ fn get_type_at_call_expr_by_call( } if alias_call_type.get_call_kind() == LuaAliasCallKind::RawGet { - let antecedent_type = infer_expr(db, cache, LuaExpr::CallExpr(call_expr))?; + let antecedent_type = get_var_ref_type(db, cache, var_ref_id)?; let result_type = match condition_flow { InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 7d8b9a458..78c464852 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -4,8 +4,11 @@ use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk}; use crate::{ DbIndex, FlowId, FlowTree, InferFailReason, LuaDeclId, LuaFunctionType, LuaInferCache, - LuaSignature, LuaType, TypeOps, infer_expr, instantiate_func_generic, - semantic::infer::{InferResult, VarRefId, narrow::narrow_down_type}, + LuaSignature, LuaType, TypeOps, + semantic::{ + infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, + instantiate_func_generic, + }, }; use super::{ConditionFlowAction, PendingConditionNarrow}; @@ -448,7 +451,6 @@ fn correlated_type_contains(db: &DbIndex, container: &LuaType, target: &LuaType) TypeOps::Union.apply(db, container, target) == *container } -#[allow(clippy::too_many_arguments)] fn collect_matching_correlated_types( db: &DbIndex, cache: &mut LuaInferCache, @@ -547,8 +549,8 @@ fn infer_signature_for_call_ptr<'a>( let Some(prefix_expr) = call_expr.get_prefix_expr() else { return Ok(None); }; - let signature_id = match infer_expr(db, cache, prefix_expr)? { - LuaType::Signature(signature_id) => signature_id, + let signature_id = match try_infer_expr_no_flow(db, cache, prefix_expr)? { + Some(LuaType::Signature(signature_id)) => signature_id, _ => return Ok(None), }; let Some(signature) = db.get_signature_index().get(&signature_id) else { @@ -564,23 +566,28 @@ fn instantiate_return_rows( call_expr: LuaCallExpr, signature: &LuaSignature, ) -> Vec> { + let mut instantiate_return_type = |return_type: LuaType| { + if !return_type.contain_tpl() { + return return_type; + } + + let func = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + return_type.clone(), + ); + match cache + .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) + { + Ok(instantiated) => instantiated.get_ret().clone(), + Err(_) => return_type, + } + }; + if signature.return_overloads.is_empty() { - let return_type = signature.get_return_type(); - let instantiated_return_type = if return_type.contain_tpl() { - let func = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - return_type.clone(), - ); - match instantiate_func_generic(db, cache, &func, call_expr) { - Ok(instantiated) => instantiated.get_ret().clone(), - Err(_) => return_type, - } - } else { - return_type - }; + let instantiated_return_type = instantiate_return_type(signature.get_return_type()); return vec![LuaSignature::return_type_to_row(instantiated_return_type)]; } @@ -588,22 +595,7 @@ fn instantiate_return_rows( for overload in &signature.return_overloads { let type_refs = &overload.type_refs; let overload_return_type = LuaSignature::row_to_return_type(type_refs.to_vec()); - let instantiated_return_type = if overload_return_type.contain_tpl() { - let overload_func = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - overload_return_type.clone(), - ); - match instantiate_func_generic(db, cache, &overload_func, call_expr.clone()) { - Ok(instantiated) => instantiated.get_ret().clone(), - Err(_) => overload_return_type, - } - } else { - overload_return_type - }; - + let instantiated_return_type = instantiate_return_type(overload_return_type); rows.push(LuaSignature::return_type_to_row(instantiated_return_type)); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs index 4a99386e1..81db8af42 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs @@ -5,7 +5,9 @@ use crate::{ semantic::infer::{ VarRefId, narrow::{ - condition_flow::{ConditionFlowAction, InferConditionFlow, PendingConditionNarrow}, + condition_flow::{ + ConditionFlowAction, FieldConditionKind, InferConditionFlow, PendingConditionNarrow, + }, var_ref_id::get_var_expr_var_ref_id, }, }, @@ -43,10 +45,13 @@ pub fn get_type_at_index_expr( return Ok(ConditionFlowAction::Continue); } + let idx = LuaIndexMemberExpr::IndexExpr(index_expr); Ok(ConditionFlowAction::Pending( - PendingConditionNarrow::FieldTruthy { - idx: LuaIndexMemberExpr::IndexExpr(index_expr), + PendingConditionNarrow::Field { + idx, + key_type: None, condition_flow, + kind: FieldConditionKind::Truthy, }, )) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index 505452131..aa2a5c385 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -12,17 +12,23 @@ use self::{ prepare_var_from_return_overload_condition, }, }; -use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr, LuaIndexMemberExpr, UnaryOperator}; +use emmylua_parser::{ + LuaAstNode, LuaCallExpr, LuaChunk, LuaExpr, LuaIndexMemberExpr, UnaryOperator, +}; use crate::{ DbIndex, FlowId, FlowNode, FlowTree, InferFailReason, InferGuard, LuaArrayLen, LuaArrayType, LuaDeclId, LuaInferCache, LuaSignatureCast, LuaSignatureId, LuaType, TypeOps, semantic::infer::{ - VarRefId, - infer_index::infer_member_by_member_key, + InferResult, VarRefId, + infer_index::{infer_member_by_key_type, infer_member_by_member_key}, narrow::{ condition_flow::{ - call_flow::get_type_at_call_expr, index_flow::get_type_at_index_expr, + call_flow::{ + get_type_at_call_expr, get_type_at_call_expr_by_func, + needs_deferred_receiver_method_lookup, + }, + index_flow::get_type_at_index_expr, }, get_single_antecedent, get_type_at_cast_flow::cast_type, @@ -39,43 +45,62 @@ pub enum InferConditionFlow { FalseCondition, } +impl InferConditionFlow { + fn invert(self) -> Self { + match self { + Self::TrueCondition => Self::FalseCondition, + Self::FalseCondition => Self::TrueCondition, + } + } +} + #[derive(Debug, Clone)] -pub(in crate::semantic) enum ConditionSubquery { - ArrayLen { - var_ref_id: VarRefId, - antecedent_flow_id: FlowId, - // This is the effective narrowing polarity after rewrites like `not` and `~=`. - subquery_condition_flow: InferConditionFlow, - right_expr_type: LuaType, - max_adjustment: i64, +pub(in crate::semantic) enum ExprTypeContinuation { + Call { + call_expr: LuaCallExpr, + condition_flow: InferConditionFlow, }, - FieldLiteralEq { - var_ref_id: VarRefId, - antecedent_flow_id: FlowId, - subquery_condition_flow: InferConditionFlow, + ReceiverMethodCall { idx: LuaIndexMemberExpr, - right_expr_type: LuaType, + call_expr: LuaCallExpr, + condition_flow: InferConditionFlow, }, - // 查询 `target.handle` 时遇到 `target.type == "point"` 这类同级判别字段条件. - // 这里先对子查询中的 `target` 应用当前判别条件, 再把结果投影回 `target.handle` - // 避免直接禁用条件导致外层对 `target.handle` 的 nil/赋值 guard 被丢掉. - FieldLiteralSibling { - var_ref_id: VarRefId, - discriminant_prefix_var_ref_id: VarRefId, - antecedent_flow_id: FlowId, + ArrayLen { subquery_condition_flow: InferConditionFlow, - idx: LuaIndexMemberExpr, - right_expr_type: LuaType, + max_adjustment: i64, }, - Correlated { + CorrelatedEq { var_ref_id: VarRefId, - antecedent_flow_id: FlowId, subquery_condition_flow: InferConditionFlow, discriminant_decl_id: LuaDeclId, condition_position: rowan::TextSize, - narrow: CorrelatedDiscriminantNarrow, - fallback_expr: Option, + allow_literal_equivalence: bool, }, + Eq { + condition_flow: InferConditionFlow, + true_result_is_rhs: bool, + }, +} + +#[derive(Debug, Clone)] +pub(in crate::semantic) struct CorrelatedSubquery { + var_ref_id: VarRefId, + antecedent_flow_id: FlowId, + subquery_condition_flow: InferConditionFlow, + discriminant_decl_id: LuaDeclId, + condition_position: rowan::TextSize, + narrow: CorrelatedDiscriminantNarrow, + fallback_expr: Option, +} + +#[derive(Debug, Clone)] +pub(in crate::semantic) struct FieldLiteralSiblingSubquery { + var_ref_id: VarRefId, + discriminant_prefix_var_ref_id: VarRefId, + antecedent_flow_id: FlowId, + condition_flow: InferConditionFlow, + idx: LuaIndexMemberExpr, + right_expr_type: LuaType, } #[derive(Debug, Clone)] @@ -95,18 +120,20 @@ pub(in crate::semantic) enum ConditionFlowAction { Continue, Result(LuaType), Pending(PendingConditionNarrow), - NeedSubquery(ConditionSubquery), + NeedExprType { + flow_id: FlowId, + expr: LuaExpr, + resume: ExprTypeContinuation, + }, + NeedSubquery(CorrelatedSubquery), + NeedFieldLiteralSibling(FieldLiteralSiblingSubquery), NeedCorrelated(PendingCorrelatedCondition), } #[derive(Debug, Clone)] pub(in crate::semantic) enum PendingConditionNarrow { Truthiness(InferConditionFlow), - FieldTruthy { - idx: LuaIndexMemberExpr, - condition_flow: InferConditionFlow, - }, - SameVarColonCall { + ReceiverMethodCall { idx: LuaIndexMemberExpr, condition_flow: InferConditionFlow, }, @@ -118,6 +145,17 @@ pub(in crate::semantic) enum PendingConditionNarrow { right_expr_type: LuaType, condition_flow: InferConditionFlow, }, + Field { + idx: LuaIndexMemberExpr, + key_type: Option, + condition_flow: InferConditionFlow, + kind: FieldConditionKind, + }, + ArrayLen { + right_expr_type: LuaType, + condition_flow: InferConditionFlow, + max_adjustment: i64, + }, TypeGuard { narrow: LuaType, condition_flow: InferConditionFlow, @@ -126,6 +164,12 @@ pub(in crate::semantic) enum PendingConditionNarrow { Correlated(Rc), } +#[derive(Debug, Clone)] +pub(in crate::semantic) enum FieldConditionKind { + Truthy, + LiteralEq { right_expr_type: LuaType }, +} + impl PendingConditionNarrow { pub(in crate::semantic::infer::narrow) fn apply( &self, @@ -134,60 +178,23 @@ impl PendingConditionNarrow { antecedent_type: LuaType, ) -> LuaType { match self { - PendingConditionNarrow::Truthiness(condition_flow) => match condition_flow.clone() { + PendingConditionNarrow::Truthiness(condition_flow) => match *condition_flow { InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), }, - PendingConditionNarrow::FieldTruthy { + PendingConditionNarrow::ReceiverMethodCall { idx, condition_flow, } => { - let LuaType::Union(union_type) = &antecedent_type else { - return antecedent_type; - }; - - let union_types = union_type.into_vec(); - let mut result = vec![]; - for sub_type in &union_types { - let member_type = match infer_member_by_member_key( + let Ok(member_type) = cache.with_no_flow(|cache| { + infer_member_by_member_key( db, cache, - sub_type, + &antecedent_type, idx.clone(), &InferGuard::new(), - ) { - Ok(member_type) => member_type, - Err(_) => continue, - }; - - if !member_type.is_always_falsy() { - result.push(sub_type.clone()); - } - } - - if result.is_empty() { - antecedent_type - } else { - match condition_flow.clone() { - InferConditionFlow::TrueCondition => LuaType::from_vec(result), - InferConditionFlow::FalseCondition => { - let target = LuaType::from_vec(result); - TypeOps::Remove.apply(db, &antecedent_type, &target) - } - } - } - } - PendingConditionNarrow::SameVarColonCall { - idx, - condition_flow, - } => { - let Ok(member_type) = infer_member_by_member_key( - db, - cache, - &antecedent_type, - idx.clone(), - &InferGuard::new(), - ) else { + ) + }) else { return antecedent_type; }; @@ -209,7 +216,7 @@ impl PendingConditionNarrow { antecedent_type, signature_id.clone(), signature_cast, - condition_flow.clone(), + *condition_flow, ) } PendingConditionNarrow::SignatureCast { @@ -226,32 +233,79 @@ impl PendingConditionNarrow { antecedent_type, signature_id.clone(), signature_cast, - condition_flow.clone(), + *condition_flow, ) } PendingConditionNarrow::Eq { right_expr_type, condition_flow, - } => match condition_flow.clone() { - InferConditionFlow::TrueCondition => { - let maybe_type = - TypeOps::Intersect.apply(db, &antecedent_type, right_expr_type); - if maybe_type.is_never() { - antecedent_type - } else { - maybe_type + } => narrow_eq_condition( + db, + antecedent_type, + right_expr_type.clone(), + *condition_flow, + false, + ), + PendingConditionNarrow::Field { + idx, + key_type, + condition_flow, + kind, + } => match kind { + FieldConditionKind::Truthy => { + let narrowed = narrow_field_truthy( + db, + cache, + antecedent_type.clone(), + idx, + key_type.as_ref(), + ); + + match narrowed { + Some(truthy_type) => apply_field_truthy_condition( + db, + antecedent_type, + truthy_type, + *condition_flow, + ), + None => antecedent_type, } } - InferConditionFlow::FalseCondition => { - TypeOps::Remove.apply(db, &antecedent_type, right_expr_type) + FieldConditionKind::LiteralEq { right_expr_type } => narrow_field_literal_eq( + db, + cache, + antecedent_type.clone(), + idx, + key_type.as_ref(), + right_expr_type, + *condition_flow, + ) + .unwrap_or(antecedent_type), + }, + PendingConditionNarrow::ArrayLen { + right_expr_type, + condition_flow, + max_adjustment, + } => match (&antecedent_type, right_expr_type) { + ( + LuaType::Array(array_type), + LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i), + ) if matches!(condition_flow, InferConditionFlow::TrueCondition) => { + let new_array_type = LuaArrayType::new( + array_type.get_base().clone(), + LuaArrayLen::Max(*i + *max_adjustment), + ); + LuaType::Array(new_array_type.into()) } + _ => antecedent_type, }, PendingConditionNarrow::TypeGuard { narrow, condition_flow, - } => match condition_flow.clone() { + } => match *condition_flow { InferConditionFlow::TrueCondition => { - narrow_type_guard(db, antecedent_type, narrow.clone()).unwrap_or(narrow.clone()) + narrow_type_guard(db, antecedent_type, narrow.clone()) + .unwrap_or_else(|| narrow.clone()) } InferConditionFlow::FalseCondition => { TypeOps::Remove.apply(db, &antecedent_type, narrow) @@ -325,6 +379,103 @@ fn narrow_type_guard(db: &DbIndex, antecedent_type: LuaType, narrow: LuaType) -> narrow_down_type(db, antecedent_type, narrow, None) } +pub(super) fn eq_condition_action( + db: &DbIndex, + var_ref_id: &VarRefId, + right_expr_type: LuaType, + condition_flow: InferConditionFlow, + true_result_is_rhs: bool, +) -> ConditionFlowAction { + if matches!(condition_flow, InferConditionFlow::FalseCondition) { + return ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + }); + } + + if true_result_is_rhs { + return ConditionFlowAction::Result(right_expr_type); + } + + // self is special; drop nil directly instead of replaying a normal equality narrow. + if var_ref_id.is_self_ref() && !right_expr_type.is_nil() { + return ConditionFlowAction::Result(TypeOps::Remove.apply( + db, + &right_expr_type, + &LuaType::Nil, + )); + } + + ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + }) +} + +fn narrow_field_truthy( + db: &DbIndex, + cache: &mut LuaInferCache, + antecedent_type: LuaType, + idx: &LuaIndexMemberExpr, + key_type: Option<&LuaType>, +) -> Option { + let LuaType::Union(union_type) = &antecedent_type else { + return None; + }; + + let union_types = union_type.into_vec(); + let mut result = vec![]; + for sub_type in &union_types { + let member_type = match infer_pending_field_member(db, cache, &sub_type, idx, key_type) { + Ok(member_type) => member_type, + Err(_) => continue, + }; + + if !member_type.is_always_falsy() { + result.push(sub_type.clone()); + } + } + + (!result.is_empty()).then(|| LuaType::from_vec(result)) +} + +fn infer_pending_field_member( + db: &DbIndex, + cache: &mut LuaInferCache, + prefix_type: &LuaType, + idx: &LuaIndexMemberExpr, + key_type: Option<&LuaType>, +) -> InferResult { + cache.with_no_flow(|cache| { + if let Some(key_type) = key_type { + infer_member_by_key_type( + db, + cache, + prefix_type, + idx.clone(), + key_type, + &InferGuard::new(), + ) + } else { + infer_member_by_member_key(db, cache, prefix_type, idx.clone(), &InferGuard::new()) + } + }) +} + +fn apply_field_truthy_condition( + db: &DbIndex, + antecedent_type: LuaType, + truthy_type: LuaType, + condition_flow: InferConditionFlow, +) -> LuaType { + match condition_flow { + InferConditionFlow::TrueCondition => truthy_type, + InferConditionFlow::FalseCondition => { + TypeOps::Remove.apply(db, &antecedent_type, &truthy_type) + } + } +} + fn apply_signature_cast( db: &DbIndex, antecedent_type: LuaType, @@ -410,17 +561,15 @@ pub(super) fn get_type_at_condition_flow( let fallback_expr = tree .get_decl_ref_expr(&decl_id) .and_then(|expr_ptr| expr_ptr.to_node(root)); - return Ok(ConditionFlowAction::NeedSubquery( - ConditionSubquery::Correlated { - var_ref_id: VarRefId::VarRef(decl_id), - antecedent_flow_id, - subquery_condition_flow: condition_flow, - discriminant_decl_id: decl_id, - condition_position: name_expr.get_position(), - narrow: CorrelatedDiscriminantNarrow::Truthiness, - fallback_expr, - }, - )); + return Ok(ConditionFlowAction::NeedSubquery(CorrelatedSubquery { + var_ref_id: VarRefId::VarRef(decl_id), + antecedent_flow_id, + subquery_condition_flow: condition_flow, + discriminant_decl_id: decl_id, + condition_position: name_expr.get_position(), + narrow: CorrelatedDiscriminantNarrow::Truthiness, + fallback_expr, + })); } let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { @@ -433,7 +582,14 @@ pub(super) fn get_type_at_condition_flow( continue; } LuaExpr::CallExpr(call_expr) => { - return get_type_at_call_expr(db, cache, var_ref_id, call_expr, condition_flow); + return get_type_at_call_expr( + db, + cache, + var_ref_id, + flow_node, + call_expr, + condition_flow, + ); } LuaExpr::IndexExpr(index_expr) => { return get_type_at_index_expr(db, cache, var_ref_id, index_expr, condition_flow); @@ -465,10 +621,7 @@ pub(super) fn get_type_at_condition_flow( } condition = inner_expr; - condition_flow = match condition_flow { - InferConditionFlow::TrueCondition => InferConditionFlow::FalseCondition, - InferConditionFlow::FalseCondition => InferConditionFlow::TrueCondition, - }; + condition_flow = condition_flow.invert(); continue; } LuaExpr::ParenExpr(paren_expr) => { @@ -482,188 +635,175 @@ pub(super) fn get_type_at_condition_flow( } } -#[allow(clippy::too_many_arguments)] -pub(in crate::semantic::infer::narrow) fn resolve_condition_subquery( +struct CorrelatedSubqueryCtx<'a> { + db: &'a DbIndex, + tree: &'a FlowTree, + cache: &'a mut LuaInferCache, + root: &'a LuaChunk, + var_ref_id: &'a VarRefId, + flow_node: &'a FlowNode, +} + +pub(in crate::semantic::infer::narrow) fn resolve_correlated_subquery( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, root: &LuaChunk, var_ref_id: &VarRefId, flow_node: &FlowNode, - subquery: ConditionSubquery, - antecedent_type: LuaType, + subquery: CorrelatedSubquery, + antecedent_result: InferResult, ) -> Result { - match subquery { - ConditionSubquery::ArrayLen { - subquery_condition_flow, - right_expr_type, - max_adjustment, - .. - } => match (&antecedent_type, &right_expr_type) { - ( - LuaType::Array(array_type), - LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i), - ) if matches!(subquery_condition_flow, InferConditionFlow::TrueCondition) => { - let new_array_type = LuaArrayType::new( - array_type.get_base().clone(), - LuaArrayLen::Max(*i + max_adjustment), - ); - Ok(ConditionFlowAction::Result(LuaType::Array( - new_array_type.into(), - ))) - } - _ => Ok(ConditionFlowAction::Continue), - }, - ConditionSubquery::FieldLiteralEq { - subquery_condition_flow, + let mut ctx = CorrelatedSubqueryCtx { + db, + tree, + cache, + root, + var_ref_id, + flow_node, + }; + + subquery.resolve(&mut ctx, antecedent_result) +} + +pub(in crate::semantic::infer::narrow) fn resolve_expr_type_continuation( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, + antecedent_flow_id: FlowId, + resume: ExprTypeContinuation, + expr_type: LuaType, +) -> Result { + match resume { + ExprTypeContinuation::Call { + call_expr, + condition_flow, + } => get_type_at_call_expr_by_func( + db, + cache, + var_ref_id, + call_expr, + expr_type, + condition_flow, + ), + ExprTypeContinuation::ReceiverMethodCall { idx, - right_expr_type, - .. - } => Ok(narrow_union_by_field_literal_condition( + call_expr, + condition_flow, + } => resolve_receiver_method_call( db, cache, - antecedent_type, + var_ref_id, + expr_type, idx, - right_expr_type, + call_expr, + condition_flow, + ), + ExprTypeContinuation::ArrayLen { subquery_condition_flow, - )? - .map(ConditionFlowAction::Result) - .unwrap_or(ConditionFlowAction::Continue)), - ConditionSubquery::FieldLiteralSibling { + max_adjustment, + } => Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::ArrayLen { + right_expr_type: expr_type, + condition_flow: subquery_condition_flow, + max_adjustment, + }, + )), + ExprTypeContinuation::CorrelatedEq { var_ref_id, - discriminant_prefix_var_ref_id, subquery_condition_flow, - idx, - right_expr_type, - .. - } => { - let Some(narrowed_prefix_type) = narrow_union_by_field_literal_condition( - db, - cache, - antecedent_type, - idx, - right_expr_type, - subquery_condition_flow, - )? - else { - return Ok(ConditionFlowAction::Continue); - }; - - let Some(projected_type) = project_relative_member_type( - db, - &narrowed_prefix_type, - &var_ref_id, - &discriminant_prefix_var_ref_id, - )? - else { - return Ok(ConditionFlowAction::Continue); - }; - Ok(ConditionFlowAction::Pending( - PendingConditionNarrow::NarrowTo(projected_type), - )) - } - ConditionSubquery::Correlated { + discriminant_decl_id, + condition_position, + allow_literal_equivalence, + } => Ok(ConditionFlowAction::NeedSubquery(CorrelatedSubquery { + var_ref_id, antecedent_flow_id, subquery_condition_flow, discriminant_decl_id, condition_position, - narrow, - fallback_expr, - .. - } => { - let narrowed_discriminant_type = match narrow { - CorrelatedDiscriminantNarrow::Truthiness => match subquery_condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), - }, - CorrelatedDiscriminantNarrow::TypeGuard { narrow } => match subquery_condition_flow - { - InferConditionFlow::TrueCondition => { - narrow_down_type(db, antecedent_type, narrow.clone(), None) - .unwrap_or(narrow) - } - InferConditionFlow::FalseCondition => { - TypeOps::Remove.apply(db, &antecedent_type, &narrow) - } - }, - CorrelatedDiscriminantNarrow::Eq { - right_expr_type, - allow_literal_equivalence, - } => narrow_eq_condition( - db, - antecedent_type, - right_expr_type, - subquery_condition_flow, - allow_literal_equivalence, - ), - }; + narrow: CorrelatedDiscriminantNarrow::Eq { + right_expr_type: expr_type, + allow_literal_equivalence, + }, + fallback_expr: None, + })), + ExprTypeContinuation::Eq { + condition_flow, + true_result_is_rhs, + } => Ok(eq_condition_action( + db, + var_ref_id, + expr_type, + condition_flow, + true_result_is_rhs, + )), + } +} - let action = prepare_var_from_return_overload_condition( - db, - tree, - cache, - root, - var_ref_id, - discriminant_decl_id, - condition_position, - antecedent_flow_id, - &narrowed_discriminant_type, - )?; - - let Some(fallback_expr) = fallback_expr else { - return Ok(action); - }; - - if !matches!(action, ConditionFlowAction::Continue) { - return Ok(action); - } +fn resolve_receiver_method_call( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, + receiver_type: LuaType, + idx: LuaIndexMemberExpr, + call_expr: LuaCallExpr, + condition_flow: InferConditionFlow, +) -> Result { + let member_type = match cache.with_no_flow(|cache| { + infer_member_by_member_key(db, cache, &receiver_type, idx.clone(), &InferGuard::new()) + }) { + Ok(member_type) => member_type, + Err(_) => return Ok(ConditionFlowAction::Continue), + }; - get_type_at_condition_flow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - fallback_expr, - subquery_condition_flow, - ) - } + if needs_deferred_receiver_method_lookup(&member_type) { + return Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::ReceiverMethodCall { + idx, + condition_flow, + }, + )); } + + get_type_at_call_expr_by_func( + db, + cache, + var_ref_id, + call_expr, + member_type, + condition_flow, + ) } -fn narrow_union_by_field_literal_condition( +fn narrow_field_literal_eq( db: &DbIndex, cache: &mut LuaInferCache, antecedent_type: LuaType, - idx: LuaIndexMemberExpr, - right_expr_type: LuaType, + idx: &LuaIndexMemberExpr, + key_type: Option<&LuaType>, + right_expr_type: &LuaType, condition_flow: InferConditionFlow, -) -> Result, InferFailReason> { +) -> Option { let LuaType::Union(union_type) = antecedent_type else { - return Ok(None); + return None; }; - let union_types = union_type.into_vec(); let mut matched = Vec::new(); let mut unmatched = Vec::new(); let mut has_matched = false; - for sub_type in union_types { - let member_type = - match infer_member_by_member_key(db, cache, &sub_type, idx.clone(), &InferGuard::new()) - { - Ok(member_type) => member_type, - Err(_) => { - unmatched.push(sub_type); - continue; - } - }; - - if always_literal_equal(&member_type, &right_expr_type) { + for sub_type in union_type.into_vec() { + let member_type = match infer_pending_field_member(db, cache, &sub_type, idx, key_type) { + Ok(member_type) => member_type, + Err(_) => { + unmatched.push(sub_type); + continue; + } + }; + if always_literal_equal(&member_type, right_expr_type) { has_matched = true; - matched.push(sub_type); + matched.push(sub_type.clone()); } else { - unmatched.push(sub_type); + unmatched.push(sub_type.clone()); } } @@ -672,11 +812,53 @@ fn narrow_union_by_field_literal_condition( InferConditionFlow::FalseCondition => unmatched, }; if !has_matched { - Ok(None) + None } else if result.is_empty() { - Ok(Some(LuaType::Never)) + Some(LuaType::Never) } else { - Ok(Some(LuaType::from_vec(result))) + Some(LuaType::from_vec(result)) + } +} + +impl FieldLiteralSiblingSubquery { + pub(in crate::semantic::infer::narrow) fn next_flow_query(&self) -> (&VarRefId, FlowId) { + ( + &self.discriminant_prefix_var_ref_id, + self.antecedent_flow_id, + ) + } + + pub(in crate::semantic::infer::narrow) fn resolve( + self, + db: &DbIndex, + cache: &mut LuaInferCache, + antecedent_type: LuaType, + ) -> Result { + let Some(narrowed_prefix_type) = narrow_field_literal_eq( + db, + cache, + antecedent_type, + &self.idx, + None, + &self.right_expr_type, + self.condition_flow, + ) else { + return Ok(ConditionFlowAction::Continue); + }; + + let Some(projected_type) = project_relative_member_type( + db, + &narrowed_prefix_type, + &self.var_ref_id, + &self.discriminant_prefix_var_ref_id, + )? + else { + return Ok(ConditionFlowAction::Continue); + }; + + Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::NarrowTo(projected_type), + )) } } @@ -770,6 +952,79 @@ fn project_union_member_type( Ok(Some(result_type)) } +impl CorrelatedSubquery { + pub(in crate::semantic::infer::narrow) fn next_flow_query(&self) -> (&VarRefId, FlowId) { + (&self.var_ref_id, self.antecedent_flow_id) + } + + fn resolve( + self, + ctx: &mut CorrelatedSubqueryCtx<'_>, + antecedent_result: InferResult, + ) -> Result { + let correlated = self; + let antecedent_type = antecedent_result?; + let narrowed_discriminant_type = match correlated.narrow { + CorrelatedDiscriminantNarrow::Truthiness => match correlated.subquery_condition_flow { + InferConditionFlow::FalseCondition => narrow_false_or_nil(ctx.db, antecedent_type), + InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), + }, + CorrelatedDiscriminantNarrow::TypeGuard { narrow } => { + match correlated.subquery_condition_flow { + InferConditionFlow::TrueCondition => { + narrow_down_type(ctx.db, antecedent_type, narrow.clone(), None) + .unwrap_or(narrow) + } + InferConditionFlow::FalseCondition => { + TypeOps::Remove.apply(ctx.db, &antecedent_type, &narrow) + } + } + } + CorrelatedDiscriminantNarrow::Eq { + right_expr_type, + allow_literal_equivalence, + } => narrow_eq_condition( + ctx.db, + antecedent_type, + right_expr_type, + correlated.subquery_condition_flow, + allow_literal_equivalence, + ), + }; + + let action = prepare_var_from_return_overload_condition( + ctx.db, + ctx.tree, + &mut *ctx.cache, + ctx.root, + ctx.var_ref_id, + correlated.discriminant_decl_id, + correlated.condition_position, + correlated.antecedent_flow_id, + &narrowed_discriminant_type, + )?; + + let Some(fallback_expr) = correlated.fallback_expr else { + return Ok(action); + }; + + if !matches!(action, ConditionFlowAction::Continue) { + return Ok(action); + } + + get_type_at_condition_flow( + ctx.db, + ctx.tree, + &mut *ctx.cache, + ctx.root, + ctx.var_ref_id, + ctx.flow_node, + fallback_expr, + correlated.subquery_condition_flow, + ) + } +} + pub(super) fn always_literal_equal(left: &LuaType, right: &LuaType) -> bool { match (left, right) { (LuaType::Union(union), other) => union diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index d8c7b2c14..702b2ab73 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -1,30 +1,34 @@ use emmylua_parser::{ - LuaAssignStat, LuaAstNode, LuaChunk, LuaDocOpType, LuaExpr, LuaVarExpr, UnaryOperator, + LuaAssignStat, LuaAstNode, LuaChunk, LuaDocOpType, LuaExpr, LuaIndexKey, LuaIndexMemberExpr, + LuaSyntaxId, LuaTableExpr, LuaVarExpr, }; use hashbrown::HashSet; use std::{rc::Rc, sync::Arc}; use crate::{ CacheEntry, DbIndex, FlowId, FlowNode, FlowNodeKind, FlowTree, InferFailReason, LuaDeclId, - LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, check_type_compact, infer_expr, + LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, check_type_compact, semantic::{ - cache::{FlowAssignmentInfo, FlowConditionInfo, FlowMode, FlowVarCache}, + cache::{FlowAssignmentInfo, FlowMode, FlowVarCache}, infer::{ - InferResult, VarRefId, infer_expr_list_value_type_at, + InferResult, VarRefId, narrow::{ condition_flow::{ - ConditionFlowAction, ConditionSubquery, InferConditionFlow, - PendingConditionNarrow, + ConditionFlowAction, CorrelatedSubquery, ExprTypeContinuation, + FieldConditionKind, FieldLiteralSiblingSubquery, InferConditionFlow, + PendingConditionNarrow, always_literal_equal, correlated_flow::{ PendingCorrelatedCondition, advance_pending_correlated_condition, }, - get_type_at_condition_flow, resolve_condition_subquery, + get_type_at_condition_flow, resolve_correlated_subquery, + resolve_expr_type_continuation, }, get_multi_antecedents, get_single_antecedent, get_type_at_cast_flow::cast_type, get_var_ref_type, narrow_down_type, var_ref_id::get_var_expr_var_ref_id, }, + try_infer_expr_no_flow, }, member::find_members, }, @@ -87,13 +91,21 @@ enum Continuation { merged_type: LuaType, }, // Resume an assignment once we know the pre-assignment type of the same ref. - // Example: for `x = rhs`, first query `x` just before the assignment, then - // combine that antecedent type with the RHS type here. + // Example: for `x = expr`, first query `x` just before the assignment, + // then combine that antecedent type with the expression type here. AssignmentAntecedent { walk: QueryWalk, antecedent_flow_id: FlowId, expr_type: LuaType, - reuse_source_narrowing: bool, + reuse_antecedent_narrowing: bool, + }, + // Resume structural expression replay after resolving the flow-aware refs + // it depends on. The replay itself stays no-flow; only this continuation + // may schedule the dependency queries. + ExprReplay { + walk: QueryWalk, + replay: FlowExprReplay, + replay_query: FlowReplayQuery, }, // Resume a tag cast after reading the antecedent value that the cast rewrites. // Example: `---@cast x Foo` first queries `x` before the cast node, then @@ -109,7 +121,13 @@ enum Continuation { walk: QueryWalk, flow_id: FlowId, condition_flow: InferConditionFlow, - subquery: ConditionSubquery, + subquery: CorrelatedSubquery, + }, + FieldLiteralSiblingDependency { + walk: QueryWalk, + flow_id: FlowId, + condition_flow: InferConditionFlow, + subquery: FieldLiteralSiblingSubquery, }, // Resume correlated return-overload narrowing after querying one pending root. // Example: `local ok, value = f(); if ok then ... value ... end` may need to @@ -122,6 +140,209 @@ enum Continuation { }, } +enum FlowExprReplay { + Assignment { + antecedent_flow_id: FlowId, + explicit_var_type: Option, + result_slot: usize, + }, + DeclInitializer { + fail_reason: InferFailReason, + }, + Condition { + condition_flow_id: FlowId, + condition_flow: InferConditionFlow, + resume: ExprTypeContinuation, + }, + FieldConditionKey { + condition_flow_id: FlowId, + condition_flow: InferConditionFlow, + idx: LuaIndexMemberExpr, + field_condition_flow: InferConditionFlow, + kind: FieldConditionKind, + }, +} + +// Dependency queries are flow-aware, but the final expression replay is not. +// This owns both phases so replay cannot accidentally re-enter flow. +struct FlowReplayQuery { + flow_id: FlowId, + expr: LuaExpr, + allow_table_exprs: bool, + dependency_queries: Vec, + next_dependency_idx: usize, + dependency_types: Vec<(LuaSyntaxId, LuaType)>, +} + +impl FlowReplayQuery { + fn new( + db: &DbIndex, + tree: Option<&FlowTree>, + cache: &mut LuaInferCache, + flow_id: FlowId, + expr: LuaExpr, + allow_table_exprs: bool, + ) -> Self { + let mut dependency_queries = Vec::new(); + collect_expr_dependency_queries(db, tree, cache, flow_id, &expr, &mut dependency_queries); + Self { + flow_id, + expr, + allow_table_exprs, + dependency_queries, + next_dependency_idx: 0, + dependency_types: Vec::new(), + } + } + + fn next_query(&self) -> Option<&FlowExprTypeQuery> { + self.dependency_queries.get(self.next_dependency_idx) + } + + fn accept_result(&mut self, dependency_result: InferResult) -> Result<(), InferFailReason> { + let dependency_query = self + .dependency_queries + .get(self.next_dependency_idx) + .ok_or(InferFailReason::None)?; + + match dependency_result { + Ok(mut expr_type) => { + if let Some(literal_shape_type) = &dependency_query.literal_shape_type { + expr_type = literal_equivalent_type(literal_shape_type, &expr_type) + .unwrap_or(expr_type); + } + + self.dependency_types + .push((dependency_query.syntax_id, expr_type)); + } + Err( + InferFailReason::None + | InferFailReason::RecursiveInfer + | InferFailReason::FieldNotFound, + ) => {} + Err(err) => return Err(err), + } + + self.next_dependency_idx += 1; + Ok(()) + } + + fn replay_type( + self, + db: &DbIndex, + cache: &mut LuaInferCache, + ) -> Result, InferFailReason> { + let Self { + expr, + allow_table_exprs, + dependency_types, + .. + } = self; + replay_expr_no_flow(db, cache, expr, &dependency_types, allow_table_exprs) + } +} + +// The replay overlay should preserve declared doc literals when a flow query +// proves the same runtime literal value for an index expression. +fn literal_equivalent_type(source_type: &LuaType, target_type: &LuaType) -> Option { + match source_type { + LuaType::Union(union) => { + let matches = union + .into_vec() + .into_iter() + .filter(|candidate| always_literal_equal(candidate, target_type)) + .collect::>(); + (!matches.is_empty()).then(|| LuaType::from_vec(matches)) + } + _ if always_literal_equal(source_type, target_type) => Some(source_type.clone()), + _ => None, + } +} + +#[derive(Debug, Clone)] +struct FlowExprTypeQuery { + var_ref_id: VarRefId, + flow_id: FlowId, + syntax_id: LuaSyntaxId, + literal_shape_type: Option, +} + +fn collect_expr_dependency_queries( + db: &DbIndex, + tree: Option<&FlowTree>, + cache: &mut LuaInferCache, + fallback_flow_id: FlowId, + expr: &LuaExpr, + dependency_queries: &mut Vec, +) { + if matches!(expr, LuaExpr::ClosureExpr(_)) { + return; + } + + if matches!(expr, LuaExpr::NameExpr(_)) { + let flow_id = tree + .and_then(|tree| tree.get_flow_id(expr.get_syntax_id())) + .unwrap_or(fallback_flow_id); + if let Some(var_ref_id) = get_var_expr_var_ref_id(db, cache, expr.clone()) { + dependency_queries.push(FlowExprTypeQuery { + var_ref_id, + flow_id, + syntax_id: expr.get_syntax_id(), + literal_shape_type: None, + }); + } + return; + } + + if let LuaExpr::IndexExpr(index_expr) = expr { + if let Some(prefix_expr) = index_expr.get_prefix_expr() { + collect_expr_dependency_queries( + db, + tree, + cache, + fallback_flow_id, + &prefix_expr, + dependency_queries, + ); + } + if let Some(LuaIndexKey::Expr(expr)) = index_expr.get_index_key() { + collect_expr_dependency_queries( + db, + tree, + cache, + fallback_flow_id, + &expr, + dependency_queries, + ); + } + // Keep prefix/key deps first. The direct IndexRef repairs guards on + // the member itself, and should use the replay point's flow. + if let Some(var_ref_id) = get_var_expr_var_ref_id(db, cache, expr.clone()) { + let literal_shape_type = try_infer_expr_no_flow(db, cache, expr.clone()) + .ok() + .flatten(); + dependency_queries.push(FlowExprTypeQuery { + var_ref_id, + flow_id: fallback_flow_id, + syntax_id: expr.get_syntax_id(), + literal_shape_type, + }); + } + return; + } + + for child_expr in expr.children::() { + collect_expr_dependency_queries( + db, + tree, + cache, + fallback_flow_id, + &child_expr, + dependency_queries, + ); + } +} + // The top-loop scheduler decision. // `StartQuery` begins one query, optionally saving the current query first. // `ContinueWalk` keeps scanning backward through the current query. @@ -194,14 +415,19 @@ impl<'a> FlowTypeEngine<'a> { walk, antecedent_flow_id, expr_type, - reuse_source_narrowing, + reuse_antecedent_narrowing, }) => self.resume_assignment_antecedent( walk, antecedent_flow_id, expr_type, - reuse_source_narrowing, + reuse_antecedent_narrowing, query_result, ), + Some(Continuation::ExprReplay { + walk, + replay, + replay_query, + }) => self.resume_expr_replay(walk, replay, replay_query, query_result), Some(Continuation::TagCastAntecedent { walk, cast_op_types, @@ -218,6 +444,18 @@ impl<'a> FlowTypeEngine<'a> { subquery, query_result, ), + Some(Continuation::FieldLiteralSiblingDependency { + walk, + flow_id, + condition_flow, + subquery, + }) => self.resume_field_literal_sibling_subquery( + walk, + flow_id, + condition_flow, + subquery, + query_result, + ), Some(Continuation::CorrelatedSearchRoot { walk, flow_id, @@ -308,23 +546,23 @@ impl<'a> FlowTypeEngine<'a> { // Finish one assignment dependency query by reading the pre-assignment type // of the same ref, optionally retrying without condition narrows, then - // combining that antecedent type with the RHS type to finish the suspended + // combining that antecedent type with the expression type to finish the suspended // query. fn resume_assignment_antecedent( &mut self, walk: QueryWalk, antecedent_flow_id: FlowId, expr_type: LuaType, - reuse_source_narrowing: bool, - source_result: InferResult, + reuse_antecedent_narrowing: bool, + antecedent_result: InferResult, ) -> Result { - let source_type = match source_result { - Ok(source_type) => source_type, + let antecedent_type = match antecedent_result { + Ok(antecedent_type) => antecedent_type, Err(err) => return self.fail_query(&walk.query, err), }; - if reuse_source_narrowing - && !can_reuse_narrowed_assignment_source(self.db, &source_type, &expr_type) + if reuse_antecedent_narrowing + && !can_reuse_narrowed_assignment_source(self.db, &antecedent_type, &expr_type) { let next_query = walk .query @@ -335,7 +573,7 @@ impl<'a> FlowTypeEngine<'a> { walk, antecedent_flow_id, expr_type, - reuse_source_narrowing: false, + reuse_antecedent_narrowing: false, }), }); } @@ -343,15 +581,184 @@ impl<'a> FlowTypeEngine<'a> { let result_type = finish_assignment_result( self.db, self.cache, - &source_type, + &antecedent_type, &expr_type, &walk.query.var_ref_id, - reuse_source_narrowing, + reuse_antecedent_narrowing, None, ); Ok(self.finish_walk(walk, result_type)) } + fn start_expr_replay( + &mut self, + walk: QueryWalk, + replay: FlowExprReplay, + replay_query: FlowReplayQuery, + ) -> Result { + let next_query = replay_query + .next_query() + .map(|query| (query.var_ref_id.clone(), query.flow_id)); + let Some((var_ref_id, flow_id)) = next_query else { + return self.finish_expr_replay(walk, replay, replay_query); + }; + + Ok(SchedulerStep::StartQuery { + query: FlowQuery::new(self.cache, &var_ref_id, flow_id), + continuation: Some(Continuation::ExprReplay { + walk, + replay, + replay_query, + }), + }) + } + + fn resume_expr_replay( + &mut self, + walk: QueryWalk, + replay: FlowExprReplay, + mut replay_query: FlowReplayQuery, + query_result: InferResult, + ) -> Result { + match replay_query.accept_result(query_result) { + Ok(()) => {} + Err(err) => return self.finish_expr_replay_error(walk, replay, err), + } + + self.start_expr_replay(walk, replay, replay_query) + } + + fn finish_expr_replay( + &mut self, + walk: QueryWalk, + replay: FlowExprReplay, + replay_query: FlowReplayQuery, + ) -> Result { + match replay { + FlowExprReplay::Assignment { + antecedent_flow_id, + explicit_var_type, + result_slot, + } => self.finish_assignment_expr( + walk, + antecedent_flow_id, + explicit_var_type, + result_slot, + replay_query, + ), + FlowExprReplay::DeclInitializer { fail_reason } => { + let query = walk.query.clone(); + let expr_type = match replay_query.replay_type(self.db, self.cache) { + Ok(Some(expr_type)) => expr_type, + Ok(None) => return self.fail_query(&query, fail_reason), + Err(err) => return self.fail_query(&query, err), + }; + + let Some(init_type) = expr_type.get_result_slot_type(0) else { + return self.fail_query(&query, fail_reason); + }; + + Ok(self.finish_walk(walk, init_type)) + } + FlowExprReplay::Condition { + condition_flow_id, + condition_flow, + resume, + } => { + let expr_flow_id = replay_query.flow_id; + let query = walk.query.clone(); + let var_ref_id = query.var_ref_id.clone(); + let action_result = match replay_query.replay_type(self.db, self.cache) { + Ok(Some(expr_type)) => resolve_expr_type_continuation( + self.db, + self.cache, + &var_ref_id, + expr_flow_id, + resume, + expr_type, + ), + Ok(None) => Ok(ConditionFlowAction::Continue), + Err(err) => Err(err), + }; + let action = match action_result { + Ok(action) => action, + Err(err) => { + return self.fail_condition_query( + &query, + condition_flow_id, + condition_flow, + err, + ); + } + }; + self.apply_condition_action(walk, condition_flow_id, condition_flow, action) + .or_else(|err| { + self.fail_condition_query(&query, condition_flow_id, condition_flow, err) + }) + } + FlowExprReplay::FieldConditionKey { + condition_flow_id, + condition_flow, + idx, + field_condition_flow, + kind, + } => { + let query = walk.query.clone(); + let key_type = match replay_query.replay_type(self.db, self.cache) { + Ok(key_type) => key_type, + Err(err) => { + return self.fail_condition_query( + &query, + condition_flow_id, + condition_flow, + err, + ); + } + }; + + Ok(self.push_pending_condition( + walk, + condition_flow_id, + condition_flow, + PendingConditionNarrow::Field { + idx, + key_type, + condition_flow: field_condition_flow, + kind, + }, + )) + } + } + } + + fn finish_expr_replay_error( + &mut self, + walk: QueryWalk, + replay: FlowExprReplay, + err: InferFailReason, + ) -> Result { + match replay { + FlowExprReplay::Assignment { + antecedent_flow_id, + explicit_var_type, + .. + } => { + self.finish_assignment_expr_error(walk, antecedent_flow_id, explicit_var_type, err) + } + FlowExprReplay::DeclInitializer { .. } => self.fail_query(&walk.query, err), + FlowExprReplay::Condition { + condition_flow_id, + condition_flow, + .. + } + | FlowExprReplay::FieldConditionKey { + condition_flow_id, + condition_flow, + .. + } => self.fail_condition_query(&walk.query, condition_flow_id, condition_flow, err), + } + } + // Finish one tag-cast dependency query by reading the antecedent type and // replaying the cast operators in source order, then finish the suspended // query with the cast result. @@ -392,17 +799,16 @@ impl<'a> FlowTypeEngine<'a> { walk: QueryWalk, flow_id: FlowId, condition_flow: InferConditionFlow, - subquery: ConditionSubquery, + subquery: CorrelatedSubquery, antecedent_result: InferResult, ) -> Result { let query = walk.query.clone(); let result = (|| { - let antecedent_type = antecedent_result?; let flow_node = self .tree .get_flow_node(flow_id) .ok_or(InferFailReason::None)?; - let action = resolve_condition_subquery( + let action = resolve_correlated_subquery( self.db, self.tree, self.cache, @@ -410,17 +816,143 @@ impl<'a> FlowTypeEngine<'a> { &query.var_ref_id, flow_node, subquery, - antecedent_type, + antecedent_result, )?; self.apply_condition_action(walk, flow_id, condition_flow, action) })(); - result.or_else(|err| { - get_flow_var_cache(self.cache, query.var_cache_idx) - .condition_cache - .remove(&(flow_id, condition_flow)); - self.fail_query(&query, err) - }) + result.or_else(|err| self.fail_condition_query(&query, flow_id, condition_flow, err)) + } + + fn start_condition_subquery( + &mut self, + walk: QueryWalk, + flow_id: FlowId, + condition_flow: InferConditionFlow, + subquery: CorrelatedSubquery, + ) -> SchedulerStep { + let (subquery_var_ref_id, subquery_flow_id) = subquery.next_flow_query(); + let query = FlowQuery::new(self.cache, subquery_var_ref_id, subquery_flow_id); + SchedulerStep::StartQuery { + query, + continuation: Some(Continuation::ConditionDependency { + walk, + flow_id, + condition_flow, + subquery, + }), + } + } + + fn resume_field_literal_sibling_subquery( + &mut self, + walk: QueryWalk, + flow_id: FlowId, + condition_flow: InferConditionFlow, + subquery: FieldLiteralSiblingSubquery, + antecedent_result: InferResult, + ) -> Result { + let query = walk.query.clone(); + let result = (|| { + let antecedent_type = antecedent_result?; + let action = subquery.resolve(self.db, self.cache, antecedent_type)?; + self.apply_condition_action(walk, flow_id, condition_flow, action) + })(); + + result.or_else(|err| self.fail_condition_query(&query, flow_id, condition_flow, err)) + } + + fn start_field_literal_sibling_subquery( + &mut self, + walk: QueryWalk, + flow_id: FlowId, + condition_flow: InferConditionFlow, + subquery: FieldLiteralSiblingSubquery, + ) -> SchedulerStep { + let (subquery_var_ref_id, subquery_flow_id) = subquery.next_flow_query(); + let query = FlowQuery::new(self.cache, subquery_var_ref_id, subquery_flow_id); + SchedulerStep::StartQuery { + query, + continuation: Some(Continuation::FieldLiteralSiblingDependency { + walk, + flow_id, + condition_flow, + subquery, + }), + } + } + + fn push_pending_condition( + &mut self, + mut walk: QueryWalk, + flow_id: FlowId, + condition_flow: InferConditionFlow, + pending_condition_narrow: PendingConditionNarrow, + ) -> SchedulerStep { + get_flow_var_cache(self.cache, walk.query.var_cache_idx) + .condition_cache + .insert( + (flow_id, condition_flow), + CacheEntry::Cache(ConditionFlowAction::Pending( + pending_condition_narrow.clone(), + )), + ); + walk.pending_condition_narrows + .push(pending_condition_narrow); + SchedulerStep::ContinueWalk(walk) + } + + fn start_pending_condition( + &mut self, + walk: QueryWalk, + flow_id: FlowId, + condition_flow: InferConditionFlow, + pending_condition_narrow: PendingConditionNarrow, + ) -> Result { + let (idx, field_condition_flow, kind) = match pending_condition_narrow { + PendingConditionNarrow::Field { + idx, + key_type: None, + condition_flow: field_condition_flow, + kind, + } => (idx, field_condition_flow, kind), + pending_condition_narrow => { + return Ok(self.push_pending_condition( + walk, + flow_id, + condition_flow, + pending_condition_narrow, + )); + } + }; + let antecedent_flow_id = walk.antecedent_flow_id; + let Some(LuaIndexKey::Expr(expr)) = idx.get_index_key() else { + return Ok(self.push_pending_condition( + walk, + flow_id, + condition_flow, + PendingConditionNarrow::Field { + idx, + key_type: None, + condition_flow: field_condition_flow, + kind, + }, + )); + }; + let replay_query = + FlowReplayQuery::new(self.db, None, self.cache, antecedent_flow_id, expr, false); + + self.start_expr_replay( + walk, + FlowExprReplay::FieldConditionKey { + condition_flow_id: flow_id, + condition_flow, + idx, + field_condition_flow, + kind, + }, + replay_query, + ) } fn step_assignment( @@ -456,42 +988,67 @@ impl<'a> FlowTypeEngine<'a> { .filter(|tc| tc.is_doc()) .map(|tc| tc.as_type().clone()); - let expr_type = - match infer_expr_list_value_type_at(self.db, self.cache, &assignment_info.exprs, i) { - Ok(expr_type) => expr_type, - Err(err) => { - if let Some(explicit_var_type) = explicit_var_type.as_ref() { - return Ok(self.finish_walk(walk, explicit_var_type.clone())); - } - - if matches!(var_ref_id, VarRefId::IndexRef(_, _)) - && let Ok(origin_type) = get_var_ref_type(self.db, self.cache, &var_ref_id) - { - let non_nil_origin = - TypeOps::Remove.apply(self.db, &origin_type, &LuaType::Nil); - return Ok(self.finish_walk( - walk, - if non_nil_origin.is_never() { - origin_type - } else { - non_nil_origin - }, - )); - } + if let Some(last_expr_idx) = assignment_info.exprs.len().checked_sub(1) { + let expr_idx = i.min(last_expr_idx); + let result_slot = i.saturating_sub(last_expr_idx); + let expr = assignment_info.exprs[expr_idx].clone(); + let replay_query = FlowReplayQuery::new( + self.db, + Some(self.tree), + self.cache, + antecedent_flow_id, + expr, + true, + ); + return self.start_expr_replay( + walk, + FlowExprReplay::Assignment { + antecedent_flow_id, + explicit_var_type, + result_slot, + }, + replay_query, + ); + } - if matches!(err, InferFailReason::FieldNotFound | InferFailReason::None) { - return Ok(self.finish_walk(walk, LuaType::Nil)); - } + self.finish_assignment_expr_type(walk, antecedent_flow_id, explicit_var_type, LuaType::Nil) + } - walk.antecedent_flow_id = antecedent_flow_id; - return Ok(SchedulerStep::ContinueWalk(walk)); - } - }; - let Some(expr_type) = expr_type else { - walk.antecedent_flow_id = antecedent_flow_id; - return Ok(SchedulerStep::ContinueWalk(walk)); + fn finish_assignment_expr( + &mut self, + walk: QueryWalk, + antecedent_flow_id: FlowId, + explicit_var_type: Option, + result_slot: usize, + replay_query: FlowReplayQuery, + ) -> Result { + let expr_type = match replay_query.replay_type(self.db, self.cache) { + Ok(Some(expr_type)) => expr_type + .get_result_slot_type(result_slot) + .unwrap_or(LuaType::Nil), + Ok(None) => LuaType::Unknown, + Err(err) => { + return self.finish_assignment_expr_error( + walk, + antecedent_flow_id, + explicit_var_type, + err, + ); + } }; + self.finish_assignment_expr_type(walk, antecedent_flow_id, explicit_var_type, expr_type) + } + + fn finish_assignment_expr_type( + &mut self, + walk: QueryWalk, + antecedent_flow_id: FlowId, + explicit_var_type: Option, + expr_type: LuaType, + ) -> Result { + let var_ref_id = walk.query.var_ref_id.clone(); + if let Some(explicit_var_type) = explicit_var_type { let result_type = finish_assignment_result( self.db, @@ -505,8 +1062,8 @@ impl<'a> FlowTypeEngine<'a> { return Ok(self.finish_walk(walk, result_type)); } - let reuse_source_narrowing = preserves_assignment_expr_type(&expr_type); - let mode = if reuse_source_narrowing { + let reuse_antecedent_narrowing = preserves_assignment_expr_type(&expr_type); + let mode = if reuse_antecedent_narrowing { FlowMode::WithConditions } else { FlowMode::WithoutConditions @@ -518,11 +1075,45 @@ impl<'a> FlowTypeEngine<'a> { walk, antecedent_flow_id, expr_type, - reuse_source_narrowing, + reuse_antecedent_narrowing, }), }) } + fn finish_assignment_expr_error( + &mut self, + mut walk: QueryWalk, + antecedent_flow_id: FlowId, + explicit_var_type: Option, + err: InferFailReason, + ) -> Result { + if let Some(explicit_var_type) = explicit_var_type { + return Ok(self.finish_walk(walk, explicit_var_type)); + } + + let var_ref_id = walk.query.var_ref_id.clone(); + if matches!(var_ref_id, VarRefId::IndexRef(_, _)) + && let Ok(origin_type) = get_var_ref_type(self.db, self.cache, &var_ref_id) + { + let non_nil_origin = TypeOps::Remove.apply(self.db, &origin_type, &LuaType::Nil); + return Ok(self.finish_walk( + walk, + if non_nil_origin.is_never() { + origin_type + } else { + non_nil_origin + }, + )); + } + + if matches!(err, InferFailReason::FieldNotFound | InferFailReason::None) { + return Ok(self.finish_walk(walk, LuaType::Nil)); + } + + walk.antecedent_flow_id = antecedent_flow_id; + Ok(SchedulerStep::ContinueWalk(walk)) + } + fn step_condition( &mut self, mut walk: QueryWalk, @@ -536,32 +1127,31 @@ impl<'a> FlowTypeEngine<'a> { return Ok(SchedulerStep::ContinueWalk(walk)); } - let condition_info = - get_flow_condition_info(self.db, self.cache, self.root, flow_node.id, condition_ptr)?; walk.antecedent_flow_id = antecedent_flow_id; let q = &walk.query; let var_ref_id = &q.var_ref_id; - if condition_info.index_var_ref_id.is_some() - && condition_info.index_var_ref_id.as_ref() != Some(var_ref_id) - && condition_info.index_prefix_var_ref_id.as_ref() != Some(var_ref_id) - { - return Ok(SchedulerStep::ContinueWalk(walk)); - } let cache_id = q.var_cache_idx; let flow_id = flow_node.id; let cache_key = (flow_id, condition_flow); + let mut cached_action = false; let action = match self .cache .flow_var_caches .get(cache_id as usize) .and_then(|var_cache| var_cache.condition_cache.get(&cache_key)) { - Some(CacheEntry::Cache(action)) => action.clone(), + Some(CacheEntry::Cache(action)) => { + cached_action = true; + action.clone() + } Some(CacheEntry::Ready) => { return self.fail_query(q, InferFailReason::RecursiveInfer); } None => { + let condition = condition_ptr + .to_node(self.root) + .ok_or(InferFailReason::None)?; get_flow_var_cache(self.cache, cache_id) .condition_cache .insert(cache_key, CacheEntry::Ready); @@ -572,20 +1162,31 @@ impl<'a> FlowTypeEngine<'a> { self.root, var_ref_id, flow_node, - condition_info.expr.clone(), + condition, condition_flow, ) { Ok(action) => action, Err(err) => { - get_flow_var_cache(self.cache, cache_id) - .condition_cache - .remove(&cache_key); - return self.fail_query(q, err); + return self.fail_condition_query(q, flow_id, condition_flow, err); } } } }; + if cached_action { + return match action { + ConditionFlowAction::Continue => Ok(SchedulerStep::ContinueWalk(walk)), + ConditionFlowAction::Result(result_type) => Ok(self.finish_walk(walk, result_type)), + ConditionFlowAction::Pending(pending_condition_narrow) => { + let mut walk = walk; + walk.pending_condition_narrows + .push(pending_condition_narrow); + Ok(SchedulerStep::ContinueWalk(walk)) + } + action => self.apply_condition_action(walk, flow_id, condition_flow, action), + }; + } + self.apply_condition_action(walk, flow_id, condition_flow, action) } @@ -680,10 +1281,36 @@ impl<'a> FlowTypeEngine<'a> { return Ok(self.finish_walk(walk, var_type)); } Err(err) => { - if let Some(init_type) = try_infer_decl_initializer_type( - self.db, self.cache, self.root, var_ref_id, - )? { - return Ok(self.finish_walk(walk, init_type)); + let Some(decl_id) = var_ref_id.get_decl_id_ref() else { + return self.fail_query(&walk.query, err); + }; + let decl = self + .db + .get_decl_index() + .get_decl(&decl_id) + .ok_or(InferFailReason::None)?; + if let Some(value_syntax_id) = decl.get_value_syntax_id() + && let Some(node) = + value_syntax_id.to_node_from_root(self.root.syntax()) + && let Some(expr) = LuaExpr::cast(node) + { + let expr_flow_id = self + .tree + .get_flow_id(expr.get_syntax_id()) + .unwrap_or(walk.antecedent_flow_id); + let replay_query = FlowReplayQuery::new( + self.db, + Some(self.tree), + self.cache, + expr_flow_id, + expr, + false, + ); + return self.start_expr_replay( + walk, + FlowExprReplay::DeclInitializer { fail_reason: err }, + replay_query, + ); } return self.fail_query(&walk.query, err); @@ -754,7 +1381,7 @@ impl<'a> FlowTypeEngine<'a> { fn apply_condition_action( &mut self, - mut walk: QueryWalk, + walk: QueryWalk, flow_id: FlowId, condition_flow: InferConditionFlow, action: ConditionFlowAction, @@ -778,60 +1405,41 @@ impl<'a> FlowTypeEngine<'a> { ); Ok(self.finish_walk(walk, result_type)) } - ConditionFlowAction::Pending(pending_condition_narrow) => { - get_flow_var_cache(self.cache, walk.query.var_cache_idx) - .condition_cache - .insert( - (flow_id, condition_flow), - CacheEntry::Cache(ConditionFlowAction::Pending( - pending_condition_narrow.clone(), - )), - ); - walk.pending_condition_narrows - .push(pending_condition_narrow); - Ok(SchedulerStep::ContinueWalk(walk)) + ConditionFlowAction::Pending(pending_condition_narrow) => self.start_pending_condition( + walk, + flow_id, + condition_flow, + pending_condition_narrow, + ), + ConditionFlowAction::NeedExprType { + flow_id: expr_flow_id, + expr, + resume, + } => { + let replay_query = FlowReplayQuery::new( + self.db, + Some(self.tree), + self.cache, + expr_flow_id, + expr, + false, + ); + self.start_expr_replay( + walk, + FlowExprReplay::Condition { + condition_flow_id: flow_id, + condition_flow, + resume, + }, + replay_query, + ) } ConditionFlowAction::NeedSubquery(subquery) => { - let (subquery_var_ref_id, subquery_antecedent_flow_id, subquery_mode) = - match &subquery { - ConditionSubquery::ArrayLen { - var_ref_id, - antecedent_flow_id, - .. - } - | ConditionSubquery::FieldLiteralEq { - var_ref_id, - antecedent_flow_id, - .. - } - | ConditionSubquery::Correlated { - var_ref_id, - antecedent_flow_id, - .. - } => (var_ref_id, *antecedent_flow_id, FlowMode::WithConditions), - ConditionSubquery::FieldLiteralSibling { - discriminant_prefix_var_ref_id, - antecedent_flow_id, - .. - } => ( - discriminant_prefix_var_ref_id, - *antecedent_flow_id, - FlowMode::WithConditions, - ), - }; - let subquery_query = - FlowQuery::new(self.cache, subquery_var_ref_id, subquery_antecedent_flow_id) - .at_flow(subquery_antecedent_flow_id, subquery_mode); - Ok(SchedulerStep::StartQuery { - query: subquery_query, - continuation: Some(Continuation::ConditionDependency { - walk, - flow_id, - condition_flow, - subquery, - }), - }) + Ok(self.start_condition_subquery(walk, flow_id, condition_flow, subquery)) } + ConditionFlowAction::NeedFieldLiteralSibling(subquery) => Ok( + self.start_field_literal_sibling_subquery(walk, flow_id, condition_flow, subquery) + ), ConditionFlowAction::NeedCorrelated(pending_correlated_condition) => { let subquery = walk.query.at_flow( pending_correlated_condition.current_search_root_flow_id, @@ -881,6 +1489,19 @@ impl<'a> FlowTypeEngine<'a> { .remove(&(query.flow_id, query.mode)); Err(err) } + + fn fail_condition_query( + &mut self, + query: &FlowQuery, + flow_id: FlowId, + condition_flow: InferConditionFlow, + err: InferFailReason, + ) -> Result { + get_flow_var_cache(self.cache, query.var_cache_idx) + .condition_cache + .remove(&(flow_id, condition_flow)); + self.fail_query(query, err) + } } pub(super) fn get_type_at_flow( @@ -923,6 +1544,29 @@ fn get_flow_var_cache(cache: &mut LuaInferCache, var_ref_cache_id: u32) -> &mut &mut cache.flow_var_caches[outer_index] } +fn replay_expr_no_flow( + db: &DbIndex, + cache: &mut LuaInferCache, + expr: LuaExpr, + dependency_types: &[(LuaSyntaxId, LuaType)], + allow_table_exprs: bool, +) -> Result, InferFailReason> { + let mut table_exprs = Vec::new(); + if allow_table_exprs { + if let LuaExpr::TableExpr(table_expr) = &expr { + table_exprs.push(table_expr.get_syntax_id()); + } + table_exprs.extend( + expr.descendants::() + .map(|table_expr| table_expr.get_syntax_id()), + ); + } + + cache.with_replay_overlay(dependency_types, &table_exprs, |cache| { + try_infer_expr_no_flow(db, cache, expr) + }) +} + fn can_reuse_narrowed_assignment_source( db: &DbIndex, narrowed_source_type: &LuaType, @@ -995,75 +1639,6 @@ fn is_exact_assignment_expr_type(typ: &LuaType) -> bool { } } -fn get_flow_condition_info( - db: &DbIndex, - cache: &mut LuaInferCache, - root: &LuaChunk, - flow_id: FlowId, - condition_ptr: &emmylua_parser::LuaAstPtr, -) -> Result, InferFailReason> { - let flow_index = flow_id.0 as usize; - if let Some(Some(info)) = cache.flow_condition_info_cache.get(flow_index) { - return Ok(info.clone()); - } - - let expr = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; - let (index_var_ref_id, index_prefix_var_ref_id) = - get_condition_index_var_refs(db, cache, expr.clone()); - let info = Rc::new(FlowConditionInfo { - expr, - index_var_ref_id, - index_prefix_var_ref_id, - }); - if cache.flow_condition_info_cache.len() <= flow_index { - cache - .flow_condition_info_cache - .resize_with(flow_index + 1, || None); - } - cache.flow_condition_info_cache[flow_index] = Some(info.clone()); - Ok(info) -} - -fn get_condition_index_var_refs( - db: &DbIndex, - cache: &mut LuaInferCache, - condition: LuaExpr, -) -> (Option, Option) { - match condition { - LuaExpr::IndexExpr(index_expr) => { - let index_var_ref_id = - get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(index_expr.clone())); - let index_prefix_var_ref_id = if index_var_ref_id.is_some() { - index_expr - .get_prefix_expr() - .and_then(|prefix_expr| get_var_expr_var_ref_id(db, cache, prefix_expr)) - } else { - None - }; - (index_var_ref_id, index_prefix_var_ref_id) - } - LuaExpr::ParenExpr(paren_expr) => paren_expr - .get_expr() - .map(|expr| get_condition_index_var_refs(db, cache, expr)) - .unwrap_or((None, None)), - LuaExpr::UnaryExpr(unary_expr) => { - let Some(op_token) = unary_expr.get_op_token() else { - return (None, None); - }; - - if op_token.get_op() != UnaryOperator::OpNot { - return (None, None); - } - - unary_expr - .get_expr() - .map(|expr| get_condition_index_var_refs(db, cache, expr)) - .unwrap_or((None, None)) - } - _ => (None, None), - } -} - fn get_branch_label_flow_ids( tree: &FlowTree, cache: &mut LuaInferCache, @@ -1173,73 +1748,3 @@ fn finish_assignment_result( expr_type.clone() } } - -fn try_infer_decl_initializer_type( - db: &DbIndex, - cache: &mut LuaInferCache, - root: &LuaChunk, - var_ref_id: &VarRefId, -) -> Result, InferFailReason> { - let Some(decl_id) = var_ref_id.get_decl_id_ref() else { - return Ok(None); - }; - - let decl = db - .get_decl_index() - .get_decl(&decl_id) - .ok_or(InferFailReason::None)?; - - let Some(value_syntax_id) = decl.get_value_syntax_id() else { - return Ok(None); - }; - - let Some(node) = value_syntax_id.to_node_from_root(root.syntax()) else { - return Ok(None); - }; - - let Some(expr) = LuaExpr::cast(node) else { - return Ok(None); - }; - - let expr_type = infer_expr(db, cache, expr.clone())?; - let init_type = expr_type.get_result_slot_type(0); - - Ok(init_type) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{CacheOptions, FileId}; - - #[test] - fn test_flow_caches_stay_sparse_for_large_flow_ids() { - let mut cache = LuaInferCache::new(FileId::new(0), CacheOptions::default()); - let var_ref_id = VarRefId::VarRef(LuaDeclId::new(FileId::new(0), 0.into())); - let query = FlowQuery::new(&mut cache, &var_ref_id, FlowId(10_000)); - - get_flow_var_cache(&mut cache, 0) - .type_cache - .insert((query.flow_id, query.mode), CacheEntry::Ready); - get_flow_var_cache(&mut cache, 0).condition_cache.insert( - (FlowId(20_000), InferConditionFlow::FalseCondition), - CacheEntry::Ready, - ); - - assert_eq!(cache.flow_var_caches.len(), 1); - assert_eq!(cache.flow_var_caches[0].type_cache.len(), 1); - assert_eq!(cache.flow_var_caches[0].condition_cache.len(), 1); - assert!(matches!( - cache.flow_var_caches[0] - .type_cache - .get(&(query.flow_id, query.mode)), - Some(CacheEntry::Ready) - )); - assert!(matches!( - cache.flow_var_caches[0] - .condition_cache - .get(&(FlowId(20_000), InferConditionFlow::FalseCondition)), - Some(CacheEntry::Ready) - )); - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs index 2d655ee35..d7e2e887b 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs @@ -37,7 +37,11 @@ pub fn infer_expr_narrow_type( get_type_at_flow::get_type_at_flow(db, flow_tree, cache, &root, &var_ref_id, flow_id) } -fn get_var_ref_type(db: &DbIndex, cache: &mut LuaInferCache, var_ref_id: &VarRefId) -> InferResult { +pub(in crate::semantic) fn get_var_ref_type( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, +) -> InferResult { if let Some(decl_id) = var_ref_id.get_decl_id_ref() { let decl = db .get_decl_index() @@ -96,3 +100,32 @@ fn get_multi_antecedents(tree: &FlowTree, flow: &FlowNode) -> Result None => Err(InferFailReason::None), } } + +#[cfg(test)] +mod tests { + use crate::{CacheEntry, LuaType, VirtualWorkspace}; + use emmylua_parser::{LuaAstNode, LuaTableExpr}; + + use super::*; + + #[test] + fn test_replay_overlay_is_scoped_without_cache_seed() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def("local value = {}"); + let syntax_id = ws.get_node::(file_id).get_syntax_id(); + let mut cache = LuaInferCache::new(file_id, Default::default()); + + cache.with_replay_overlay(&[(syntax_id, LuaType::Table)], &[syntax_id], |cache| { + assert_eq!(cache.replay_expr_type(syntax_id), Some(&LuaType::Table)); + assert!(cache.no_flow_table_exprs.contains(&syntax_id)); + assert!(!cache.expr_no_flow_cache.contains_key(&syntax_id)); + cache + .expr_no_flow_cache + .insert(syntax_id, CacheEntry::Cache(Some(LuaType::Table))); + }); + + assert!(cache.replay_expr_type(syntax_id).is_none()); + assert!(!cache.no_flow_table_exprs.contains(&syntax_id)); + assert!(!cache.expr_no_flow_cache.contains_key(&syntax_id)); + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs index 6ccaddde1..fdf9f6087 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs @@ -7,9 +7,10 @@ use smol_str::SmolStr; use crate::{ DbIndex, LuaAliasCallKind, LuaDeclId, LuaDeclOrMemberId, LuaInferCache, LuaMemberId, - LuaMemberKey, LuaType, infer_expr, + LuaMemberKey, LuaType, semantic::infer::{ infer_index::get_index_expr_var_ref_id, infer_name::get_name_expr_var_ref_id, + try_infer_expr_no_flow, }, }; @@ -63,9 +64,6 @@ impl VarRefId { } } - // 计算从 prefix 到当前索引引用的相对字段路径。 - // 例如 `target.handle.name` 相对 `target` 得到 `handle.name`, - // 后续可在已经被判别字段窄化过的 prefix 类型上逐级投影。 pub fn relative_index_path(&self, prefix: &VarRefId) -> Option> { let (decl_or_member_id, path) = match self { VarRefId::IndexRef(decl_or_member_id, path) => { @@ -127,7 +125,7 @@ fn get_call_expr_var_ref_id( call_expr: &LuaCallExpr, ) -> Option { let prefix_expr = call_expr.get_prefix_expr()?; - let maybe_func = infer_expr(db, cache, prefix_expr.clone()).ok()?; + let maybe_func = try_infer_expr_no_flow(db, cache, prefix_expr.clone()).ok()??; let ret = match maybe_func { LuaType::DocFunction(f) => f.get_ret().clone(), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/test.rs b/crates/emmylua_code_analysis/src/semantic/infer/test.rs index 0d41d1748..f3fa9ba6a 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/test.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod test { use crate::{DiagnosticCode, VirtualWorkspace}; + use emmylua_parser::{LuaCallExpr, LuaExpr}; #[test] fn test_custom_binary() { @@ -82,6 +83,71 @@ mod test { assert_eq!(ws.expr_ty("F()"), ws.ty("string")); } + #[test] + fn test_no_flow_overload_call_keeps_shared_return_when_arg_declines() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@overload fun(value: string): boolean + ---@overload fun(value: integer): boolean + ---@param value string|integer + ---@return boolean + local function classify(value) + end + + local result = classify({}) + "#, + ); + + let call_expr = ws.get_node::(file_id); + let semantic_model = ws + .analysis + .compilation + .get_semantic_model(file_id) + .expect("Semantic model must exist"); + let ty = crate::semantic::infer::try_infer_expr_no_flow( + semantic_model.get_db(), + &mut semantic_model.get_cache().borrow_mut(), + LuaExpr::CallExpr(call_expr), + ) + .expect("no-flow call replay should not error") + .expect("no-flow call replay should keep shared overload return"); + + assert_eq!(ty, ws.ty("boolean")); + } + + #[test] + fn test_no_flow_overload_call_declines_when_declined_arg_returns_differ() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@overload fun(value: string): string + ---@overload fun(value: integer): integer + ---@param value string|integer + ---@return string|integer + local function classify(value) + end + + local result = classify({}) + "#, + ); + + let call_expr = ws.get_node::(file_id); + let semantic_model = ws + .analysis + .compilation + .get_semantic_model(file_id) + .expect("Semantic model must exist"); + let ty = crate::semantic::infer::try_infer_expr_no_flow( + semantic_model.get_db(), + &mut semantic_model.get_cache().borrow_mut(), + LuaExpr::CallExpr(call_expr), + ) + .expect("no-flow call replay should not error"); + + assert!(ty.is_none()); + } + #[test] fn test_infer_expr_list_types_tolerates_infer_failures() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index ae9f3e713..b9cd692f9 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -55,8 +55,9 @@ pub use generic::*; pub use guard::{InferGuard, InferGuardRef}; pub use infer::InferFailReason; pub use infer::infer_call_expr_func; -pub(crate) use infer::infer_expr; pub use infer::infer_param; +pub(crate) use infer::try_infer_expr_for_index; +pub(crate) use infer::{infer_expr, try_infer_expr_no_flow}; use overload_resolve::resolve_signature; pub use semantic_info::SemanticDeclLevel; pub use type_check::{TypeCheckFailReason, TypeCheckResult}; diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index 11d3d3595..a6447a91c 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,19 +1,15 @@ mod resolve_signature_by_args; -use std::{ops::Deref, sync::Arc}; +use std::sync::Arc; -use emmylua_parser::{LuaCallExpr, LuaExpr}; +use emmylua_parser::{LuaAstNode, LuaCallExpr}; -use crate::{ - VariadicType, - db_index::{DbIndex, LuaFunctionType, LuaType}, - infer_expr, -}; +use crate::db_index::{DbIndex, LuaFunctionType, LuaType}; use super::{ LuaInferCache, generic::instantiate_func_generic, - infer::{InferCallFuncResult, InferFailReason}, + infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; @@ -27,12 +23,37 @@ pub fn resolve_signature( arg_count: Option, ) -> InferCallFuncResult { let args = call_expr.get_args_list().ok_or(InferFailReason::None)?; - let expr_types = infer_expr_list_types( + let mut declined_no_flow_arg_ranges = Vec::new(); + let expr_types_with_ranges = infer_expr_list_types( db, cache, args.get_args().collect::>().as_slice(), arg_count, - ); + |db, cache, expr| { + if cache.is_no_flow() { + let expr_range = expr.get_range(); + let Some(expr_type) = try_infer_expr_no_flow(db, cache, expr)? else { + if !is_generic { + declined_no_flow_arg_ranges.push(expr_range); + return Ok(LuaType::Unknown); + } + return Err(InferFailReason::None); + }; + Ok(expr_type) + } else { + Ok(crate::infer_expr(db, cache, expr).unwrap_or(LuaType::Unknown)) + } + }, + )?; + let declined_no_flow_args = expr_types_with_ranges + .iter() + .map(|(_, range)| declined_no_flow_arg_ranges.contains(range)) + .collect::>(); + let expr_types: Vec<_> = expr_types_with_ranges + .into_iter() + .map(|(ty, _)| ty) + .collect(); + if is_generic { resolve_signature_by_generic(db, cache, overloads, call_expr, expr_types, arg_count) } else { @@ -42,6 +63,7 @@ pub fn resolve_signature( &expr_types, call_expr.is_colon_call(), arg_count, + &declined_no_flow_args, ) } } @@ -65,48 +87,6 @@ fn resolve_signature_by_generic( &expr_types, call_expr.is_colon_call(), arg_count, + &[], ) } - -fn infer_expr_list_types( - db: &DbIndex, - cache: &mut LuaInferCache, - exprs: &[LuaExpr], - var_count: Option, -) -> Vec { - let mut value_types = Vec::new(); - for (idx, expr) in exprs.iter().enumerate() { - let expr_type = infer_expr(db, cache, expr.clone()).unwrap_or(LuaType::Unknown); - match expr_type { - LuaType::Variadic(variadic) => { - if let Some(var_count) = var_count { - if idx < var_count { - for i in idx..var_count { - if let Some(typ) = variadic.get_type(i - idx) { - value_types.push(typ.clone()); - } else { - break; - } - } - } - } else { - match variadic.deref() { - VariadicType::Base(base) => { - value_types.push(base.clone()); - } - VariadicType::Multi(vecs) => { - for typ in vecs { - value_types.push(typ.clone()); - } - } - } - } - - break; - } - _ => value_types.push(expr_type.clone()), - } - } - - value_types -} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs index 2fda0c1a4..7e2217f27 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs @@ -19,10 +19,11 @@ pub(crate) fn callable_accepts_args( } for (arg_index, expr_type) in expr_types.iter().enumerate() { - let param_type = match get_call_arg_param(func, arg_index, is_colon_call) { - CallArgParam::Skip => continue, - CallArgParam::Present { param_type, .. } => param_type, - CallArgParam::Missing => return false, + let Some(param_index) = get_call_param_index(func, arg_index, is_colon_call) else { + continue; + }; + let Some(param_type) = get_call_arg_param_type(func, param_index) else { + return false; }; if !param_type.is_any() && check_type_compact(db, ¶m_type, expr_type).is_err() { @@ -39,9 +40,11 @@ pub fn resolve_signature_by_args( expr_types: &[LuaType], is_colon_call: bool, arg_count: Option, + declined_no_flow_args: &[bool], ) -> InferCallFuncResult { let expr_len = expr_types.len(); let arg_count = arg_count.unwrap_or(expr_len); + let has_declined_no_flow_arg = declined_no_flow_args.iter().any(|declined| *declined); let mut need_resolve_funcs = match overloads.len() { 0 => return Err(InferFailReason::None), 1 => return Ok(Arc::clone(&overloads[0])), @@ -65,30 +68,34 @@ pub fn resolve_signature_by_args( .expect("Match result should exist"); for (arg_index, expr_type) in expr_types.iter().enumerate() { let mut current_match_result = ParamMatchResult::Not; + let declined_no_flow_arg = declined_no_flow_args + .get(arg_index) + .copied() + .unwrap_or(false); for opt_func in &mut need_resolve_funcs { let func = match opt_func.as_ref() { None => continue, Some(func) => func, }; - if func.get_params().len() < arg_count && !is_func_last_param_variadic(func) { + let param_len = func.get_params().len(); + if param_len < arg_count && !is_func_last_param_variadic(func) { *opt_func = None; continue; } - let (param_index, param_type) = match get_call_arg_param(func, arg_index, is_colon_call) - { - CallArgParam::Skip => continue, - CallArgParam::Present { - param_index, - param_type, - } => (param_index, param_type), - CallArgParam::Missing => { - *opt_func = None; - continue; - } + let Some(param_index) = get_call_param_index(func, arg_index, is_colon_call) else { + continue; + }; + let Some(param_type) = get_call_arg_param_type(func, param_index) else { + *opt_func = None; + continue; }; - let match_result = if param_type.is_any() { + let match_result = if declined_no_flow_arg && expr_type.is_unknown() { + // Declined no-flow args are compatible with any overload, but they do + // not prove that a specific row won. + ParamMatchResult::Any + } else if param_type.is_any() { ParamMatchResult::Any } else if check_type_compact(db, ¶m_type, expr_type).is_ok() { ParamMatchResult::Type @@ -106,7 +113,8 @@ pub fn resolve_signature_by_args( continue; } - if match_result > ParamMatchResult::Any + if !has_declined_no_flow_arg + && match_result > ParamMatchResult::Any && arg_index + 1 == expr_len && param_index + 1 == func.get_params().len() { @@ -127,7 +135,12 @@ pub fn resolve_signature_by_args( let rest_len = rest_need_resolve_funcs.len(); match rest_len { - 0 => return Ok(best_match_result), + 0 => { + if has_declined_no_flow_arg { + return Err(InferFailReason::None); + } + return Ok(best_match_result); + } 1 => { return Ok(rest_need_resolve_funcs[0] .clone() @@ -152,22 +165,35 @@ pub fn resolve_signature_by_args( None => continue, Some(func) => func, }; - let (param_index, param_type) = - match get_call_arg_param(func, param_index, is_colon_call) { - CallArgParam::Skip => continue, - CallArgParam::Present { - param_index, - param_type, - } => (param_index, param_type), - CallArgParam::Missing => return Ok(func.clone()), - }; - - let match_result = if param_type.is_any() { - ParamMatchResult::Any - } else if param_type.is_nullable() { - ParamMatchResult::Type + let param_len = func.get_params().len(); + let Some(param_index) = get_call_param_index(func, param_index, is_colon_call) else { + continue; + }; + let match_result = if param_index >= param_len { + if func + .get_params() + .last() + .is_some_and(|last_param_info| last_param_info.0 == "...") + { + ParamMatchResult::Any + } else if has_declined_no_flow_arg { + ParamMatchResult::Type + } else { + return Ok(func.clone()); + } } else { - ParamMatchResult::Not + let param_info = func + .get_params() + .get(param_index) + .expect("Param index should exist"); + let param_type = param_info.1.clone().unwrap_or(LuaType::Any); + if param_type.is_any() { + ParamMatchResult::Any + } else if param_type.is_nullable() { + ParamMatchResult::Type + } else { + ParamMatchResult::Not + } }; if match_result > current_match_result { @@ -180,7 +206,8 @@ pub fn resolve_signature_by_args( continue; } - if match_result >= ParamMatchResult::Any + if !has_declined_no_flow_arg + && match_result >= ParamMatchResult::Any && i + 1 == rest_len && param_index + 1 == func.get_params().len() { @@ -193,19 +220,40 @@ pub fn resolve_signature_by_args( } } - Ok(best_match_result) + if !has_declined_no_flow_arg { + return Ok(best_match_result); + } + + let mut remaining_funcs = rest_need_resolve_funcs.into_iter().flatten(); + let Some(first) = remaining_funcs.next() else { + return Err(InferFailReason::None); + }; + + if remaining_funcs.all(|func| func.get_ret() == first.get_ret()) { + Ok(first) + } else { + Err(InferFailReason::None) + } +} + +fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { + if let Some(last_param) = func.get_params().last() { + last_param.0 == "..." + } else { + false + } } -fn get_call_arg_param( +fn get_call_param_index( func: &LuaFunctionType, arg_index: usize, is_colon_call: bool, -) -> CallArgParam { +) -> Option { let mut param_index = arg_index; match (func.is_colon_define(), is_colon_call) { (true, false) => { if param_index == 0 { - return CallArgParam::Skip; + return None; } param_index -= 1; } @@ -214,31 +262,19 @@ fn get_call_arg_param( } _ => {} } + Some(param_index) +} - if let Some((_, ty)) = func.get_params().get(param_index) { - return CallArgParam::Present { - param_index, - param_type: ty.clone().unwrap_or(LuaType::Any), - }; - } - - if let Some((name, ty)) = func.get_params().last() - && name == "..." - { - return CallArgParam::Present { - param_index, - param_type: ty.clone().unwrap_or(LuaType::Any), - }; +fn get_call_arg_param_type(func: &LuaFunctionType, param_index: usize) -> Option { + if let Some(param_info) = func.get_params().get(param_index) { + return Some(param_info.1.clone().unwrap_or(LuaType::Any)); } - CallArgParam::Missing -} - -fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { - if let Some(last_param) = func.get_params().last() { - last_param.0 == "..." + let last_param_info = func.get_params().last()?; + if last_param_info.0 == "..." { + Some(last_param_info.1.clone().unwrap_or(LuaType::Any)) } else { - false + None } } @@ -248,12 +284,3 @@ enum ParamMatchResult { Any, Type, } - -enum CallArgParam { - Skip, - Present { - param_index: usize, - param_type: LuaType, - }, - Missing, -} diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 04b3c217f..8ea744c64 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -511,11 +511,13 @@ fn build_node_semantic_token( } LuaAst::LuaCallExpr(call_expr) => { let prefix = call_expr.get_prefix_expr()?; - let prefix_type = semantic_model.infer_expr(prefix.clone()).ok(); match prefix { LuaExpr::NameExpr(name_expr) => { let name = name_expr.get_name_token()?; + let prefix_type = semantic_model + .infer_expr(LuaExpr::NameExpr(name_expr.clone())) + .ok(); if let Some(prefix_type) = prefix_type { match prefix_type { LuaType::Signature(signature) => {