orca_wasm/ir/module/
module_functions.rs

1//! Intermediate Representation of a Function
2
3use crate::ir::function::FunctionModifier;
4use crate::ir::id::{FunctionID, ImportsID, LocalID, TypeID};
5use crate::ir::module::{GetID, Iter, LocalOrImport, ReIndexable};
6use crate::ir::types::{Body, FuncInstrFlag, InstrumentationMode};
7use crate::DataType;
8use log::warn;
9use std::vec::IntoIter;
10use wasmparser::Operator;
11
12/// Represents a function. Local or Imported depends on the `FuncKind`.
13#[derive(Clone, Debug)]
14pub struct Function<'a> {
15    pub(crate) kind: FuncKind<'a>,
16    name: Option<String>,
17    pub(crate) deleted: bool,
18}
19
20impl GetID for Function<'_> {
21    /// Get the ID of the function
22    fn get_id(&self) -> u32 {
23        match &self.kind {
24            FuncKind::Import(i) => *i.import_fn_id,
25            FuncKind::Local(l) => *l.func_id,
26        }
27    }
28}
29
30impl LocalOrImport for Function<'_> {
31    /// Check if it's a local function
32    fn is_local(&self) -> bool {
33        matches!(&self.kind, FuncKind::Local(_))
34    }
35
36    /// Check if it's an imported function
37    fn is_import(&self) -> bool {
38        matches!(&self.kind, FuncKind::Import(_))
39    }
40
41    /// Check if this function has been deleted
42    fn is_deleted(&self) -> bool {
43        self.deleted
44    }
45}
46
47impl<'a> Function<'a> {
48    /// Create a new function
49    pub fn new(kind: FuncKind<'a>, name: Option<String>) -> Self {
50        Function {
51            kind,
52            name,
53            deleted: false,
54        }
55    }
56
57    /// Get the TypeID of the function
58    pub fn get_type_id(&self) -> TypeID {
59        self.kind.get_type()
60    }
61
62    /// Change the kind of the Function
63    pub(crate) fn set_kind(&mut self, kind: FuncKind<'a>) {
64        self.kind = kind;
65        // Resets deletion
66        self.deleted = false;
67    }
68
69    /// Get the kind of the function
70    pub fn kind(&self) -> &FuncKind<'a> {
71        &self.kind
72    }
73
74    /// Unwrap a local function. If it is an imported function, it panics.
75    pub fn unwrap_local(&self) -> &LocalFunction<'a> {
76        self.kind.unwrap_local()
77    }
78
79    /// Unwrap a local function as mutable. If it is an imported function, it panics.
80    pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction<'a> {
81        self.kind.unwrap_local_mut()
82    }
83
84    pub(crate) fn delete(&mut self) {
85        self.deleted = true;
86    }
87}
88
89/// Represents whether a function is a Local Function or an Imported Function
90#[derive(Clone, Debug)]
91pub enum FuncKind<'a> {
92    Local(LocalFunction<'a>),
93    Import(ImportedFunction),
94}
95
96impl<'a> FuncKind<'a> {
97    /// Unwrap a local function as a read-only reference. If it is an imported function, it panics.
98    pub fn unwrap_local(&self) -> &LocalFunction<'a> {
99        match &self {
100            FuncKind::Local(l) => l,
101            FuncKind::Import(_) => panic!("Attempting to unwrap an imported function as a local!!"),
102        }
103    }
104    /// Unwrap a local function as a mutable reference. If it is an imported function, it panics.
105    pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction<'a> {
106        match self {
107            FuncKind::Local(l) => l,
108            FuncKind::Import(_) => panic!("Attempting to unwrap an imported function as a local!!"),
109        }
110    }
111
112    /// Get the TypeID of the function
113    pub fn get_type(&self) -> TypeID {
114        match &self {
115            FuncKind::Local(l) => l.ty_id,
116            FuncKind::Import(i) => i.ty_id,
117        }
118    }
119}
120
121impl PartialEq for FuncKind<'_> {
122    fn eq(&self, other: &Self) -> bool {
123        match (self, other) {
124            (FuncKind::Import(i1), FuncKind::Import(i2)) => i1.ty_id == i2.ty_id,
125            (FuncKind::Local(l1), FuncKind::Local(l2)) => l1.ty_id == l2.ty_id,
126            _ => false,
127        }
128    }
129}
130
131impl Eq for FuncKind<'_> {}
132
133/// Intermediate Representation of a Local Function
134#[derive(Clone, Debug)]
135pub struct LocalFunction<'a> {
136    pub ty_id: TypeID,
137    pub func_id: FunctionID,
138    pub instr_flag: FuncInstrFlag<'a>,
139    pub body: Body<'a>,
140    pub args: Vec<LocalID>,
141}
142
143impl<'a> LocalFunction<'a> {
144    /// Creates a new local function
145    pub fn new(type_id: TypeID, function_id: FunctionID, body: Body<'a>, num_args: usize) -> Self {
146        let mut args = vec![];
147        for arg in 0..num_args {
148            args.push(LocalID(arg as u32));
149        }
150        LocalFunction {
151            ty_id: type_id,
152            func_id: function_id,
153            instr_flag: FuncInstrFlag::default(),
154            body,
155            args,
156        }
157    }
158    pub fn add_local(&mut self, ty: DataType) -> LocalID {
159        add_local(
160            ty,
161            self.args.len(),
162            &mut self.body.num_locals,
163            &mut self.body.locals,
164        )
165    }
166
167    pub fn add_instr(&mut self, instr: Operator<'a>, instr_idx: usize) {
168        if self.instr_flag.current_mode.is_some() {
169            // inject at function level
170            self.instr_flag.add_instr(instr);
171        } else {
172            // inject at instruction level
173            let is_special = self.body.instructions[instr_idx].add_instr(instr);
174            // remember if we injected a special instrumentation (to be resolved before encoding)
175            self.instr_flag.has_special_instr |= is_special;
176        }
177    }
178
179    pub fn clear_instr_at(&mut self, instr_idx: usize, mode: InstrumentationMode) {
180        self.body.clear_instr(instr_idx, mode);
181    }
182}
183
184// Must split this out so that the Rust compiler knows that we're not mutating data being iterated
185// over in `resolve_special_instrumentation` func.
186pub(crate) fn add_local(
187    ty: DataType,
188    num_params: usize,
189    num_locals: &mut u32,
190    locals: &mut Vec<(u32, DataType)>,
191) -> LocalID {
192    let index = num_params + *num_locals as usize;
193
194    let len = locals.len();
195    *num_locals += 1;
196    if len > 0 {
197        let last = len - 1;
198        if locals[last].1 == ty {
199            locals[last].0 += 1;
200        } else {
201            locals.push((1, ty));
202        }
203    } else {
204        // If no locals, just append
205        locals.push((1, ty));
206    }
207
208    LocalID(index as u32)
209}
210
211pub(crate) fn add_locals(
212    types: &[DataType],
213    num_params: usize,
214    num_locals: &mut u32,
215    locals: &mut Vec<(u32, DataType)>,
216) {
217    // TODO: Make this more efficient instead of just iterating
218    for ty in types.iter() {
219        add_local(*ty, num_params, num_locals, locals);
220    }
221}
222
223/// Intermediate representation of an Imported Function. The actual Import is stored in the Imports field of the module.
224#[derive(Clone, Debug)]
225pub struct ImportedFunction {
226    pub import_id: ImportsID,            // Maps to location in a modules imports
227    pub(crate) import_fn_id: FunctionID, // Maps to location in a modules imported functions
228    pub ty_id: TypeID,
229}
230
231impl ImportedFunction {
232    /// Create a new imported function
233    pub fn new(id: ImportsID, type_id: TypeID, function_id: FunctionID) -> Self {
234        ImportedFunction {
235            import_id: id,
236            ty_id: type_id,
237            import_fn_id: function_id,
238        }
239    }
240}
241
242/// Intermediate representation of all the functions in a module.
243#[derive(Clone, Debug, Default)]
244pub struct Functions<'a> {
245    functions: Vec<Function<'a>>,
246    pub(crate) recalculate_ids: bool,
247}
248
249impl<'a> Functions<'a> {
250    /// Iterate over functions present in the module
251    ///
252    /// Note: Functions returned by this iterator *may* be deleted.
253    pub fn iter(&self) -> impl Iterator<Item = &Function<'a>> {
254        Iter::<Function<'a>>::iter(self)
255    }
256}
257
258impl<'a> Iter<Function<'a>> for Functions<'a> {
259    /// Get an iterator for the functions.
260    fn iter(&self) -> std::slice::Iter<'_, Function<'a>> {
261        self.functions.iter()
262    }
263
264    fn get_into_iter(&self) -> IntoIter<Function<'a>> {
265        self.functions.clone().into_iter()
266    }
267}
268
269impl<'a> ReIndexable<Function<'a>> for Functions<'a> {
270    /// Get the number of functions
271    fn len(&self) -> usize {
272        self.functions.len()
273    }
274    fn remove(&mut self, function_id: u32) -> Function<'a> {
275        self.functions.remove(function_id as usize)
276    }
277
278    fn insert(&mut self, function_id: u32, func: Function<'a>) {
279        self.functions.insert(function_id as usize, func);
280    }
281    /// Add a new function
282    fn push(&mut self, func: Function<'a>) {
283        self.functions.push(func);
284    }
285}
286
287impl<'a> Functions<'a> {
288    /// Create a new functions section
289    pub fn new(functions: Vec<Function<'a>>) -> Self {
290        Functions {
291            functions,
292            recalculate_ids: false,
293        }
294    }
295
296    /// Get a function by its FunctionID
297    pub fn get_fn_by_id(&self, function_id: FunctionID) -> Option<&Function<'a>> {
298        if *function_id < self.functions.len() as u32 {
299            return Some(&self.functions[*function_id as usize]);
300        }
301        None
302    }
303
304    /// Checks if there are no functions
305    pub fn is_empty(&self) -> bool {
306        self.functions.is_empty()
307    }
308
309    // =======================
310    // ==== FIELD GETTERS ====
311    // =======================
312
313    /// Get kind of function
314    pub fn get_kind(&self, function_id: FunctionID) -> &FuncKind<'a> {
315        &self.functions[*function_id as usize].kind
316    }
317
318    /// Get kind of function
319    // TODO -- can this be removed?
320    pub fn get_kind_mut(&mut self, function_id: FunctionID) -> &mut FuncKind<'a> {
321        &mut self.functions[*function_id as usize].kind
322    }
323
324    /// Get the name of a function
325    pub fn get_name(&self, function_id: FunctionID) -> &Option<String> {
326        &self.functions[*function_id as usize].name
327    }
328
329    /// Check if a function is a local
330    pub fn is_local(&self, function_id: FunctionID) -> bool {
331        self.functions[*function_id as usize].is_local()
332    }
333
334    /// Check if a function is an import
335    pub fn is_import(&self, function_id: FunctionID) -> bool {
336        self.functions[*function_id as usize].is_import()
337    }
338
339    /// Get the type ID of a function
340    pub fn get_type_id(&self, id: FunctionID) -> TypeID {
341        self.functions[*id as usize].get_type_id()
342    }
343
344    /// Check if it's deleted
345    pub fn is_deleted(&self, function_id: FunctionID) -> bool {
346        self.functions[*function_id as usize].is_deleted()
347    }
348
349    // ======================
350    // ==== FUNC GETTERS ====
351    // ======================
352
353    /// Get by ID
354    pub fn get(&self, function_id: FunctionID) -> &Function<'a> {
355        &self.functions[*function_id as usize]
356    }
357
358    /// Get mutable function by ID
359    pub fn get_mut(&mut self, function_id: FunctionID) -> &mut Function<'a> {
360        &mut self.functions[*function_id as usize]
361    }
362
363    /// Unwrap local function.
364    pub fn unwrap_local(&mut self, function_id: FunctionID) -> &mut LocalFunction<'a> {
365        self.functions[*function_id as usize].unwrap_local_mut()
366    }
367
368    /// Get local Function ID by name
369    pub fn get_local_fid_by_name(&self, name: &str) -> Option<FunctionID> {
370        for (idx, func) in self.functions.iter().enumerate() {
371            if let FuncKind::Local(l) = &func.kind {
372                if let Some(n) = &l.body.name {
373                    if n == name {
374                        return Some(FunctionID(idx as u32));
375                    }
376                }
377            }
378        }
379        None
380    }
381
382    // =======================
383    // ==== MANIPULATIONS ====
384    // =======================
385
386    /// Get a function modifier from a function index
387    pub fn get_fn_modifier<'b>(
388        &'b mut self,
389        func_id: FunctionID,
390    ) -> Option<FunctionModifier<'b, 'a>> {
391        // grab type and section and code section
392        match &mut self.functions.get_mut(*func_id as usize)?.kind {
393            FuncKind::Local(ref mut l) => {
394                // the instrflag should be reset!
395                l.instr_flag.finish_instr();
396                Some(FunctionModifier::init(
397                    &mut l.instr_flag,
398                    &mut l.body,
399                    &mut l.args,
400                ))
401            }
402            _ => None,
403        }
404    }
405
406    /// Delete a function
407    pub(crate) fn delete(&mut self, id: FunctionID) {
408        self.recalculate_ids = true;
409        if *id < self.functions.len() as u32 {
410            self.functions[*id as usize].delete();
411        }
412    }
413
414    fn next_id(&self) -> FunctionID {
415        FunctionID(self.functions.len() as u32)
416    }
417
418    pub(crate) fn add_local_func(
419        &mut self,
420        mut local_function: LocalFunction<'a>,
421        name: Option<String>,
422    ) -> FunctionID {
423        self.recalculate_ids = true;
424        // fix the ID of the function
425        let id = self.next_id();
426        local_function.func_id = id;
427
428        self.push(Function::new(FuncKind::Local(local_function), name.clone()));
429        if let Some(name) = name {
430            self.set_local_fn_name(id, name);
431        }
432        id
433    }
434
435    pub(crate) fn add_import_func(
436        &mut self,
437        imp_id: ImportsID,
438        ty_id: TypeID,
439        name: Option<String>,
440        // The id of the function we're using (at least until re-indexing)
441        imp_fn_id: u32,
442    ) {
443        self.recalculate_ids = true;
444        assert_eq!(*self.next_id(), imp_fn_id);
445        self.functions.push(Function::new(
446            FuncKind::Import(ImportedFunction::new(imp_id, ty_id, FunctionID(imp_fn_id))),
447            name,
448        ));
449    }
450
451    pub(crate) fn add_local(&mut self, func_idx: FunctionID, ty: DataType) -> LocalID {
452        let local_func = self.functions[*func_idx as usize].unwrap_local_mut();
453        local_func.add_local(ty)
454    }
455
456    /// Set the name for a local function. Returns false if it is an imported function.
457    pub fn set_local_fn_name(&mut self, func_idx: FunctionID, name: String) -> bool {
458        match &mut self.functions[*func_idx as usize].kind {
459            FuncKind::Import(_) => {
460                warn!("is an imported function!");
461                return false;
462            }
463            FuncKind::Local(ref mut l) => l.body.name = Some(name.clone()),
464        }
465        self.functions[*func_idx as usize].name = Some(name);
466        true
467    }
468
469    /// Set the name for an imported function. Returns false if it is a local function.
470    pub(crate) fn set_imported_fn_name(&mut self, func_idx: FunctionID, name: String) -> bool {
471        if self.functions[*func_idx as usize].is_local() {
472            warn!("is a local function!");
473            return false;
474        }
475        self.functions[*func_idx as usize].name = Some(name);
476        true
477    }
478}