Skip to content
4 changes: 4 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub fn raw_command_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("raw_command")]
}

pub fn predicate_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("predicate")]
}

/// A [`annot::Resolver`] implementation for resolving function parameters.
///
/// The parameter names and their sorts needs to be configured via
Expand Down
18 changes: 17 additions & 1 deletion src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct Analyzer<'tcx, 'ctx> {
tcx: TyCtxt<'tcx>,
ctx: &'ctx mut analyze::Analyzer<'tcx>,
trusted: HashSet<DefId>,
predicates: HashSet<DefId>,
}

impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
Expand Down Expand Up @@ -82,6 +83,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.trusted.insert(local_def_id.to_def_id());
}

if analyzer.is_annotated_as_predicate() {
self.predicates.insert(local_def_id.to_def_id());
analyzer.analyze_predicate_definition(local_def_id);
}

use mir_ty::TypeVisitableExt as _;
if sig.has_param() && !analyzer.is_fully_annotated() {
self.ctx.register_deferred_def(local_def_id.to_def_id());
Expand All @@ -105,6 +111,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
tracing::info!(?local_def_id, "trusted");
continue;
}
if self.predicates.contains(&local_def_id.to_def_id()) {
tracing::info!(?local_def_id, "predicate");
continue;
}
let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else {
// when the local_def_id is deferred it would be skipped
continue;
Expand Down Expand Up @@ -212,7 +222,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
pub fn new(ctx: &'ctx mut analyze::Analyzer<'tcx>) -> Self {
let tcx = ctx.tcx;
let trusted = HashSet::default();
Self { ctx, tcx, trusted }
let predicates = HashSet::default();
Self {
ctx,
tcx,
trusted,
predicates,
}
}

pub fn run(&mut self) {
Expand Down
74 changes: 74 additions & 0 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ use crate::pretty::PrettyDisplayExt as _;
use crate::refine::{BasicBlockType, TypeBuilder};
use crate::rty;

fn stmt_str_literal(stmt: &rustc_hir::Stmt) -> Option<String> {
use rustc_ast::LitKind;
use rustc_hir::{Expr, ExprKind, Stmt, StmtKind};

match stmt {
Stmt {
kind:
StmtKind::Semi(Expr {
kind:
ExprKind::Lit(rustc_hir::Lit {
node: LitKind::Str(symbol, _),
..
}),
..
}),
..
} => Some(symbol.to_string()),
_ => None,
}
}

/// An implementation of the typing of local definitions.
///
/// The current implementation only applies to function definitions. The entry point is
Expand Down Expand Up @@ -106,6 +127,49 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
ret_annot
}

pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) {
let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();

// function's body
use rustc_hir::{Block, Expr, ExprKind};

let hir_map = self.tcx.hir();
let body_id = hir_map.maybe_body_owned_by(local_def_id).unwrap();
let hir_body = hir_map.body(body_id);

let predicate_body = match hir_body.value {
Expr {
kind: ExprKind::Block(Block { stmts, .. }, _),
..
} => stmts
.iter()
.find_map(stmt_str_literal)
.expect("invalid predicate definition: no string literal was found."),
_ => panic!("expected function body, got: {hir_body:?}"),
};

// names and sorts of arguments
let arg_names = self
.tcx
.fn_arg_names(local_def_id.to_def_id())
.iter()
.map(|ident| ident.to_string());

let sig = self.ctx.local_fn_sig(local_def_id);
let arg_sorts = sig
.inputs()
.iter()
.map(|input_ty| self.type_builder.build(*input_ty).to_sort());

let arg_name_and_sorts = arg_names.into_iter().zip(arg_sorts).collect::<Vec<_>>();

self.ctx.system.borrow_mut().push_pred_define(
chc::UserDefinedPred::new(pred_name),
chc::UserDefinedPredSig::from(arg_name_and_sorts),
predicate_body,
);
}

