use crate::ir::function::FunctionModifier;
use crate::ir::id::{FunctionID, ImportsID, LocalID, TypeID};
use crate::ir::module::{GetID, Iter, LocalOrImport, ReIndexable};
use crate::ir::types::{Body, FuncInstrFlag, InstrumentationMode};
use crate::DataType;
use log::warn;
use std::vec::IntoIter;
use wasmparser::Operator;
#[derive(Clone, Debug)]
pub struct Function<'a> {
pub(crate) kind: FuncKind<'a>,
name: Option<String>,
pub(crate) deleted: bool,
}
impl GetID for Function<'_> {
fn get_id(&self) -> u32 {
match &self.kind {
FuncKind::Import(i) => *i.import_fn_id,
FuncKind::Local(l) => *l.func_id,
}
}
}
impl LocalOrImport for Function<'_> {
fn is_local(&self) -> bool {
matches!(&self.kind, FuncKind::Local(_))
}
fn is_import(&self) -> bool {
matches!(&self.kind, FuncKind::Import(_))
}
fn is_deleted(&self) -> bool {
self.deleted
}
}
impl<'a> Function<'a> {
pub fn new(kind: FuncKind<'a>, name: Option<String>) -> Self {
Function {
kind,
name,
deleted: false,
}
}
pub fn get_type_id(&self) -> TypeID {
self.kind.get_type()
}
pub(crate) fn set_kind(&mut self, kind: FuncKind<'a>) {
self.kind = kind;
self.deleted = false;
}
pub fn kind(&self) -> &FuncKind<'a> {
&self.kind
}
pub fn unwrap_local(&self) -> &LocalFunction<'a> {
self.kind.unwrap_local()
}
pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction<'a> {
self.kind.unwrap_local_mut()
}
pub(crate) fn delete(&mut self) {
self.deleted = true;
}
}
#[derive(Clone, Debug)]
pub enum FuncKind<'a> {
Local(LocalFunction<'a>),
Import(ImportedFunction),
}
impl<'a> FuncKind<'a> {
pub fn unwrap_local(&self) -> &LocalFunction<'a> {
match &self {
FuncKind::Local(l) => l,
FuncKind::Import(_) => panic!("Attempting to unwrap an imported function as a local!!"),
}
}
pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction<'a> {
match self {
FuncKind::Local(l) => l,
FuncKind::Import(_) => panic!("Attempting to unwrap an imported function as a local!!"),
}
}
pub fn get_type(&self) -> TypeID {
match &self {
FuncKind::Local(l) => l.ty_id,
FuncKind::Import(i) => i.ty_id,
}
}
}
impl PartialEq for FuncKind<'_> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(FuncKind::Import(i1), FuncKind::Import(i2)) => i1.ty_id == i2.ty_id,
(FuncKind::Local(l1), FuncKind::Local(l2)) => l1.ty_id == l2.ty_id,
_ => false,
}
}
}
impl Eq for FuncKind<'_> {}
#[derive(Clone, Debug)]
pub struct LocalFunction<'a> {
pub ty_id: TypeID,
pub func_id: FunctionID,
pub instr_flag: FuncInstrFlag<'a>,
pub body: Body<'a>,
pub args: Vec<LocalID>,
}
impl<'a> LocalFunction<'a> {
pub fn new(type_id: TypeID, function_id: FunctionID, body: Body<'a>, num_args: usize) -> Self {
let mut args = vec![];
for arg in 0..num_args {
args.push(LocalID(arg as u32));
}
LocalFunction {
ty_id: type_id,
func_id: function_id,
instr_flag: FuncInstrFlag::default(),
body,
args,
}
}
pub fn add_local(&mut self, ty: DataType) -> LocalID {
add_local(
ty,
self.args.len(),
&mut self.body.num_locals,
&mut self.body.locals,
)
}
pub fn add_instr(&mut self, instr: Operator<'a>, instr_idx: usize) {
if self.instr_flag.current_mode.is_some() {
self.instr_flag.add_instr(instr);
} else {
let is_special = self.body.instructions[instr_idx].add_instr(instr);
self.instr_flag.has_special_instr |= is_special;
}
}
pub fn clear_instr_at(&mut self, instr_idx: usize, mode: InstrumentationMode) {
self.body.clear_instr(instr_idx, mode);
}
}
pub(crate) fn add_local(
ty: DataType,
num_params: usize,
num_locals: &mut usize,
locals: &mut Vec<(u32, DataType)>,
) -> LocalID {
let index = num_params + *num_locals;
let len = locals.len();
*num_locals += 1;
if len > 0 {
let last = len - 1;
if locals[last].1 == ty {
locals[last].0 += 1;
} else {
locals.push((1, ty));
}
} else {
locals.push((1, ty));
}
LocalID(index as u32)
}
#[derive(Clone, Debug)]
pub struct ImportedFunction {
pub import_id: ImportsID, pub(crate) import_fn_id: FunctionID, pub ty_id: TypeID,
}
impl ImportedFunction {
pub fn new(id: ImportsID, type_id: TypeID, function_id: FunctionID) -> Self {
ImportedFunction {
import_id: id,
ty_id: type_id,
import_fn_id: function_id,
}
}
}
#[allow(dead_code)] #[derive(Clone, Debug, Default)]
pub struct Functions<'a> {
functions: Vec<Function<'a>>,
pub(crate) recalculate_ids: bool,
}
impl<'a> Iter<Function<'a>> for Functions<'a> {
fn iter(&self) -> std::slice::Iter<'_, Function<'a>> {
self.functions.iter()
}
fn get_into_iter(&self) -> IntoIter<Function<'a>> {
self.functions.clone().into_iter()
}
}
impl<'a> ReIndexable<Function<'a>> for Functions<'a> {
fn len(&self) -> usize {
self.functions.len()
}
fn remove(&mut self, function_id: u32) -> Function<'a> {
self.functions.remove(function_id as usize)
}
fn insert(&mut self, function_id: u32, func: Function<'a>) {
self.functions.insert(function_id as usize, func);
}
fn push(&mut self, func: Function<'a>) {
self.functions.push(func);
}
}
impl<'a> Functions<'a> {
pub fn new(functions: Vec<Function<'a>>) -> Self {
Functions {
functions,
recalculate_ids: false,
}
}
pub fn get_fn_by_id(&self, function_id: FunctionID) -> Option<&Function<'a>> {
if *function_id < self.functions.len() as u32 {
return Some(&self.functions[*function_id as usize]);
}
None
}
pub fn is_empty(&self) -> bool {
self.functions.is_empty()
}
pub fn get_kind(&self, function_id: FunctionID) -> &FuncKind<'a> {
&self.functions[*function_id as usize].kind
}
pub fn get_kind_mut(&mut self, function_id: FunctionID) -> &mut FuncKind<'a> {
&mut self.functions[*function_id as usize].kind
}
pub fn get_name(&self, function_id: FunctionID) -> &Option<String> {
&self.functions[*function_id as usize].name
}
pub fn is_local(&self, function_id: FunctionID) -> bool {
self.functions[*function_id as usize].is_local()
}
pub fn is_import(&self, function_id: FunctionID) -> bool {
self.functions[*function_id as usize].is_import()
}
pub fn get_type_id(&self, id: FunctionID) -> TypeID {
self.functions[*id as usize].get_type_id()
}
pub fn is_deleted(&self, function_id: FunctionID) -> bool {
self.functions[*function_id as usize].is_deleted()
}
pub fn get(&self, function_id: FunctionID) -> &Function<'a> {
&self.functions[*function_id as usize]
}
pub fn get_mut(&mut self, function_id: FunctionID) -> &mut Function<'a> {
&mut self.functions[*function_id as usize]
}
pub fn unwrap_local(&mut self, function_id: FunctionID) -> &mut LocalFunction<'a> {
self.functions[*function_id as usize].unwrap_local_mut()
}
pub fn get_local_fid_by_name(&self, name: &str) -> Option<FunctionID> {
for (idx, func) in self.functions.iter().enumerate() {
if let FuncKind::Local(l) = &func.kind {
match &l.body.name {
Some(n) => {
if n == name {
return Some(FunctionID(idx as u32));
}
}
None => {}
}
}
}
None
}
pub fn get_fn_modifier<'b>(
&'b mut self,
func_id: FunctionID,
) -> Option<FunctionModifier<'b, 'a>> {
return match &mut self.functions.get_mut(*func_id as usize)?.kind {
FuncKind::Local(ref mut l) => Some(FunctionModifier::init(&mut l.body, &mut l.args)),
_ => None,
};
}
pub(crate) fn delete(&mut self, id: FunctionID) {
self.recalculate_ids = true;
if *id < self.functions.len() as u32 {
self.functions[*id as usize].delete();
}
}
fn next_id(&self) -> FunctionID {
FunctionID(self.functions.len() as u32)
}
pub(crate) fn add_local_func(
&mut self,
mut local_function: LocalFunction<'a>,
name: Option<String>,
) -> FunctionID {
self.recalculate_ids = true;
let id = self.next_id();
local_function.func_id = id;
self.push(Function::new(FuncKind::Local(local_function), name.clone()));
if let Some(name) = name {
self.set_local_fn_name(id, name);
}
id
}
pub(crate) fn add_import_func(
&mut self,
imp_id: ImportsID,
ty_id: TypeID,
name: Option<String>,
imp_fn_id: u32,
) {
self.recalculate_ids = true;
assert_eq!(*self.next_id(), imp_fn_id);
self.functions.push(Function::new(
FuncKind::Import(ImportedFunction::new(imp_id, ty_id, FunctionID(imp_fn_id))),
name,
));
}
pub(crate) fn add_local(&mut self, func_idx: FunctionID, ty: DataType) -> LocalID {
let local_func = self.functions[*func_idx as usize].unwrap_local_mut();
local_func.add_local(ty)
}
pub fn set_local_fn_name(&mut self, func_idx: FunctionID, name: String) -> bool {
match &mut self.functions[*func_idx as usize].kind {
FuncKind::Import(_) => {
warn!("is an imported function!");
return false;
}
FuncKind::Local(ref mut l) => l.body.name = Some(name.clone()),
}
self.functions[*func_idx as usize].name = Some(name);
true
}
pub(crate) fn set_imported_fn_name(&mut self, func_idx: FunctionID, name: String) -> bool {
if self.functions[*func_idx as usize].is_local() {
warn!("is a local function!");
return false;
}
self.functions[*func_idx as usize].name = Some(name);
true
}
}