Skip to content
Merged
4 changes: 4 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub fn raw_command_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("raw_command")]
}

pub fn predicate_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("predicate")]
}

/// A [`annot::Resolver`] implementation for resolving function parameters.
///
/// The parameter names and their sorts needs to be configured via
Expand Down
18 changes: 17 additions & 1 deletion src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct Analyzer<'tcx, 'ctx> {
tcx: TyCtxt<'tcx>,
ctx: &'ctx mut analyze::Analyzer<'tcx>,
trusted: HashSet<DefId>,
predicates: HashSet<DefId>,
}

impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
Expand Down Expand Up @@ -82,6 +83,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.trusted.insert(local_def_id.to_def_id());
}

if analyzer.is_annotated_as_predicate() {
self.predicates.insert(local_def_id.to_def_id());
analyzer.analyze_predicate_definition(local_def_id);
}

use mir_ty::TypeVisitableExt as _;
if sig.has_param() && !analyzer.is_fully_annotated() {
self.ctx.register_deferred_def(local_def_id.to_def_id());
Expand All @@ -105,6 +111,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
tracing::info!(?local_def_id, "trusted");
continue;
}
if self.predicates.contains(&local_def_id.to_def_id()) {
tracing::info!(?local_def_id, "predicate");
continue;
}
let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else {
// when the local_def_id is deferred it would be skipped
continue;
Expand Down Expand Up @@ -212,7 +222,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
pub fn new(ctx: &'ctx mut analyze::Analyzer<'tcx>) -> Self {
let tcx = ctx.tcx;
let trusted = HashSet::default();
Self { ctx, tcx, trusted }
let predicates = HashSet::default();
Self {
ctx,
tcx,
trusted,
predicates,
}
}

pub fn run(&mut self) {
Expand Down
74 changes: 74 additions & 0 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ use crate::pretty::PrettyDisplayExt as _;
use crate::refine::{BasicBlockType, TypeBuilder};
use crate::rty;

fn stmt_str_literal(stmt: &rustc_hir::Stmt) -> Option<String> {
use rustc_ast::LitKind;
use rustc_hir::{Expr, ExprKind, Stmt, StmtKind};

match stmt {
Stmt {
kind:
StmtKind::Semi(Expr {
kind:
ExprKind::Lit(rustc_hir::Lit {
node: LitKind::Str(symbol, _),
..
}),
..
}),
..
} => Some(symbol.to_string()),
_ => None,
}
}

/// An implementation of the typing of local definitions.
///
/// The current implementation only applies to function definitions. The entry point is
Expand Down Expand Up @@ -106,6 +127,49 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
ret_annot
}

pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) {
let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();

// function's body
use rustc_hir::{Block, Expr, ExprKind};

let hir_map = self.tcx.hir();
let body_id = hir_map.maybe_body_owned_by(local_def_id).unwrap();
let hir_body = hir_map.body(body_id);

let predicate_body = match hir_body.value {
Expr {
kind: ExprKind::Block(Block { stmts, .. }, _),
..
} => stmts
.iter()
.find_map(stmt_str_literal)
.expect("invalid predicate definition: no string literal was found."),
_ => panic!("expected function body, got: {hir_body:?}"),
};

// names and sorts of arguments
let arg_names = self
.tcx
.fn_arg_names(local_def_id.to_def_id())
.iter()
.map(|ident| ident.to_string());

let sig = self.ctx.local_fn_sig(local_def_id);
let arg_sorts = sig
.inputs()
.iter()
.map(|input_ty| self.type_builder.build(*input_ty).to_sort());

let arg_name_and_sorts = arg_names.into_iter().zip(arg_sorts).collect::<Vec<_>>();

self.ctx.system.borrow_mut().push_pred_define(
chc::UserDefinedPred::new(pred_name),
chc::UserDefinedPredSig::from(arg_name_and_sorts),
predicate_body,
);
}