pub fn is_annotated_as_trusted(&self) -> bool {
self.tcx
.get_attrs_by_path(
Expand Down Expand Up @@ -136,6 +200,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.is_some()
}

pub fn is_annotated_as_predicate(&self) -> bool {
self.tcx
.get_attrs_by_path(
self.local_def_id.to_def_id(),
&analyze::annot::predicate_path(),
)
.next()
.is_some()
}

// TODO: unify this logic with extraction functions above
pub fn is_fully_annotated(&self) -> bool {
let has_require = self
Expand Down
20 changes: 20 additions & 0 deletions src/chc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1700,11 +1700,21 @@ pub struct PredVarDef {
pub debug_info: DebugInfo,
}

pub type UserDefinedPredSig = Vec<(String, Sort)>;

#[derive(Debug, Clone)]
pub struct UserDefinedPredDef {
symbol: UserDefinedPred,
sig: UserDefinedPredSig,
body: String,
}

/// A CHC system.
#[derive(Debug, Clone, Default)]
pub struct System {
pub raw_commands: Vec<RawCommand>,
pub datatypes: Vec<Datatype>,
pub user_defined_pred_defs: Vec<UserDefinedPredDef>,
pub clauses: IndexVec<ClauseId, Clause>,
pub pred_vars: IndexVec<PredVarId, PredVarDef>,
}
Expand All @@ -1718,6 +1728,16 @@ impl System {
self.raw_commands.push(raw_command)
}

pub fn push_pred_define(
&mut self,
symbol: UserDefinedPred,
sig: UserDefinedPredSig,
body: String,
) {
self.user_defined_pred_defs
.push(UserDefinedPredDef { symbol, sig, body })
}

pub fn push_clause(&mut self, clause: Clause) -> Option<ClauseId> {
if clause.is_nop() {
return None;
Expand Down
35 changes: 35 additions & 0 deletions src/chc/smtlib2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,33 @@ impl<'ctx, 'a> MatcherPredFun<'ctx, 'a> {
}
}

pub struct UserDefinedPredDef<'ctx, 'a> {
ctx: &'ctx FormatContext,
inner: &'a chc::UserDefinedPredDef,
}

impl<'ctx, 'a> std::fmt::Display for UserDefinedPredDef<'ctx, 'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let params = List::closed(
self.inner
.sig
.iter()
.map(|(name, sort)| format!("({} {})", name, self.ctx.fmt_sort(sort))),
);
write!(
f,
"(define-fun {name} {params} Bool {body})",
name = self.inner.symbol,
body = &self.inner.body,
)
}
}

impl<'ctx, 'a> UserDefinedPredDef<'ctx, 'a> {
pub fn new(ctx: &'ctx FormatContext, inner: &'a chc::UserDefinedPredDef) -> Self {
Self { ctx, inner }
}
}
/// A wrapper around a [`chc::System`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format.
#[derive(Debug, Clone)]
pub struct System<'a> {
Expand All @@ -578,6 +605,14 @@ impl<'a> std::fmt::Display for System<'a> {
writeln!(f, "{}\n", RawCommand::new(raw_command))?;
}

for user_defined_pred_def in &self.inner.user_defined_pred_defs {
writeln!(
f,
"{}\n",
UserDefinedPredDef::new(&self.ctx, user_defined_pred_def)
)?;
}

writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?;
for datatype in self.ctx.datatypes() {
writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?;
Expand Down
23 changes: 19 additions & 4 deletions src/chc/unbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,40 @@ fn unbox_datatype(datatype: Datatype) -> Datatype {
}
}

fn unbox_user_defined_pred_def(user_defined_pred_def: UserDefinedPredDef) -> UserDefinedPredDef {
let UserDefinedPredDef { symbol, sig, body } = user_defined_pred_def;
let sig = sig
.into_iter()
.map(|(name, sort)| (name, unbox_sort(sort)))
.collect();
UserDefinedPredDef { symbol, sig, body }
}

/// Remove all `Box` sorts and `Box`/`BoxCurrent` terms from the system.
///
/// The box values in Thrust represent an owned pointer, but are logically equivalent to the inner type.
/// This pass removes them to reduce the complexity of the CHCs sent to the solver.
/// This function traverses a [`System`] and removes all `Box` related constructs.
pub fn unbox(system: System) -> System {
let System {
raw_commands,
datatypes,
user_defined_pred_defs,
clauses,
pred_vars,
datatypes,
raw_commands,
} = system;
let datatypes = datatypes.into_iter().map(unbox_datatype).collect();
let clauses = clauses.into_iter().map(unbox_clause).collect();
let pred_vars = pred_vars.into_iter().map(unbox_pred_var_def).collect();
let user_defined_pred_defs = user_defined_pred_defs
.into_iter()
.map(unbox_user_defined_pred_def)
.collect();
System {
raw_commands,
datatypes,
user_defined_pred_defs,
clauses,
pred_vars,
datatypes,
raw_commands,
}
}
21 changes: 21 additions & 0 deletions tests/ui/pass/annot_preds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//@check-pass
//@compile-flags: -Adead_code -C debug-assertions=off

#[thrust::predicate]
fn is_double(x: i64, doubled_x: i64) -> bool {
"(=
(* x 2)
doubled_x
)"; true
}

#[thrust::requires(true)]
#[thrust::ensures(is_double(x, result))]
fn double(x: i64) -> i64 {
x + x
}

fn main() {
let a = 3;
assert!(double(a) == 6);
}