Skip to content

Commit d4280cf

Browse files
committed
Resolve composite element types during emission, protect loop continue blocks
- Build composite-to-element type map from OpTypeVector/Matrix/Array definitions and use element type for CompositeConstruct components instead of inheriting the composite result type. - Skip RVSDG transform for selections whose branch targets overlap with loop continue blocks.
1 parent ad8bdfb commit d4280cf

2 files changed

Lines changed: 64 additions & 4 deletions

File tree

rust/spirv-tools-opt/src/direct/emit.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ pub struct EmitCtx<'a> {
129129
pub glsl_ext_id: Option<Word>,
130130
pub type_widths: &'a HashMap<Word, u32>,
131131
pub id_to_type: &'a HashMap<Word, Word>,
132+
/// Maps composite type IDs (OpTypeVector, OpTypeMatrix, OpTypeArray) to their element type ID.
133+
pub composite_element_types: &'a HashMap<Word, Word>,
132134
}
133135

134136
// ---------------------------------------------------------------------------
@@ -1641,9 +1643,14 @@ fn emit_composite_construct(
16411643
}
16421644
let mut synth = Vec::new();
16431645
let mut operands = Vec::new();
1646+
// For CompositeConstruct, the element type is derived from the composite type.
1647+
let element_type = ctx
1648+
.composite_element_types
1649+
.get(&result_type)
1650+
.copied()
1651+
.unwrap_or(result_type);
16441652
for arg in &args[..expected] {
1645-
// Each component's type differs from result_type (composite type).
1646-
let component_type = resolve_term_type(arg, ctx).unwrap_or(result_type);
1653+
let component_type = resolve_term_type(arg, ctx).unwrap_or(element_type);
16471654
let (arg_id, mut s) = emit_term(arg, component_type, ctx)?;
16481655
synth.append(&mut s);
16491656
operands.push(rspirv::dr::Operand::IdRef(arg_id));
@@ -1667,8 +1674,13 @@ fn flatten_expr_list(
16671674
match term {
16681675
Term::App { op, .. } if op == "ENil" => Some((Vec::new(), Vec::new())),
16691676
Term::App { op, args } if op == "ECons" && args.len() >= 2 => {
1670-
// Each element's type differs from result_type (composite type).
1671-
let head_type = resolve_term_type(&args[0], ctx).unwrap_or(result_type);
1677+
// Element type from the composite type for correct component emission.
1678+
let element_type = ctx
1679+
.composite_element_types
1680+
.get(&result_type)
1681+
.copied()
1682+
.unwrap_or(result_type);
1683+
let head_type = resolve_term_type(&args[0], ctx).unwrap_or(element_type);
16721684
let (head_id, head_synth) = emit_term(&args[0], head_type, ctx)?;
16731685
let (mut rest_ids, mut rest_synth) = flatten_expr_list(&args[1], result_type, ctx)?;
16741686
let mut ids = vec![head_id];
@@ -1863,6 +1875,7 @@ mod tests {
18631875
glsl_ext_id: None,
18641876
type_widths: &type_widths,
18651877
id_to_type: &HashMap::new(),
1878+
composite_element_types: &HashMap::new(),
18661879
};
18671880
let term = parse_sexpr("(Sym \"id5\")").unwrap();
18681881
let (result_id, synth) = emit_term(&term, 10, &mut ctx).unwrap();
@@ -1888,6 +1901,7 @@ mod tests {
18881901
glsl_ext_id: None,
18891902
type_widths: &type_widths,
18901903
id_to_type: &HashMap::new(),
1904+
composite_element_types: &HashMap::new(),
18911905
};
18921906
let term = parse_sexpr("(Const 42)").unwrap();
18931907
let (result_id, synth) = emit_term(&term, 10, &mut ctx).unwrap();
@@ -1915,6 +1929,7 @@ mod tests {
19151929
glsl_ext_id: None,
19161930
type_widths: &type_widths,
19171931
id_to_type: &HashMap::new(),
1932+
composite_element_types: &HashMap::new(),
19181933
};
19191934
let term = parse_sexpr("(Add (Const 3) (Const 5))").unwrap();
19201935
let (result_id, synth) = emit_term(&term, 10, &mut ctx).unwrap();
@@ -1946,6 +1961,7 @@ mod tests {
19461961
glsl_ext_id: None,
19471962
type_widths: &type_widths,
19481963
id_to_type: &HashMap::new(),
1964+
composite_element_types: &HashMap::new(),
19491965
};
19501966
let term = parse_sexpr("(IntToExpr (Sym \"id5\"))").unwrap();
19511967
let (result_id, synth) = emit_term(&term, 10, &mut ctx).unwrap();
@@ -1974,6 +1990,11 @@ mod tests {
19741990
glsl_ext_id: None,
19751991
type_widths,
19761992
id_to_type: EMPTY_ID_TO_TYPE.get_or_init(HashMap::new),
1993+
composite_element_types: {
1994+
static EMPTY_COMP: std::sync::OnceLock<HashMap<Word, Word>> =
1995+
std::sync::OnceLock::new();
1996+
EMPTY_COMP.get_or_init(HashMap::new)
1997+
},
19771998
}
19781999
}
19792000

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,12 +559,35 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
559559
}
560560
}
561561

562+
// Build a set of (func_idx, block_idx) pairs that are loop continue blocks.
563+
// Selections whose branch targets overlap these must also be skipped.
564+
let mut continue_block_set: HashSet<(usize, usize)> = HashSet::new();
565+
for loop_info in &loop_constructs {
566+
if let Some(continue_idx) = loop_info.continue_block_idx {
567+
continue_block_set.insert((loop_info.func_idx, continue_idx));
568+
}
569+
}
570+
562571
// For each selection construct, convert to RVSDG EffGamma
563572
for (sel_idx, sel) in selection_constructs.iter().enumerate() {
564573
// Skip selection constructs inside loop bodies to avoid breaking continue block reachability
565574
if loop_block_set.contains(&(sel.func_idx, sel.header_block_idx)) {
566575
continue;
567576
}
577+
// Skip selections whose branch targets or merge block overlap with loop continue blocks
578+
{
579+
let label_map = &func_block_labels[sel.func_idx];
580+
let touches_continue = [&sel.then_label, &sel.else_label, &sel.merge_label]
581+
.iter()
582+
.any(|label| {
583+
label_map
584+
.get(label)
585+
.map_or(false, |&idx| continue_block_set.contains(&(sel.func_idx, idx)))
586+
});
587+
if touches_continue {
588+
continue;
589+
}
590+
}
568591
let func = &module.functions[sel.func_idx];
569592

570593
// Find the header block to get the condition
@@ -854,6 +877,21 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
854877
let float32_type = find_spirv_type(module, Op::TypeFloat, Some(32));
855878
let float64_type = find_spirv_type(module, Op::TypeFloat, Some(64));
856879

880+
// Build composite → element type mapping for CompositeConstruct emission
881+
let mut composite_element_types: HashMap<Word, Word> = HashMap::new();
882+
for inst in &module.types_global_values {
883+
match inst.class.opcode {
884+
Op::TypeVector | Op::TypeMatrix | Op::TypeArray | Op::TypeRuntimeArray => {
885+
if let (Some(composite_id), Some(rspirv::dr::Operand::IdRef(element_id))) =
886+
(inst.result_id, inst.operands.first())
887+
{
888+
composite_element_types.insert(composite_id, *element_id);
889+
}
890+
}
891+
_ => {}
892+
}
893+
}
894+
857895
// Only extract from IDs that are both:
858896
// 1. True roots (operands of side effects) - these are the outputs we need
859897
// 2. Live (reachable via liveness propagation in the e-graph)
@@ -939,6 +977,7 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
939977
glsl_ext_id: ctx.glsl_ext_id(),
940978
type_widths: &type_widths,
941979
id_to_type: &ctx.id_to_type,
980+
composite_element_types: &composite_element_types,
942981
};
943982
if let Some((final_id, new_insts)) =
944983
emit::emit_term(term_tree, corrected_type, &mut emit_ctx)

0 commit comments

Comments
 (0)