orca_wasm/iterator/
module_iterator.rs

1//! Iterator to traverse a Module
2
3use crate::ir::id::{FunctionID, GlobalID, LocalID};
4use crate::ir::module::module_functions::FuncKind;
5use crate::ir::module::module_globals::Global;
6use crate::ir::module::Module;
7use crate::ir::types::{DataType, FuncInstrMode, InstrumentationMode, Location};
8use crate::iterator::iterator_trait::{IteratingInstrumenter, Iterator};
9use crate::module_builder::AddLocal;
10use crate::opcode::{Inject, InjectAt, Instrumenter, MacroOpcode, Opcode};
11use crate::subiterator::module_subiterator::ModuleSubIterator;
12use wasmparser::Operator;
13
14/// Iterator for a Module.
15// 'b should outlive 'a
16pub struct ModuleIterator<'a, 'b> {
17    /// The Module to Iterate
18    pub module: &'a mut Module<'b>,
19    /// The SubIterator for this Module
20    mod_iterator: ModuleSubIterator,
21}
22
23#[allow(dead_code)]
24impl<'a, 'b> ModuleIterator<'a, 'b> {
25    /// Creates a new ModuleIterator
26    pub fn new(module: &'a mut Module<'b>, skip_funcs: &Vec<FunctionID>) -> Self {
27        let metadata = module.get_func_metadata();
28        ModuleIterator {
29            module,
30            mod_iterator: ModuleSubIterator::new(metadata, skip_funcs.to_owned()),
31        }
32    }
33
34    pub fn curr_op_owned(&self) -> Option<Operator<'b>> {
35        if let (
36            Location::Module {
37                func_idx,
38                instr_idx,
39                ..
40            },
41            ..,
42        ) = self.mod_iterator.curr_loc()
43        {
44            match &self.module.functions.get(func_idx).kind {
45                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
46                FuncKind::Local(l) => Some(l.body.instructions[instr_idx].op.clone()),
47            }
48        } else {
49            panic!("Should have gotten Module Location!")
50        }
51    }
52}
53
54impl<'b> Inject<'b> for ModuleIterator<'_, 'b> {
55    /// Injects an Operator at the current location
56    ///
57    /// # Example
58    /// ```no_run
59    /// use orca_wasm::ir::module::Module;
60    /// use orca_wasm::iterator::module_iterator::ModuleIterator;
61    /// use wasmparser::Operator;
62    /// use orca_wasm::ir::types::{Location};
63    /// use orca_wasm::iterator::iterator_trait::{IteratingInstrumenter, Iterator};
64    /// use orca_wasm::opcode::{Instrumenter, Opcode};
65    ///
66    /// let file = "path_to_file";
67    /// let buff = wat::parse_file(file).expect("couldn't convert the input wat to Wasm");
68    /// // Must use `parse_only_module` here as we are only concerned about a Module and not a module that is inside a Component
69    /// let mut module = Module::parse(&buff, false).expect("Unable to parse");
70    /// let mut module_it = ModuleIterator::new(&mut module, &vec![]);
71    ///
72    /// // Everytime there is a `call 1` instruction we want to inject an `i32.const 0`
73    /// let interested = Operator::Call { function_index: 1 };
74    ///
75    /// loop {
76    ///     let op = module_it.curr_op();
77    ///     let instr_mode = module_it.curr_instrument_mode();
78    ///
79    ///     if let Location::Module {
80    ///         func_idx,
81    ///         instr_idx,
82    ///     } = module_it.curr_loc().0
83    ///     {
84    ///         if *module_it.curr_op().unwrap() == interested {
85    ///             module_it.before().i32_const(1);
86    ///         }
87    ///         if module_it.next().is_none() {
88    ///             break;
89    ///         };
90    ///     } else {
91    ///         // Ensures we only get the location of a module while parsing a component
92    ///         panic!("Should've gotten Module Location!");
93    ///     }
94    /// }
95    /// ```
96    fn inject(&mut self, instr: Operator<'b>) {
97        if let (
98            Location::Module {
99                func_idx,
100                instr_idx,
101                ..
102            },
103            ..,
104        ) = self.curr_loc()
105        {
106            match self.module.functions.get_mut(func_idx as FunctionID).kind {
107                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
108                FuncKind::Local(ref mut l) => l.add_instr(instr, instr_idx),
109            }
110        } else {
111            panic!("Should have gotten Module Location!")
112        }
113    }
114}
115impl<'a> InjectAt<'a> for ModuleIterator<'_, 'a> {
116    fn inject_at(&mut self, idx: usize, mode: InstrumentationMode, instr: Operator<'a>) {
117        if let (Location::Module { func_idx, .. }, ..) = self.curr_loc() {
118            let loc = Location::Module {
119                func_idx,
120                instr_idx: idx,
121            };
122            self.set_instrument_mode_at(mode, loc);
123            self.add_instr_at(loc, instr);
124        } else {
125            panic!("Should have gotten Module Location!")
126        }
127    }
128}
129impl<'a> Opcode<'a> for ModuleIterator<'_, 'a> {}
130impl<'a> MacroOpcode<'a> for ModuleIterator<'_, 'a> {}
131impl<'a> Instrumenter<'a> for ModuleIterator<'_, 'a> {
132    ///Can be called after finishing some instrumentation to reset the mode.
133    fn finish_instr(&mut self) {
134        if let (
135            Location::Module {
136                func_idx,
137                instr_idx,
138                ..
139            },
140            ..,
141        ) = self.mod_iterator.curr_loc()
142        {
143            match &mut self.module.functions.get_mut(func_idx as FunctionID).kind {
144                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
145                FuncKind::Local(l) => l.body.instructions[instr_idx].instr_flag.finish_instr(),
146            }
147        } else {
148            panic!("Should have gotten Module Location and not Module Location!")
149        }
150    }
151    /// Returns the Instrumentation at the current Location
152    fn curr_instrument_mode(&self) -> &Option<InstrumentationMode> {
153        if let (
154            Location::Module {
155                func_idx,
156                instr_idx,
157                ..
158            },
159            ..,
160        ) = self.mod_iterator.curr_loc()
161        {
162            match &self.module.functions.get(func_idx as FunctionID).kind {
163                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
164                FuncKind::Local(l) => &l.body.instructions[instr_idx].instr_flag.current_mode,
165            }
166        } else {
167            panic!("Should have gotten Module Location and not Module Location!")
168        }
169    }
170
171    fn set_instrument_mode_at(&mut self, mode: InstrumentationMode, loc: Location) {
172        if let Location::Module {
173            func_idx,
174            instr_idx,
175            ..
176        } = loc
177        {
178            match self.module.functions.get_mut(func_idx as FunctionID).kind {
179                FuncKind::Import(_) => panic!("Cannot add an instruction to an imported function"),
180                FuncKind::Local(ref mut l) => {
181                    l.body.instructions[instr_idx].instr_flag.current_mode = Some(mode)
182                }
183            }
184        } else {
185            panic!("Should have gotten module location!")
186        }
187    }
188
189    fn curr_func_instrument_mode(&self) -> &Option<FuncInstrMode> {
190        if let (Location::Module { func_idx, .. }, ..) = self.mod_iterator.curr_loc() {
191            match &self.module.functions.get(func_idx as FunctionID).kind {
192                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
193                FuncKind::Local(l) => &l.instr_flag.current_mode,
194            }
195        } else {
196            panic!("Should have gotten Module Location and not Module Location!")
197        }
198    }
199
200    fn set_func_instrument_mode(&mut self, mode: FuncInstrMode) {
201        if let (Location::Module { func_idx, .. }, ..) = self.mod_iterator.curr_loc() {
202            match self.module.functions.get_mut(func_idx as FunctionID).kind {
203                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
204                FuncKind::Local(ref mut l) => l.instr_flag.current_mode = Some(mode),
205            }
206        } else {
207            panic!("Should have gotten Module Location and not Module Location!")
208        }
209    }
210
211    fn clear_instr_at(&mut self, loc: Location, mode: InstrumentationMode) {
212        if let Location::Module {
213            func_idx,
214            instr_idx,
215            ..
216        } = loc
217        {
218            match self.module.functions.get_mut(func_idx as FunctionID).kind {
219                FuncKind::Import(_) => panic!("Cannot add an instruction to an imported function"),
220                FuncKind::Local(ref mut l) => {
221                    l.clear_instr_at(instr_idx, mode);
222                }
223            }
224            // Only injects if it is an instrumented location
225        } else {
226            panic!("Should have gotten Module Location!")
227        }
228    }
229
230    fn add_instr_at(&mut self, loc: Location, instr: Operator<'a>) {
231        if let Location::Module {
232            func_idx,
233            instr_idx,
234            ..
235        } = loc
236        {
237            match self.module.functions.get_mut(func_idx as FunctionID).kind {
238                FuncKind::Import(_) => panic!("Cannot add an instruction to an imported function"),
239                FuncKind::Local(ref mut l) => {
240                    l.add_instr(instr, instr_idx);
241                }
242            }
243            // Only injects if it is an instrumented location
244        } else {
245            panic!("Should have gotten Module Location!")
246        }
247    }
248
249    fn empty_alternate_at(&mut self, loc: Location) -> &mut Self {
250        if let Location::Module {
251            func_idx,
252            instr_idx,
253            ..
254        } = loc
255        {
256            match self.module.functions.get_mut(func_idx).kind {
257                FuncKind::Import(_) => panic!("Cannot instrument an imported function"),
258                FuncKind::Local(ref mut l) => {
259                    l.body.instructions[instr_idx].instr_flag.alternate = Some(vec![])
260                }
261            }
262        } else {
263            panic!("Should have gotten Module Location and not Module Location!")
264        }
265        self
266    }
267
268    fn empty_block_alt_at(&mut self, loc: Location) -> &mut Self {
269        if let Location::Module {
270            func_idx,
271            instr_idx,
272            ..
273        } = loc
274        {
275            match self.module.functions.get_mut(func_idx as FunctionID).kind {
276                FuncKind::Import(_) => panic!("Cannot instrument an imported function"),
277                FuncKind::Local(ref mut l) => {
278                    l.body.instructions[instr_idx].instr_flag.block_alt = Some(vec![]);
279                    l.instr_flag.has_special_instr |= true;
280                }
281            }
282        } else {
283            panic!("Should have gotten Module Location and not Module Location!")
284        }
285        self
286    }
287
288    /// Gets the injected instruction at the current location by index
289    fn get_injected_val(&self, idx: usize) -> &Operator {
290        if let (
291            Location::Module {
292                func_idx,
293                instr_idx,
294                ..
295            },
296            ..,
297        ) = self.mod_iterator.curr_loc()
298        {
299            match &self.module.functions.get(func_idx as FunctionID).kind {
300                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
301                FuncKind::Local(l) => l.body.instructions[instr_idx].instr_flag.get_instr(idx),
302            }
303        } else {
304            panic!("Should have gotten Component Location and not Module Location!")
305        }
306    }
307}
308impl<'a> IteratingInstrumenter<'a> for ModuleIterator<'_, 'a> {
309    fn set_instrument_mode(&mut self, mode: InstrumentationMode) {
310        self.set_instrument_mode_at(mode, self.curr_loc().0);
311    }
312
313    fn add_global(&mut self, global: Global) -> GlobalID {
314        self.module.globals.add(global)
315    }
316}
317
318impl AddLocal for ModuleIterator<'_, '_> {
319    fn add_local(&mut self, val_type: DataType) -> LocalID {
320        let curr_loc = self.curr_loc();
321        if let (Location::Module { func_idx, .. }, ..) = curr_loc {
322            self.module.functions.add_local(func_idx, val_type)
323        } else {
324            panic!("Should have gotten Module Location!")
325        }
326    }
327}
328
329// Note: Marked Trait as the same lifetime as component
330impl<'a> Iterator for ModuleIterator<'_, 'a> {
331    /// Resets the Module Iterator
332    fn reset(&mut self) {
333        self.mod_iterator.reset();
334    }
335
336    /// Goes to the next instruction and returns the instruction
337    fn next(&mut self) -> Option<&Operator> {
338        match self.mod_iterator.next() {
339            false => None,
340            true => self.curr_op(),
341        }
342    }
343
344    /// Returns the Current Location as a Location and a bool value that
345    /// says whether the location is at the end of the function.
346    fn curr_loc(&self) -> (Location, bool) {
347        self.mod_iterator.curr_loc()
348    }
349
350    /// Returns the current instruction
351    fn curr_op(&self) -> Option<&Operator<'a>> {
352        if let (
353            Location::Module {
354                func_idx,
355                instr_idx,
356                ..
357            },
358            ..,
359        ) = self.mod_iterator.curr_loc()
360        {
361            match &self.module.functions.get(func_idx).kind {
362                FuncKind::Import(_) => panic!("Cannot get an instruction to an imported function"),
363                FuncKind::Local(l) => Some(&l.body.instructions[instr_idx].op),
364            }
365        } else {
366            panic!("Should have gotten Module Location!")
367        }
368    }
369}