pub fn is_annotated_as_trusted(&self) -> bool {
self.tcx
.get_attrs_by_path(
Expand Down Expand Up @@ -136,6 +200,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.is_some()
}

pub fn is_annotated_as_predicate(&self) -> bool {
self.tcx
.get_attrs_by_path(
self.local_def_id.to_def_id(),
&analyze::annot::predicate_path(),
)
.next()
.is_some()
}

// TODO: unify this logic with extraction functions above
pub fn is_fully_annotated(&self) -> bool {
let has_require = self
Expand Down
20 changes: 20 additions & 0 deletions src/chc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1700,11 +1700,21 @@ pub struct PredVarDef {
pub debug_info: DebugInfo,
}

pub type UserDefinedPredSig = Vec<(String, Sort)>;

#[derive(Debug, Clone)]
pub struct UserDefinedPredDef {
symbol: UserDefinedPred,
sig: UserDefinedPredSig,
body: String,
}

/// A CHC system.
#[derive(Debug, Clone, Default)]
pub struct System {
pub raw_commands: Vec<RawCommand>,
pub datatypes: Vec<Datatype>,
pub user_defined_pred_defs: Vec<UserDefinedPredDef>,
pub clauses: IndexVec<ClauseId, Clause>,
pub pred_vars: IndexVec<PredVarId, PredVarDef>,
}
Expand All @@ -1718,6 +1728,16 @@ impl System {
self.raw_commands.push(raw_command)
}

pub fn push_pred_define(
&mut self,
symbol: UserDefinedPred,
sig: UserDefinedPredSig,
body: String,
) {
self.user_defined_pred_defs
.push(UserDefinedPredDef { symbol, sig, body })
}

pub fn push_clause(&mut self, clause: Clause) -> Option<ClauseId> {
if clause.is_nop() {
return None;
Expand Down
44 changes: 40 additions & 4 deletions src/chc/smtlib2.rs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think raw‑command definitions and user‑defined predicates should be emitted before the datatype declarations in the generated .smt2 file.

The reason is that raw commands and predicates written by users may refer to previously declared datatypes(ex. access fieleds of structs), whereas datatype definitions emmitted by Thrust will not refer to those commands or predicates.

Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,33 @@ impl<'ctx, 'a> MatcherPredFun<'ctx, 'a> {
}
}

pub struct UserDefinedPredDef<'ctx, 'a> {
ctx: &'ctx FormatContext,
inner: &'a chc::UserDefinedPredDef,
}

impl<'ctx, 'a> std::fmt::Display for UserDefinedPredDef<'ctx, 'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let params = List::closed(
self.inner
.sig
.iter()
.map(|(name, sort)| format!("({} {})", name, self.ctx.fmt_sort(sort))),
);
write!(
f,
"(define-fun {name} {params} Bool {body})",
name = self.inner.symbol,
body = &self.inner.body,
)
}
}

impl<'ctx, 'a> UserDefinedPredDef<'ctx, 'a> {
pub fn new(ctx: &'ctx FormatContext, inner: &'a chc::UserDefinedPredDef) -> Self {
Self { ctx, inner }
}
}
/// A wrapper around a [`chc::System`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format.
#[derive(Debug, Clone)]
pub struct System<'a> {
Expand All @@ -573,16 +600,25 @@ impl<'a> std::fmt::Display for System<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "(set-logic HORN)\n")?;

writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?;
for datatype in self.ctx.datatypes() {
writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?;
writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?;
}

// insert command from #![thrust::raw_command()] here
for raw_command in &self.inner.raw_commands {
writeln!(f, "{}\n", RawCommand::new(raw_command))?;
}

writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?;
for datatype in self.ctx.datatypes() {
writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?;
writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?;
for user_defined_pred_def in &self.inner.user_defined_pred_defs {
writeln!(
f,
"{}\n",
UserDefinedPredDef::new(&self.ctx, user_defined_pred_def)
)?;
}

writeln!(f)?;
for (p, def) in self.inner.pred_vars.iter_enumerated() {
if !def.debug_info.is_empty() {
Expand Down
23 changes: 19 additions & 4 deletions src/chc/unbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,40 @@ fn unbox_datatype(datatype: Datatype) -> Datatype {
}
}

fn unbox_user_defined_pred_def(user_defined_pred_def: UserDefinedPredDef) -> UserDefinedPredDef {
let UserDefinedPredDef { symbol, sig, body } = user_defined_pred_def;
let sig = sig
.into_iter()
.map(|(name, sort)| (name, unbox_sort(sort)))
.collect();
UserDefinedPredDef { symbol, sig, body }
}

/// Remove all `Box` sorts and `Box`/`BoxCurrent` terms from the system.
///
/// The box values in Thrust represent an owned pointer, but are logically equivalent to the inner type.
/// This pass removes them to reduce the complexity of the CHCs sent to the solver.
/// This function traverses a [`System`] and removes all `Box` related constructs.
pub fn unbox(system: System) -> System {
let System {
raw_commands,
datatypes,
user_defined_pred_defs,
clauses,
pred_vars,
datatypes,
raw_commands,
} = system;
let datatypes = datatypes.into_iter().map(unbox_datatype).collect();
let clauses = clauses.into_iter().map(unbox_clause).collect();
let pred_vars = pred_vars.into_iter().map(unbox_pred_var_def).collect();
let user_defined_pred_defs = user_defined_pred_defs
.into_iter()
.map(unbox_user_defined_pred_def)
.collect();
System {
raw_commands,
datatypes,
user_defined_pred_defs,
clauses,
pred_vars,
datatypes,
raw_commands,
}
}
21 changes: 21 additions & 0 deletions tests/ui/pass/annot_preds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//@check-pass
//@compile-flags: -Adead_code -C debug-assertions=off

#[thrust::predicate]
fn is_double(x: i64, doubled_x: i64) -> bool {
"(=
(* x 2)
doubled_x
)"; true
}

#[thrust::requires(true)]
#[thrust::ensures(is_double(x, result))]
fn double(x: i64) -> i64 {
x + x
}

fn main() {
let a = 3;
assert!(double(a) == 6);
}