diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 1d7658e..b884b31 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -41,6 +41,10 @@ pub fn raw_command_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("raw_command")] } +pub fn predicate_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("predicate")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index bd1d87b..cdbc721 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -24,6 +24,7 @@ pub struct Analyzer<'tcx, 'ctx> { tcx: TyCtxt<'tcx>, ctx: &'ctx mut analyze::Analyzer<'tcx>, trusted: HashSet, + predicates: HashSet, } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { @@ -82,6 +83,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(local_def_id.to_def_id()); } + if analyzer.is_annotated_as_predicate() { + self.predicates.insert(local_def_id.to_def_id()); + analyzer.analyze_predicate_definition(local_def_id); + } + use mir_ty::TypeVisitableExt as _; if sig.has_param() && !analyzer.is_fully_annotated() { self.ctx.register_deferred_def(local_def_id.to_def_id()); @@ -105,6 +111,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local_def_id, "trusted"); continue; } + if self.predicates.contains(&local_def_id.to_def_id()) { + tracing::info!(?local_def_id, "predicate"); + continue; + } let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else { // when the local_def_id is deferred it would be skipped continue; @@ -212,7 +222,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { pub fn new(ctx: &'ctx mut analyze::Analyzer<'tcx>) -> Self { let tcx = ctx.tcx; let trusted = HashSet::default(); - Self { ctx, tcx, trusted } + let predicates = HashSet::default(); + Self { + ctx, + tcx, + trusted, + predicates, + } } pub fn run(&mut self) { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index d556ef0..a4bfd64 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -16,6 +16,27 @@ use crate::pretty::PrettyDisplayExt as _; use crate::refine::{BasicBlockType, TypeBuilder}; use crate::rty; +fn stmt_str_literal(stmt: &rustc_hir::Stmt) -> Option { + use rustc_ast::LitKind; + use rustc_hir::{Expr, ExprKind, Stmt, StmtKind}; + + match stmt { + Stmt { + kind: + StmtKind::Semi(Expr { + kind: + ExprKind::Lit(rustc_hir::Lit { + node: LitKind::Str(symbol, _), + .. + }), + .. + }), + .. + } => Some(symbol.to_string()), + _ => None, + } +} + /// An implementation of the typing of local definitions. /// /// The current implementation only applies to function definitions. The entry point is @@ -106,6 +127,49 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { ret_annot } + pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) { + let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string(); + + // function's body + use rustc_hir::{Block, Expr, ExprKind}; + + let hir_map = self.tcx.hir(); + let body_id = hir_map.maybe_body_owned_by(local_def_id).unwrap(); + let hir_body = hir_map.body(body_id); + + let predicate_body = match hir_body.value { + Expr { + kind: ExprKind::Block(Block { stmts, .. }, _), + .. + } => stmts + .iter() + .find_map(stmt_str_literal) + .expect("invalid predicate definition: no string literal was found."), + _ => panic!("expected function body, got: {hir_body:?}"), + }; + + // names and sorts of arguments + let arg_names = self + .tcx + .fn_arg_names(local_def_id.to_def_id()) + .iter() + .map(|ident| ident.to_string()); + + let sig = self.ctx.local_fn_sig(local_def_id); + let arg_sorts = sig + .inputs() + .iter() + .map(|input_ty| self.type_builder.build(*input_ty).to_sort()); + + let arg_name_and_sorts = arg_names.into_iter().zip(arg_sorts).collect::>(); + + self.ctx.system.borrow_mut().push_pred_define( + chc::UserDefinedPred::new(pred_name), + chc::UserDefinedPredSig::from(arg_name_and_sorts), + predicate_body, + ); + } + pub fn is_annotated_as_trusted(&self) -> bool { self.tcx .get_attrs_by_path( @@ -136,6 +200,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .is_some() } + pub fn is_annotated_as_predicate(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::predicate_path(), + ) + .next() + .is_some() + } + // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { let has_require = self diff --git a/src/chc.rs b/src/chc.rs index e9741bb..3c72b5c 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1700,11 +1700,21 @@ pub struct PredVarDef { pub debug_info: DebugInfo, } +pub type UserDefinedPredSig = Vec<(String, Sort)>; + +#[derive(Debug, Clone)] +pub struct UserDefinedPredDef { + symbol: UserDefinedPred, + sig: UserDefinedPredSig, + body: String, +} + /// A CHC system. #[derive(Debug, Clone, Default)] pub struct System { pub raw_commands: Vec, pub datatypes: Vec, + pub user_defined_pred_defs: Vec, pub clauses: IndexVec, pub pred_vars: IndexVec, } @@ -1718,6 +1728,16 @@ impl System { self.raw_commands.push(raw_command) } + pub fn push_pred_define( + &mut self, + symbol: UserDefinedPred, + sig: UserDefinedPredSig, + body: String, + ) { + self.user_defined_pred_defs + .push(UserDefinedPredDef { symbol, sig, body }) + } + pub fn push_clause(&mut self, clause: Clause) -> Option { if clause.is_nop() { return None; diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 1fb7319..e8886ed 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -562,6 +562,33 @@ impl<'ctx, 'a> MatcherPredFun<'ctx, 'a> { } } +pub struct UserDefinedPredDef<'ctx, 'a> { + ctx: &'ctx FormatContext, + inner: &'a chc::UserDefinedPredDef, +} + +impl<'ctx, 'a> std::fmt::Display for UserDefinedPredDef<'ctx, 'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let params = List::closed( + self.inner + .sig + .iter() + .map(|(name, sort)| format!("({} {})", name, self.ctx.fmt_sort(sort))), + ); + write!( + f, + "(define-fun {name} {params} Bool {body})", + name = self.inner.symbol, + body = &self.inner.body, + ) + } +} + +impl<'ctx, 'a> UserDefinedPredDef<'ctx, 'a> { + pub fn new(ctx: &'ctx FormatContext, inner: &'a chc::UserDefinedPredDef) -> Self { + Self { ctx, inner } + } +} /// A wrapper around a [`chc::System`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct System<'a> { @@ -573,16 +600,25 @@ impl<'a> std::fmt::Display for System<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "(set-logic HORN)\n")?; + writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?; + for datatype in self.ctx.datatypes() { + writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?; + writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?; + } + // insert command from #![thrust::raw_command()] here for raw_command in &self.inner.raw_commands { writeln!(f, "{}\n", RawCommand::new(raw_command))?; } - writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?; - for datatype in self.ctx.datatypes() { - writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?; - writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?; + for user_defined_pred_def in &self.inner.user_defined_pred_defs { + writeln!( + f, + "{}\n", + UserDefinedPredDef::new(&self.ctx, user_defined_pred_def) + )?; } + writeln!(f)?; for (p, def) in self.inner.pred_vars.iter_enumerated() { if !def.debug_info.is_empty() { diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 30ea31c..cb5ce8f 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -152,6 +152,15 @@ fn unbox_datatype(datatype: Datatype) -> Datatype { } } +fn unbox_user_defined_pred_def(user_defined_pred_def: UserDefinedPredDef) -> UserDefinedPredDef { + let UserDefinedPredDef { symbol, sig, body } = user_defined_pred_def; + let sig = sig + .into_iter() + .map(|(name, sort)| (name, unbox_sort(sort))) + .collect(); + UserDefinedPredDef { symbol, sig, body } +} + /// Remove all `Box` sorts and `Box`/`BoxCurrent` terms from the system. /// /// The box values in Thrust represent an owned pointer, but are logically equivalent to the inner type. @@ -159,18 +168,24 @@ fn unbox_datatype(datatype: Datatype) -> Datatype { /// This function traverses a [`System`] and removes all `Box` related constructs. pub fn unbox(system: System) -> System { let System { + raw_commands, + datatypes, + user_defined_pred_defs, clauses, pred_vars, - datatypes, - raw_commands, } = system; let datatypes = datatypes.into_iter().map(unbox_datatype).collect(); let clauses = clauses.into_iter().map(unbox_clause).collect(); let pred_vars = pred_vars.into_iter().map(unbox_pred_var_def).collect(); + let user_defined_pred_defs = user_defined_pred_defs + .into_iter() + .map(unbox_user_defined_pred_def) + .collect(); System { + raw_commands, + datatypes, + user_defined_pred_defs, clauses, pred_vars, - datatypes, - raw_commands, } } diff --git a/tests/ui/pass/annot_preds.rs b/tests/ui/pass/annot_preds.rs new file mode 100644 index 0000000..9fa5bec --- /dev/null +++ b/tests/ui/pass/annot_preds.rs @@ -0,0 +1,21 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#[thrust::predicate] +fn is_double(x: i64, doubled_x: i64) -> bool { + "(= + (* x 2) + doubled_x + )"; true +} + +#[thrust::requires(true)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); +}