@@ -129,6 +129,8 @@ pub struct EmitCtx<'a> {
129129 pub glsl_ext_id : Option < Word > ,
130130 pub type_widths : & ' a HashMap < Word , u32 > ,
131131 pub id_to_type : & ' a HashMap < Word , Word > ,
132+ /// Maps composite type IDs (OpTypeVector, OpTypeMatrix, OpTypeArray) to their element type ID.
133+ pub composite_element_types : & ' a HashMap < Word , Word > ,
132134}
133135
134136// ---------------------------------------------------------------------------
@@ -1641,9 +1643,14 @@ fn emit_composite_construct(
16411643 }
16421644 let mut synth = Vec :: new ( ) ;
16431645 let mut operands = Vec :: new ( ) ;
1646+ // For CompositeConstruct, the element type is derived from the composite type.
1647+ let element_type = ctx
1648+ . composite_element_types
1649+ . get ( & result_type)
1650+ . copied ( )
1651+ . unwrap_or ( result_type) ;
16441652 for arg in & args[ ..expected] {
1645- // Each component's type differs from result_type (composite type).
1646- let component_type = resolve_term_type ( arg, ctx) . unwrap_or ( result_type) ;
1653+ let component_type = resolve_term_type ( arg, ctx) . unwrap_or ( element_type) ;
16471654 let ( arg_id, mut s) = emit_term ( arg, component_type, ctx) ?;
16481655 synth. append ( & mut s) ;
16491656 operands. push ( rspirv:: dr:: Operand :: IdRef ( arg_id) ) ;
@@ -1667,8 +1674,13 @@ fn flatten_expr_list(
16671674 match term {
16681675 Term :: App { op, .. } if op == "ENil" => Some ( ( Vec :: new ( ) , Vec :: new ( ) ) ) ,
16691676 Term :: App { op, args } if op == "ECons" && args. len ( ) >= 2 => {
1670- // Each element's type differs from result_type (composite type).
1671- let head_type = resolve_term_type ( & args[ 0 ] , ctx) . unwrap_or ( result_type) ;
1677+ // Element type from the composite type for correct component emission.
1678+ let element_type = ctx
1679+ . composite_element_types
1680+ . get ( & result_type)
1681+ . copied ( )
1682+ . unwrap_or ( result_type) ;
1683+ let head_type = resolve_term_type ( & args[ 0 ] , ctx) . unwrap_or ( element_type) ;
16721684 let ( head_id, head_synth) = emit_term ( & args[ 0 ] , head_type, ctx) ?;
16731685 let ( mut rest_ids, mut rest_synth) = flatten_expr_list ( & args[ 1 ] , result_type, ctx) ?;
16741686 let mut ids = vec ! [ head_id] ;
@@ -1863,6 +1875,7 @@ mod tests {
18631875 glsl_ext_id : None ,
18641876 type_widths : & type_widths,
18651877 id_to_type : & HashMap :: new ( ) ,
1878+ composite_element_types : & HashMap :: new ( ) ,
18661879 } ;
18671880 let term = parse_sexpr ( "(Sym \" id5\" )" ) . unwrap ( ) ;
18681881 let ( result_id, synth) = emit_term ( & term, 10 , & mut ctx) . unwrap ( ) ;
@@ -1888,6 +1901,7 @@ mod tests {
18881901 glsl_ext_id : None ,
18891902 type_widths : & type_widths,
18901903 id_to_type : & HashMap :: new ( ) ,
1904+ composite_element_types : & HashMap :: new ( ) ,
18911905 } ;
18921906 let term = parse_sexpr ( "(Const 42)" ) . unwrap ( ) ;
18931907 let ( result_id, synth) = emit_term ( & term, 10 , & mut ctx) . unwrap ( ) ;
@@ -1915,6 +1929,7 @@ mod tests {
19151929 glsl_ext_id : None ,
19161930 type_widths : & type_widths,
19171931 id_to_type : & HashMap :: new ( ) ,
1932+ composite_element_types : & HashMap :: new ( ) ,
19181933 } ;
19191934 let term = parse_sexpr ( "(Add (Const 3) (Const 5))" ) . unwrap ( ) ;
19201935 let ( result_id, synth) = emit_term ( & term, 10 , & mut ctx) . unwrap ( ) ;
@@ -1946,6 +1961,7 @@ mod tests {
19461961 glsl_ext_id : None ,
19471962 type_widths : & type_widths,
19481963 id_to_type : & HashMap :: new ( ) ,
1964+ composite_element_types : & HashMap :: new ( ) ,
19491965 } ;
19501966 let term = parse_sexpr ( "(IntToExpr (Sym \" id5\" ))" ) . unwrap ( ) ;
19511967 let ( result_id, synth) = emit_term ( & term, 10 , & mut ctx) . unwrap ( ) ;
@@ -1974,6 +1990,11 @@ mod tests {
19741990 glsl_ext_id : None ,
19751991 type_widths,
19761992 id_to_type : EMPTY_ID_TO_TYPE . get_or_init ( HashMap :: new) ,
1993+ composite_element_types : {
1994+ static EMPTY_COMP : std:: sync:: OnceLock < HashMap < Word , Word > > =
1995+ std:: sync:: OnceLock :: new ( ) ;
1996+ EMPTY_COMP . get_or_init ( HashMap :: new)
1997+ } ,
19771998 }
19781999 }
19792000
0 commit comments