orca_wasm/ir/module/
mod.rs

1//! Intermediate Representation of a wasm module.
2
3use super::types::{DataType, InitExpr, Instruction, InstrumentationMode};
4use crate::error::Error;
5use crate::ir::function::FunctionModifier;
6use crate::ir::id::{DataSegmentID, FunctionID, GlobalID, ImportsID, LocalID, MemoryID, TypeID};
7use crate::ir::module::module_exports::{Export, ModuleExports};
8use crate::ir::module::module_functions::{
9    add_local, FuncKind, Function, Functions, ImportedFunction, LocalFunction,
10};
11use crate::ir::module::module_globals::{
12    Global, GlobalKind, ImportedGlobal, LocalGlobal, ModuleGlobals,
13};
14use crate::ir::module::module_imports::{Import, ModuleImports};
15use crate::ir::module::module_memories::{ImportedMemory, LocalMemory, MemKind, Memories, Memory};
16use crate::ir::module::module_tables::ModuleTables;
17use crate::ir::module::module_types::{ModuleTypes, Types};
18use crate::ir::types::InstrumentationMode::{BlockAlt, BlockEntry, BlockExit, SemanticAfter};
19use crate::ir::types::{
20    BlockType, Body, CustomSections, DataSegment, DataSegmentKind, ElementItems, ElementKind,
21    InstrumentationFlag,
22};
23use crate::ir::wrappers::{
24    indirect_namemap_parser2encoder, namemap_parser2encoder, refers_to_func, refers_to_global,
25    refers_to_memory, update_fn_instr, update_global_instr, update_memory_instr,
26};
27use crate::opcode::{Inject, Instrumenter};
28use crate::{Location, Opcode};
29use log::{error, warn};
30use std::borrow::Cow;
31use std::collections::HashMap;
32use std::vec::IntoIter;
33use wasm_encoder::reencode::{Reencode, RoundtripReencoder};
34use wasm_encoder::TagSection;
35use wasmparser::Operator::Block;
36use wasmparser::{
37    CompositeInnerType, ExternalKind, GlobalType, MemoryType, Operator, Parser, Payload, TagType,
38    TypeRef,
39};
40
41pub mod module_exports;
42pub mod module_functions;
43pub mod module_globals;
44pub mod module_imports;
45pub mod module_memories;
46pub mod module_tables;
47pub mod module_types;
48#[cfg(test)]
49mod test;
50
51#[derive(Debug, Default)]
52/// Intermediate Representation of a wasm module. See the [WASM Spec] for different sections.
53///
54/// [WASM Spec]: https://webassembly.github.io/spec/core/binary/modules.html
55pub struct Module<'a> {
56    /// name of module
57    pub module_name: Option<String>,
58    /// Types
59    pub types: ModuleTypes,
60    /// Imports
61    pub imports: ModuleImports<'a>,
62    /// Mapping from function index to type index.
63    /// Note that `|functions| == num_functions + num_imported_functions`
64    pub functions: Functions<'a>,
65    /// Each table has a type and optional initialization expression.
66    pub tables: ModuleTables<'a>,
67    /// Memories
68    pub memories: Memories,
69    /// Globals
70    pub globals: ModuleGlobals,
71    /// Data Sections
72    pub data: Vec<DataSegment>,
73    data_count_section_exists: bool,
74    /// Exports
75    pub exports: ModuleExports,
76    /// Index of the start function.
77    pub start: Option<FunctionID>,
78    /// Elements
79    pub elements: Vec<(ElementKind<'a>, ElementItems<'a>)>,
80    /// Tags
81    pub tags: Vec<TagType>,
82    /// Custom Sections
83    pub custom_sections: CustomSections<'a>,
84    /// Number of local functions (not counting imported functions)
85    pub(crate) num_local_functions: u32,
86    /// Number of local globals (not counting imported globals)
87    pub(crate) num_local_globals: u32,
88    /// Number of local tables (not counting imported tables)
89    #[allow(dead_code)]
90    pub(crate) num_local_tables: u32,
91    /// Number of local memories (not counting imported memories)
92    #[allow(dead_code)]
93    pub(crate) num_local_memories: u32,
94
95    // just a placeholder for round-trip
96    pub(crate) local_names: wasm_encoder::IndirectNameMap,
97    pub(crate) label_names: wasm_encoder::IndirectNameMap,
98    pub(crate) type_names: wasm_encoder::NameMap,
99    pub(crate) table_names: wasm_encoder::NameMap,
100    pub(crate) memory_names: wasm_encoder::NameMap,
101    pub(crate) global_names: wasm_encoder::NameMap,
102    pub(crate) elem_names: wasm_encoder::NameMap,
103    pub(crate) data_names: wasm_encoder::NameMap,
104    pub(crate) field_names: wasm_encoder::IndirectNameMap,
105    pub(crate) tag_names: wasm_encoder::NameMap,
106}
107
108impl<'a> Module<'a> {
109    /// Parses a `Module` from a wasm binary.
110    ///
111    /// # Example
112    ///
113    /// ```no_run
114    /// use orca_wasm::Module;
115    ///
116    /// let file = "path_to_file";
117    /// let buff = wat::parse_file(file).expect("couldn't convert the input wat to Wasm");
118    /// let module = Module::parse(&buff, false).unwrap();
119    /// ```
120    pub fn parse(wasm: &'a [u8], enable_multi_memory: bool) -> Result<Self, Error> {
121        let parser = Parser::new(0);
122        Module::parse_internal(wasm, enable_multi_memory, parser)
123    }
124
125    pub(crate) fn parse_internal(
126        wasm: &'a [u8],
127        enable_multi_memory: bool,
128        parser: Parser,
129    ) -> Result<Self, Error> {
130        let mut imports: ModuleImports = ModuleImports::default();
131        let mut types: Vec<Types> = vec![];
132        let mut data = vec![];
133        let mut tables = vec![];
134        let mut memories = vec![];
135        let mut functions = vec![];
136        let mut elements = vec![];
137        let mut code_section_count = 0;
138        let mut code_sections = vec![];
139        let mut globals = vec![];
140        let mut exports = vec![];
141        let mut start = None;
142        let mut data_section_count = None;
143        let mut custom_sections = vec![];
144        let mut tags: Vec<TagType> = vec![];
145
146        let mut module_name: Option<String> = None;
147        // for the other names, we directly encode it without passing them into the IR
148        let mut local_names = wasm_encoder::IndirectNameMap::new();
149        let mut label_names = wasm_encoder::IndirectNameMap::new();
150        let mut type_names = wasm_encoder::NameMap::new();
151        let mut table_names = wasm_encoder::NameMap::new();
152        let mut memory_names = wasm_encoder::NameMap::new();
153        let mut global_names = wasm_encoder::NameMap::new();
154        let mut elem_names = wasm_encoder::NameMap::new();
155        let mut data_names = wasm_encoder::NameMap::new();
156        let mut field_names = wasm_encoder::IndirectNameMap::new();
157        let mut tag_names = wasm_encoder::NameMap::new();
158        let mut recgroup_map = HashMap::new();
159
160        for payload in parser.parse_all(wasm) {
161            let payload = payload?;
162            match payload {
163                Payload::ImportSection(import_section_reader) => {
164                    let mut temp = vec![];
165                    // count number of imported functions
166                    for import in import_section_reader.into_iter() {
167                        let imp = Import::from(import?);
168                        temp.push(imp);
169                    }
170                    imports = ModuleImports::new(temp);
171                }
172                Payload::TypeSection(type_section_reader) => {
173                    let mut ty_idx: u32 = 0;
174                    for (id, ty) in type_section_reader.into_iter().enumerate() {
175                        let rec_group = ty.clone()?.is_explicit_rec_group();
176                        for subtype in ty?.types() {
177                            match subtype.composite_type.inner.clone() {
178                                CompositeInnerType::Func(fty) => {
179                                    let fun_ty = fty;
180                                    let params = fun_ty
181                                        .params()
182                                        .iter()
183                                        .map(|x| DataType::from(*x))
184                                        .collect::<Vec<_>>()
185                                        .into_boxed_slice();
186                                    let results = fun_ty
187                                        .results()
188                                        .iter()
189                                        .map(|x| DataType::from(*x))
190                                        .collect::<Vec<_>>()
191                                        .into_boxed_slice();
192                                    let final_ty = Types::FuncType {
193                                        params,
194                                        results,
195                                        super_type: subtype.supertype_idx,
196                                        is_final: subtype.is_final,
197                                        shared: subtype.composite_type.shared,
198                                    };
199                                    types.push(final_ty.clone());
200
201                                    if rec_group {
202                                        recgroup_map.insert(ty_idx, id as u32);
203                                    }
204                                }
205                                CompositeInnerType::Array(aty) => {
206                                    let array_ty = Types::ArrayType {
207                                        mutable: aty.0.mutable,
208                                        fields: DataType::from(aty.0.element_type),
209                                        super_type: subtype.supertype_idx,
210                                        is_final: subtype.is_final,
211                                        shared: subtype.composite_type.shared,
212                                    };
213                                    types.push(array_ty.clone());
214
215                                    if rec_group {
216                                        recgroup_map.insert(ty_idx, id as u32);
217                                    }
218                                }
219                                CompositeInnerType::Struct(sty) => {
220                                    let struct_ty = Types::StructType {
221                                        mutable: sty
222                                            .fields
223                                            .iter()
224                                            .map(|field| field.mutable)
225                                            .collect::<Vec<_>>(),
226                                        fields: sty
227                                            .fields
228                                            .iter()
229                                            .map(|field| DataType::from(field.element_type))
230                                            .collect::<Vec<_>>(),
231                                        super_type: subtype.supertype_idx,
232                                        is_final: subtype.is_final,
233                                        shared: subtype.composite_type.shared,
234                                    };
235                                    types.push(struct_ty.clone());
236                                    if rec_group {
237                                        recgroup_map.insert(ty_idx, id as u32);
238                                    }
239                                }
240                                CompositeInnerType::Cont(cty) => {
241                                    let cont_ty = Types::ContType {
242                                        packed_index: cty.0,
243                                        super_type: subtype.supertype_idx,
244                                        is_final: subtype.is_final,
245                                        shared: subtype.composite_type.shared,
246                                    };
247                                    types.push(cont_ty.clone());
248                                    if rec_group {
249                                        recgroup_map.insert(ty_idx, id as u32);
250                                    }
251                                }
252                            }
253                            ty_idx += 1;
254                        }
255                    }
256                }
257                Payload::DataSection(data_section_reader) => {
258                    data = data_section_reader
259                        .into_iter()
260                        .map(|sec| {
261                            sec.map_err(Error::from)
262                                .and_then(DataSegment::from_wasmparser)
263                        })
264                        .collect::<Result<_, _>>()?;
265                }
266                Payload::TableSection(table_section_reader) => {
267                    tables = table_section_reader
268                        .into_iter()
269                        .map(|t| {
270                            t.map_err(Error::from).map(|t| match t.init {
271                                wasmparser::TableInit::RefNull => (t.ty, None),
272                                wasmparser::TableInit::Expr(e) => (t.ty, Some(e)),
273                            })
274                        })
275                        .collect::<Result<_, _>>()?;
276                }
277                Payload::MemorySection(memory_section_reader) => {
278                    memories = memory_section_reader
279                        .into_iter()
280                        .collect::<Result<_, _>>()?;
281                }
282                Payload::FunctionSection(function_section_reader) => {
283                    let temp: Vec<u32> = function_section_reader
284                        .into_iter()
285                        .collect::<Result<_, _>>()?;
286                    functions.extend(temp.iter().map(|id| TypeID(*id)));
287                }
288                Payload::GlobalSection(global_section_reader) => {
289                    globals = global_section_reader
290                        .into_iter()
291                        .map(|g| Global::from_wasmparser(g?))
292                        .collect::<Result<_, _>>()?;
293                }
294                Payload::ExportSection(export_section_reader) => {
295                    for exp in export_section_reader.into_iter() {
296                        exports.push(Export::from(exp?));
297                    }
298                }
299                Payload::StartSection { func, range: _ } => {
300                    if start.is_some() {
301                        return Err(Error::MultipleStartSections);
302                    }
303                    start = Some(FunctionID(func));
304                }
305                Payload::ElementSection(element_section_reader) => {
306                    for element in element_section_reader.into_iter() {
307                        let element = element?;
308                        let items = ElementItems::from_wasmparser(element.items.clone())?;
309                        elements.push((ElementKind::from_wasmparser(element.kind)?, items));
310                    }
311                }
312                Payload::DataCountSection { count, range: _ } => {
313                    data_section_count = Some(count);
314                }
315                Payload::CodeSectionStart {
316                    count,
317                    range: _,
318                    size: _,
319                } => {
320                    code_section_count = count as usize;
321                }
322                Payload::CodeSectionEntry(body) => {
323                    let locals_reader = body.get_locals_reader()?;
324                    let locals = locals_reader.into_iter().collect::<Result<Vec<_>, _>>()?;
325                    let mut num_locals = 0;
326                    let locals: Vec<(u32, DataType)> = locals
327                        .iter()
328                        .map(|(count, val_type)| {
329                            num_locals += count;
330                            (*count, DataType::from(*val_type))
331                        })
332                        .collect();
333
334                    let instructions = body
335                        .get_operators_reader()?
336                        .into_iter()
337                        .collect::<Result<Vec<_>, _>>()?;
338                    if let Some(last) = instructions.last() {
339                        if let Operator::End = last {
340                        } else {
341                            return Err(Error::MissingFunctionEnd {
342                                func_range: body.range(),
343                            });
344                        }
345                    }
346                    if !enable_multi_memory
347                        && instructions.iter().any(|i| match i {
348                            Operator::MemoryGrow { mem, .. } | Operator::MemorySize { mem, .. } => {
349                                *mem != 0x00
350                            }
351                            _ => false,
352                        })
353                    {
354                        return Err(Error::InvalidMemoryReservedByte {
355                            func_range: body.range(),
356                        });
357                    }
358                    let instructions_bool: Vec<_> =
359                        instructions.into_iter().map(Instruction::new).collect();
360                    code_sections.push(Body {
361                        locals,
362                        num_locals,
363                        instructions: instructions_bool.clone(),
364                        num_instructions: instructions_bool.len(),
365                        name: None,
366                    });
367                }
368                Payload::TagSection(tag_section_reader) => {
369                    for tag in tag_section_reader.into_iter() {
370                        match tag {
371                            Ok(t) => tags.push(t),
372                            Err(e) => panic!("Error encored in tag section!: {}", e),
373                        }
374                    }
375                }
376                Payload::CustomSection(custom_section_reader) => {
377                    match custom_section_reader.as_known() {
378                        wasmparser::KnownCustom::Name(name_section_reader) => {
379                            for subsection in name_section_reader {
380                                #[allow(clippy::single_match)]
381                                match subsection? {
382                                    wasmparser::Name::Function(names) => {
383                                        for name in names {
384                                            let naming = name?;
385                                            let abs_idx = naming.index;
386                                            if abs_idx < imports.num_funcs {
387                                                let mut import_func_count = 0;
388                                                // TODO: this is very expensive, can we optimize this?
389                                                for import in imports.iter_mut() {
390                                                    if import.is_function() {
391                                                        if import_func_count == abs_idx {
392                                                            import.custom_name =
393                                                                Some(naming.name.to_string());
394                                                            break;
395                                                        }
396                                                        import_func_count += 1;
397                                                    }
398                                                }
399                                            } else {
400                                                let rel_idx = abs_idx - imports.num_funcs;
401                                                // assert!(0 < rel_idx && rel_idx < code_sections.len() as u32);
402                                                code_sections[rel_idx as usize].name =
403                                                    Some(naming.name.to_string());
404                                            }
405                                        }
406                                    }
407                                    wasmparser::Name::Module { name, .. } => {
408                                        module_name = Some(name.to_string());
409                                    }
410                                    wasmparser::Name::Local(names) => {
411                                        local_names = indirect_namemap_parser2encoder(names);
412                                    }
413                                    wasmparser::Name::Label(names) => {
414                                        label_names = indirect_namemap_parser2encoder(names);
415                                    }
416                                    wasmparser::Name::Type(names) => {
417                                        type_names = namemap_parser2encoder(names);
418                                    }
419                                    wasmparser::Name::Table(names) => {
420                                        table_names = namemap_parser2encoder(names);
421                                    }
422                                    wasmparser::Name::Memory(names) => {
423                                        memory_names = namemap_parser2encoder(names);
424                                    }
425                                    wasmparser::Name::Global(names) => {
426                                        global_names = namemap_parser2encoder(names);
427                                    }
428                                    wasmparser::Name::Element(names) => {
429                                        elem_names = namemap_parser2encoder(names);
430                                    }
431                                    wasmparser::Name::Data(names) => {
432                                        data_names = namemap_parser2encoder(names);
433                                    }
434                                    wasmparser::Name::Field(names) => {
435                                        field_names = indirect_namemap_parser2encoder(names);
436                                    }
437                                    wasmparser::Name::Tag(names) => {
438                                        tag_names = namemap_parser2encoder(names);
439                                    }
440                                    wasmparser::Name::Unknown { .. } => {}
441                                }
442                            }
443                        }
444                        wasmparser::KnownCustom::Producers(producer_section_reader) => {
445                            let field = producer_section_reader
446                                .into_iter()
447                                .next()
448                                .unwrap()
449                                .expect("producers field");
450                            let _value = field
451                                .values
452                                .into_iter()
453                                .collect::<Result<Vec<_>, _>>()
454                                .expect("values");
455                            custom_sections
456                                .push((custom_section_reader.name(), custom_section_reader.data()));
457                        }
458                        _ => {
459                            custom_sections
460                                .push((custom_section_reader.name(), custom_section_reader.data()));
461                        }
462                    }
463                }
464                Payload::Version {
465                    num,
466                    encoding: _,
467                    range: _,
468                } => {
469                    if num != 1 {
470                        return Err(Error::UnknownVersion(num as u32));
471                    }
472                }
473                Payload::UnknownSection {
474                    id,
475                    contents: _,
476                    range: _,
477                } => return Err(Error::UnknownSection { section_id: id }),
478                Payload::ModuleSection {
479                    parser: _,
480                    unchecked_range: _,
481                }
482                | Payload::InstanceSection(_)
483                | Payload::CoreTypeSection(_)
484                | Payload::ComponentSection {
485                    parser: _,
486                    unchecked_range: _,
487                }
488                | Payload::ComponentInstanceSection(_)
489                | Payload::ComponentAliasSection(_)
490                | Payload::ComponentTypeSection(_)
491                | Payload::ComponentCanonicalSection(_)
492                | Payload::ComponentStartSection { start: _, range: _ }
493                | Payload::ComponentImportSection(_)
494                | Payload::ComponentExportSection(_)
495                | Payload::End(_) => {}
496                _ => todo!(),
497            }
498        }
499        if code_section_count != code_sections.len() || code_section_count != functions.len() {
500            return Err(Error::IncorrectCodeCounts {
501                function_section_count: functions.len(),
502                code_section_declared_count: code_section_count,
503                code_section_actual_count: code_sections.len(),
504            });
505        }
506        if let Some(data_count) = data_section_count {
507            if data_count as usize != data.len() {
508                return Err(Error::IncorrectDataCount {
509                    declared_count: data_count as usize,
510                    actual_count: data.len(),
511                });
512            }
513        }
514
515        // Add all the functions. First add all the imported functions as they have the first IDs
516        let mut final_funcs = vec![];
517        let mut imp_fn_id = 0;
518        for (index, imp) in imports.iter().enumerate() {
519            if let TypeRef::Func(u) = imp.ty {
520                final_funcs.push(Function::new(
521                    FuncKind::Import(ImportedFunction::new(
522                        ImportsID(index as u32),
523                        TypeID(u),
524                        FunctionID(imp_fn_id),
525                    )),
526                    Some(imp.name.parse().unwrap()),
527                ));
528                imp_fn_id += 1;
529            }
530        }
531        // Local Functions
532        for (index, code_sec) in code_sections.iter().enumerate() {
533            final_funcs.push(Function::new(
534                FuncKind::Local(LocalFunction::new(
535                    functions[index],
536                    FunctionID(imports.num_funcs + index as u32),
537                    (*code_sec).clone(),
538                    types[*functions[index] as usize].params().len(),
539                )),
540                (*code_sec).clone().name,
541            ));
542        }
543
544        // Process the imported memories
545        let mut final_mems = vec![];
546        let mut imp_mem_id = 0;
547        for (index, imp) in imports.iter().enumerate() {
548            if let TypeRef::Memory(ty) = imp.ty {
549                final_mems.push(Memory::new(
550                    ty,
551                    MemKind::Import(ImportedMemory {
552                        import_id: ImportsID(index as u32),
553                        import_mem_id: MemoryID(imp_mem_id),
554                    }),
555                ));
556                imp_mem_id += 1;
557            }
558        }
559        // Process the Local memories
560        for (index, ty) in memories.iter().enumerate() {
561            final_mems.push(Memory::new(
562                ty.to_owned(),
563                MemKind::Local(LocalMemory {
564                    mem_id: MemoryID(imports.num_memories + index as u32),
565                }),
566            ));
567        }
568
569        let num_globals = globals.len() as u32;
570        let num_memories = memories.len() as u32;
571        let num_tables = tables.len() as u32;
572        let module_globals = ModuleGlobals::new(&imports, globals);
573        Ok(Module {
574            types: ModuleTypes::new(types, recgroup_map),
575            imports,
576            functions: Functions::new(final_funcs),
577            tables: ModuleTables::new(tables),
578            memories: Memories::new(final_mems),
579            globals: module_globals,
580            exports: ModuleExports::new(exports),
581            start,
582            elements,
583            data_count_section_exists: data_section_count.is_some(),
584            // code_sections: code_sections.clone(),
585            data,
586            tags,
587            custom_sections: CustomSections::new(custom_sections),
588            num_local_functions: code_sections.len() as u32,
589            num_local_globals: num_globals,
590            num_local_tables: num_tables,
591            num_local_memories: num_memories,
592            module_name,
593            local_names,
594            type_names,
595            table_names,
596            elem_names,
597            memory_names,
598            global_names,
599            data_names,
600            field_names,
601            tag_names,
602            label_names,
603        })
604    }
605
606    /// Creates Vec of (Function, Number of Instructions)
607    pub fn get_func_metadata(&self) -> Vec<(FunctionID, usize)> {
608        let mut metadata = vec![];
609        for func in self.functions.iter() {
610            match &func.kind {
611                FuncKind::Import(_) => {}
612                FuncKind::Local(LocalFunction { func_id, body, .. }) => {
613                    metadata.push((*func_id, body.num_instructions));
614                }
615            }
616        }
617        metadata
618    }
619
620    /// Emit the module into a wasm binary file.
621    pub fn emit_wasm(&mut self, file_name: &str) -> Result<(), std::io::Error> {
622        let module = self.encode_internal();
623        let wasm = module.finish();
624        std::fs::write(file_name, wasm)?;
625        Ok(())
626    }
627
628    /// Encode the module into a wasm binary.
629    ///
630    /// # Example
631    ///
632    /// ```no_run
633    /// use orca_wasm::Module;
634    ///
635    /// let file = "path_to_file";
636    /// let buff = wat::parse_file(file).expect("couldn't convert the input wat to Wasm");
637    /// let mut module = Module::parse(&buff, false).unwrap();
638    /// let result = module.encode();
639    /// ```
640    pub fn encode(&mut self) -> Vec<u8> {
641        self.encode_internal().finish()
642    }
643
644    /// Visits the Orca Module and resolves the special instrumentation by
645    /// translating them into the straightforward before/after/alt modes.
646    fn resolve_special_instrumentation(&mut self) {
647        if !self.num_local_functions > 0 {
648            for rel_func_idx in (self.imports.num_funcs - self.imports.num_funcs_added) as usize
649                ..self.functions.len()
650            {
651                let func_idx = FunctionID(rel_func_idx as u32);
652                if let FuncKind::Import(..) = &self.functions.get_kind(func_idx) {
653                    // skip imports
654                    continue;
655                }
656
657                let mut instr_func_on_entry = None;
658                let mut instr_func_on_exit = None;
659                if let FuncKind::Local(LocalFunction { instr_flag, .. }) =
660                    self.functions.get_kind_mut(func_idx)
661                {
662                    if !instr_flag.has_special_instr {
663                        // skip functions without special instrumentation!
664                        continue;
665                    }
666
667                    // save off the function entry/exit special mode bodies
668                    if !instr_flag.entry.is_empty() {
669                        instr_func_on_entry = Some(instr_flag.entry.to_owned());
670                        instr_flag.entry = vec![];
671                    }
672                    if !instr_flag.exit.is_empty() {
673                        instr_func_on_exit = Some(instr_flag.exit.to_owned());
674                        instr_flag.exit = vec![];
675                    }
676                }
677
678                // initialize with 0 to store the func block!
679                let mut block_stack: Vec<BlockID> = vec![0];
680                let mut delete_block: Option<BlockID> = None;
681                let mut retain_end = true;
682                let mut resolve_on_else_or_end: HashMap<InstrumentationMode, InstrToInject> =
683                    HashMap::new();
684                let mut resolve_on_end: HashMap<
685                    BlockID,
686                    HashMap<InstrumentationMode, InstrToInject>,
687                > = HashMap::new();
688                if let Some(on_exit) = &mut instr_func_on_exit {
689                    if !on_exit.is_empty() {
690                        let on_entry = if let Some(on_entry) = &mut instr_func_on_entry {
691                            on_entry
692                        } else {
693                            let on_entry = vec![];
694                            instr_func_on_entry = Some(on_entry);
695                            if let Some(ref mut on_entry) = instr_func_on_entry {
696                                on_entry
697                            } else {
698                                panic!()
699                            }
700                        };
701
702                        let func_ty = self.functions.get_type_id(func_idx);
703                        let func_results = self.types.get(func_ty).unwrap().results();
704                        let block_ty = self.types.add_func_type(&[], &func_results);
705                        resolve_function_exit_with_block_wrapper(on_entry, block_ty);
706                    }
707                }
708                let mut builder = self.functions.get_fn_modifier(func_idx).unwrap();
709
710                // Must make copy to be able to iterate over body while calling builder.* methods that mutate the instrumentation flag!
711                let readable_copy_of_body = builder.body.instructions.clone();
712                for (
713                    idx,
714                    Instruction {
715                        op,
716                        instr_flag: instrumentation,
717                    },
718                ) in readable_copy_of_body.iter().enumerate()
719                {
720                    // resolve function-level instrumentation
721                    if let Some(on_entry) = &mut instr_func_on_entry {
722                        if !on_entry.is_empty() {
723                            resolve_function_entry(&mut builder, on_entry, idx);
724                        }
725                    }
726                    if let Some(on_exit) = &mut instr_func_on_exit {
727                        if !on_exit.is_empty() {
728                            resolve_function_exit(on_exit, &mut builder, op, idx);
729                        }
730                    }
731
732                    // resolve instruction-level instrumentation
733                    match op {
734                        Operator::Block { .. } | Operator::Loop { .. } | Operator::If { .. } => {
735                            // The block ID will just be the curr len of the stack!
736                            block_stack.push(block_stack.len() as u32);
737
738                            // Handle block alt
739                            if let Some(block_alt) = &instrumentation.block_alt {
740                                // only plan to handle if we're not already removing the block this instr is in
741                                if delete_block.is_none()
742                                    && plan_resolution_block_alt(
743                                        block_alt,
744                                        &mut builder,
745                                        &mut retain_end,
746                                        op,
747                                        idx,
748                                    )
749                                {
750                                    builder.clear_instr_at(
751                                        Location::Module {
752                                            func_idx: FunctionID(0), // not used
753                                            instr_idx: idx,
754                                        },
755                                        BlockAlt,
756                                    );
757                                    // we've got a match, which injected the alt body. continue to the next instruction
758                                    delete_block = Some(*block_stack.last().unwrap());
759                                    continue;
760                                }
761                            }
762
763                            if delete_block.is_some() {
764                                // delete this block and skip all instrumentation handling (like below)
765                                builder.empty_alternate_at(Location::Module {
766                                    func_idx: FunctionID(0), // not used
767                                    instr_idx: idx,
768                                });
769                                continue;
770                            }
771                        }
772                        Operator::Else => {
773                            // necessary for if statements with block_exit instrumentation
774                            for (mode, instr_to_inject) in resolve_on_else_or_end.iter() {
775                                // resolve bodies at the else
776                                resolve_bodies(&mut builder, mode, instr_to_inject, idx);
777                            }
778                            resolve_on_else_or_end.clear();
779
780                            // Handle block alt
781                            if let Some(block_alt) = &instrumentation.block_alt {
782                                // only plan to handle if we're not already removing the block this instr is in
783                                if delete_block.is_none()
784                                    && plan_resolution_block_alt(
785                                        block_alt,
786                                        &mut builder,
787                                        &mut retain_end,
788                                        op,
789                                        idx,
790                                    )
791                                {
792                                    builder.clear_instr_at(
793                                        Location::Module {
794                                            func_idx: FunctionID(0), // not used
795                                            instr_idx: idx,
796                                        },
797                                        BlockAlt,
798                                    );
799                                    // we've got a match, which injected the alt body. continue to the next instruction
800                                    delete_block = Some(*block_stack.last().unwrap());
801                                    continue;
802                                }
803                            }
804
805                            if delete_block.is_some() {
806                                // delete this block and skip all instrumentation handling (like below)
807                                builder.empty_alternate_at(Location::Module {
808                                    func_idx: FunctionID(0), // not used
809                                    instr_idx: idx,
810                                });
811                                continue;
812                            }
813                        }
814                        Operator::End => {
815                            // Pop the stack and check to see if we have instrumentation to inject!
816                            if let Some(block_id) = block_stack.pop() {
817                                if let Some(delete_block_id) = delete_block.as_mut() {
818                                    // Delete the block, but don't remove the end if we say not to
819                                    // should still process instrumentation on the end though...
820                                    // (consider if/else where the else has an alt block)
821                                    if (*delete_block_id).eq(&block_id) {
822                                        // completing the alt block logic, clear state
823                                        delete_block = None;
824                                        if !retain_end {
825                                            // delete this end and skip all instrumentation handling (like below)
826                                            builder.empty_alternate_at(Location::Module {
827                                                func_idx: FunctionID(0), // not used
828                                                instr_idx: idx,
829                                            });
830                                            retain_end = true;
831                                            continue;
832                                        }
833                                        // fall through to the instrumentation handling
834                                        retain_end = true;
835                                    } else {
836                                        // delete this instruction and skip all instrumentation handling (like below)
837                                        builder.empty_alternate_at(Location::Module {
838                                            func_idx: FunctionID(0), // not used
839                                            instr_idx: idx,
840                                        });
841                                        continue;
842                                    }
843                                }
844
845                                // we've reached an end, make sure resolve_on_else is cleared!
846                                // resolve bodies for else OR end
847                                for (mode, instr_to_inject) in resolve_on_else_or_end.iter() {
848                                    resolve_bodies(&mut builder, mode, instr_to_inject, idx);
849                                }
850                                resolve_on_else_or_end.clear();
851
852                                // remove top of stack! (end of vec)
853                                // remove it, so we don't try to re-inject!
854                                if let Some(to_resolve) = resolve_on_end.remove(&block_id) {
855                                    for (mode, instr_to_inject) in to_resolve.iter() {
856                                        // resolve bodies at the end
857                                        resolve_bodies(&mut builder, mode, instr_to_inject, idx);
858                                    }
859                                }
860                            }
861                        }
862                        _ => {
863                            // non block-structured opcodes
864                            if delete_block.is_some() {
865                                // delete this instruction and skip all instrumentation handling (like below)
866                                builder.empty_alternate_at(Location::Module {
867                                    func_idx: FunctionID(0), // not used
868                                    instr_idx: idx,
869                                });
870                                continue;
871                            }
872                        }
873                    }
874
875                    // plan instruction-level instrumentation resolution
876                    // this must go after the above logic to ensure the block_id is on the top of the stack!
877                    if instrumentation.has_instr() {
878                        // this instruction has instrumentation, check if there is any to resolve!
879                        let InstrumentationFlag {
880                            semantic_after,
881                            block_entry,
882                            block_exit,
883                            block_alt: _, // handled before here!
884                            before: _,
885                            after: _,
886                            alternate: _,
887                            current_mode: _,
888                            // exhaustive to help identify where to add code to handle other special modes.
889                        } = instrumentation;
890
891                        // Handle block entry
892                        if !block_entry.is_empty() {
893                            resolve_block_entry(block_entry, &mut builder, op, idx);
894                            builder.clear_instr_at(
895                                Location::Module {
896                                    func_idx: FunctionID(0), // not used
897                                    instr_idx: idx,
898                                },
899                                BlockEntry,
900                            );
901                        }
902
903                        // Handle block exit
904                        if !block_exit.is_empty() {
905                            plan_resolution_block_exit(
906                                block_exit,
907                                &block_stack,
908                                &mut resolve_on_else_or_end,
909                                &mut resolve_on_end,
910                                op,
911                            );
912                            builder.clear_instr_at(
913                                Location::Module {
914                                    func_idx: FunctionID(0), // not used
915                                    instr_idx: idx,
916                                },
917                                BlockExit,
918                            );
919                        }
920
921                        // Handle semantic_after!
922                        if !semantic_after.is_empty() {
923                            plan_resolution_semantic_after(
924                                semantic_after,
925                                &mut builder,
926                                &block_stack,
927                                &mut resolve_on_end,
928                                op,
929                                idx,
930                            );
931                            builder.clear_instr_at(
932                                Location::Module {
933                                    func_idx: FunctionID(0), // not used
934                                    instr_idx: idx,
935                                },
936                                SemanticAfter,
937                            );
938                        }
939                    }
940                }
941            }
942        }
943    }
944
945    /// Reorganises items (both local and imports) in the correct ordering after any potential modifications
946    pub(crate) fn reorganise_generic<T: LocalOrImport, U: ReIndexable<T>>(
947        orig_num_imported: u32,
948        items: &mut U,
949        items_read_only: IntoIter<T>,
950    ) {
951        // Location where we may have to move an import (converted from local) to
952        let mut num_imported = orig_num_imported;
953        let mut num_deleted = 0;
954
955        // Iterate over cloned list
956        for (idx, val) in items_read_only.enumerate() {
957            // If the index is less than < imported
958            if idx < orig_num_imported as usize {
959                // If it is a local, that means it was an import before
960                if val.is_local() {
961                    let f = items.remove((idx - num_deleted) as u32);
962                    items.push(f);
963                    // decrement as this is the place where we might have to move an import to
964                    num_imported -= 1;
965                    // We update it here for the following case. A , B. A is moved to a position later than B, indices will reduce by 1 and we need the offset
966                    num_deleted += 1;
967                } else if val.is_deleted() {
968                    // If val was import but was deleted
969                    items.remove((idx - num_deleted) as u32);
970                    num_imported -= 1;
971                    num_deleted += 1;
972                }
973            } else {
974                // If it's an import, was a local before
975                if val.is_import() {
976                    let i = items.remove((idx - num_deleted) as u32);
977                    items.insert(num_imported, i);
978                    // increment as this is the place where we might have to move an import to
979                    num_imported += 1;
980                    // We do not update it here for the following case. A , B. A is moved to a position earlier than B, indices will not change and hence no need to update
981                    // num_deleted += 1;
982                }
983                // If val was local but was deleted
984                else if val.is_deleted() {
985                    items.remove((idx - num_deleted) as u32);
986                    num_deleted += 1;
987                }
988            }
989        }
990    }
991
992    /// Get the mapping of old ID -> new ID in module
993    pub(crate) fn get_mapping_generic<T: GetID>(
994        slice: std::slice::Iter<'_, T>,
995    ) -> HashMap<u32, u32> {
996        let mut mapping = HashMap::new();
997        for (new_id, item) in slice.enumerate() {
998            let old_id = item.get_id();
999            mapping.insert(old_id, new_id as u32);
1000        }
1001        mapping
1002    }
1003
1004    pub(crate) fn recalculate_ids<T: LocalOrImport + GetID, U: Iter<T> + ReIndexable<T>>(
1005        orig_num_imported: u32,
1006        items: &mut U,
1007    ) -> HashMap<u32, u32> {
1008        let items_read_only = items.get_into_iter();
1009        Self::reorganise_generic(orig_num_imported, items, items_read_only);
1010        let id_mapping = Self::get_mapping_generic(items.iter());
1011        assert_eq!(items.len(), id_mapping.len());
1012        id_mapping
1013    }
1014
1015    fn encode_type(&self, ty: &Types) -> wasm_encoder::SubType {
1016        match ty {
1017            Types::FuncType {
1018                params,
1019                results,
1020                super_type,
1021                is_final,
1022                shared,
1023            } => {
1024                let params = params
1025                    .iter()
1026                    .map(wasm_encoder::ValType::from)
1027                    .collect::<Vec<_>>();
1028                let results = results
1029                    .iter()
1030                    .map(wasm_encoder::ValType::from)
1031                    .collect::<Vec<_>>();
1032                let fty = wasm_encoder::FuncType::new(params, results);
1033                wasm_encoder::SubType {
1034                    is_final: *is_final,
1035                    supertype_idx: match super_type {
1036                        None => None,
1037                        Some(idx) => idx.as_module_index(),
1038                    },
1039                    composite_type: wasm_encoder::CompositeType {
1040                        inner: wasm_encoder::CompositeInnerType::Func(fty),
1041                        shared: *shared,
1042                    },
1043                }
1044            }
1045            Types::ArrayType {
1046                fields,
1047                mutable,
1048                super_type,
1049                is_final,
1050                shared,
1051            } => wasm_encoder::SubType {
1052                is_final: *is_final,
1053                supertype_idx: match super_type {
1054                    None => None,
1055                    Some(idx) => idx.as_module_index(),
1056                },
1057                composite_type: wasm_encoder::CompositeType {
1058                    inner: wasm_encoder::CompositeInnerType::Array(wasm_encoder::ArrayType(
1059                        wasm_encoder::FieldType {
1060                            element_type: wasm_encoder::StorageType::from(*fields),
1061                            mutable: *mutable,
1062                        },
1063                    )),
1064                    shared: *shared,
1065                },
1066            },
1067            Types::StructType {
1068                fields,
1069                mutable,
1070                super_type,
1071                is_final,
1072                shared,
1073            } => {
1074                let mut encoded_fields: Vec<wasm_encoder::FieldType> = vec![];
1075                for (idx, sty) in fields.iter().enumerate() {
1076                    encoded_fields.push(wasm_encoder::FieldType {
1077                        element_type: wasm_encoder::StorageType::from(*sty),
1078                        mutable: mutable[idx],
1079                    });
1080                }
1081                wasm_encoder::SubType {
1082                    is_final: *is_final,
1083                    supertype_idx: match super_type {
1084                        None => None,
1085                        Some(idx) => idx.as_module_index(),
1086                    },
1087                    composite_type: wasm_encoder::CompositeType {
1088                        inner: wasm_encoder::CompositeInnerType::Struct(wasm_encoder::StructType {
1089                            fields: Box::from(encoded_fields),
1090                        }),
1091                        shared: *shared,
1092                    },
1093                }
1094            }
1095            Types::ContType { .. } => {
1096                todo!()
1097            }
1098        }
1099    }
1100
1101    /// Encodes an Orca Module to a wasm_encoder Module.
1102    /// This requires a mutable reference to self due to the special instrumentation resolution step.
1103    pub(crate) fn encode_internal(&mut self) -> wasm_encoder::Module {
1104        // First resolve any instrumentation that needs to be translated to before/after/alt
1105        self.resolve_special_instrumentation();
1106
1107        let func_mapping = if self.functions.recalculate_ids {
1108            Self::recalculate_ids(
1109                self.imports.num_funcs - self.imports.num_funcs_added,
1110                &mut self.functions,
1111            )
1112        } else {
1113            Self::get_mapping_generic(Iter::<Function<'a>>::iter(&self.functions))
1114        };
1115        let global_mapping = if self.globals.recalculate_ids {
1116            Self::recalculate_ids(
1117                self.imports.num_globals - self.imports.num_globals_added,
1118                &mut self.globals,
1119            )
1120        } else {
1121            Self::get_mapping_generic(self.globals.iter())
1122        };
1123        let memory_mapping = if self.memories.recalculate_ids {
1124            Self::recalculate_ids(
1125                self.imports.num_memories - self.imports.num_memories_added,
1126                &mut self.memories,
1127            )
1128        } else {
1129            Self::get_mapping_generic(self.memories.iter())
1130        };
1131
1132        let mut module = wasm_encoder::Module::new();
1133        let mut reencode = RoundtripReencoder;
1134
1135        let new_start = if let Some(start_fn) = self.start {
1136            // fix the start function mapping
1137            match func_mapping.get(&*start_fn) {
1138                Some(new_index) => Some(FunctionID(*new_index)),
1139                None => {
1140                    warn!("Deleted the start function!");
1141                    None
1142                }
1143            }
1144        } else {
1145            None
1146        };
1147        self.start = new_start;
1148
1149        if !self.types.is_empty() {
1150            let mut types = wasm_encoder::TypeSection::new();
1151            let mut last_rg = None;
1152            let mut rg_types = vec![];
1153            for (idx, ty) in self.types.iter().enumerate() {
1154                let curr_rg = self.types.recgroup_map.get(&(idx as u32));
1155                // If current one is not the same as last one and it is not the first rg, encode it
1156                // If it is a new one
1157                if curr_rg != last_rg {
1158                    // If the previous one was an explicit rec group
1159                    if last_rg.is_some() {
1160                        // Encode the last one as a recgroup
1161                        types.ty().rec(rg_types.clone());
1162                        // Reset the vector
1163                        rg_types.clear();
1164                    }
1165                    // If it was not, then it was already encoded
1166                }
1167                match curr_rg {
1168                    // If it is part of an explicit rec group
1169                    Some(_) => {
1170                        rg_types.push(self.encode_type(ty));
1171                        // first_rg = false;
1172                    }
1173                    None => types.ty().subtype(&self.encode_type(ty)),
1174                }
1175                last_rg = curr_rg;
1176            }
1177            // If the last rg was a none, it was encoded in the binary, if it was an explicit rec group, was not encoded
1178            if last_rg.is_some() {
1179                types.ty().rec(rg_types.clone());
1180            }
1181            module.section(&types);
1182        }
1183
1184        // initialize function name section
1185        let mut function_names = wasm_encoder::NameMap::new();
1186        if !self.imports.is_empty() {
1187            let mut imports = wasm_encoder::ImportSection::new();
1188            let mut import_func_idx = 0;
1189            for import in self.imports.iter() {
1190                if !import.deleted {
1191                    if import.is_function() {
1192                        if let Some(import_name) = &import.custom_name {
1193                            function_names.append(import_func_idx as u32, import_name);
1194                        }
1195                        import_func_idx += 1;
1196                    }
1197                    imports.import(
1198                        import.module,
1199                        import.name,
1200                        reencode.entity_type(import.ty).unwrap(),
1201                    );
1202                }
1203            }
1204            module.section(&imports);
1205        }
1206
1207        if !self.functions.is_empty() {
1208            let mut functions = wasm_encoder::FunctionSection::new();
1209            for func in self.functions.iter() {
1210                if !func.deleted {
1211                    if let FuncKind::Local(l) = func.kind() {
1212                        functions.function(*l.ty_id);
1213                    }
1214                }
1215            }
1216            module.section(&functions);
1217        }
1218
1219        if !self.tables.is_empty() {
1220            let mut tables = wasm_encoder::TableSection::new();
1221            for (table_ty, init) in self.tables.iter() {
1222                let table_ty = wasm_encoder::TableType {
1223                    element_type: wasm_encoder::RefType {
1224                        nullable: table_ty.element_type.is_nullable(),
1225                        heap_type: reencode
1226                            .heap_type(table_ty.element_type.heap_type())
1227                            .unwrap(),
1228                    },
1229                    table64: table_ty.table64,
1230                    minimum: table_ty.initial, // TODO - Check if this maps
1231                    maximum: table_ty.maximum,
1232                    shared: table_ty.shared,
1233                };
1234                match init {
1235                    None => tables.table(table_ty),
1236                    Some(const_expr) => tables.table_with_init(
1237                        table_ty,
1238                        &reencode
1239                            .const_expr((*const_expr).clone())
1240                            .expect("Error in Converting Const Expr"),
1241                    ),
1242                };
1243            }
1244            module.section(&tables);
1245        }
1246
1247        if !self.memories.is_empty() {
1248            let mut memories = wasm_encoder::MemorySection::new();
1249            for memory in self.memories.iter() {
1250                if memory.is_local() {
1251                    memories.memory(wasm_encoder::MemoryType::from(memory.ty));
1252                }
1253            }
1254            module.section(&memories);
1255        }
1256
1257        if !self.globals.is_empty() {
1258            let mut globals = wasm_encoder::GlobalSection::new();
1259            for global in self.globals.iter() {
1260                if !global.deleted {
1261                    if let GlobalKind::Local(LocalGlobal { ty, init_expr, .. }) = &global.kind {
1262                        globals.global(
1263                            wasm_encoder::GlobalType {
1264                                val_type: reencode.val_type(ty.content_type).unwrap(),
1265                                mutable: ty.mutable,
1266                                shared: ty.shared,
1267                            },
1268                            &init_expr.to_wasmencoder_type(),
1269                        );
1270                    }
1271                }
1272                // skip imported globals
1273            }
1274            module.section(&globals);
1275        }
1276
1277        if !self.exports.is_empty() {
1278            let mut exports = wasm_encoder::ExportSection::new();
1279            for export in self.exports.iter() {
1280                if !export.deleted {
1281                    match export.kind {
1282                        ExternalKind::Func => {
1283                            // Update the function indices
1284                            exports.export(
1285                                &export.name,
1286                                wasm_encoder::ExportKind::from(export.kind),
1287                                *func_mapping.get(&(export.index)).unwrap(),
1288                            );
1289                        }
1290                        ExternalKind::Memory => {
1291                            // Update the memory indices
1292                            exports.export(
1293                                &export.name,
1294                                wasm_encoder::ExportKind::from(export.kind),
1295                                *memory_mapping.get(&(export.index)).unwrap(),
1296                            );
1297                        }
1298                        _ => {
1299                            exports.export(
1300                                &export.name,
1301                                wasm_encoder::ExportKind::from(export.kind),
1302                                export.index,
1303                            );
1304                        }
1305                    }
1306                }
1307            }
1308            module.section(&exports);
1309        }
1310
1311        if let Some(function_index) = self.start {
1312            module.section(&wasm_encoder::StartSection {
1313                function_index: *function_index,
1314            });
1315        }
1316
1317        if !self.elements.is_empty() {
1318            let mut elements = wasm_encoder::ElementSection::new();
1319            let mut temp_const_exprs = vec![];
1320            let mut element_items = vec![];
1321            for (kind, items) in self.elements.iter() {
1322                temp_const_exprs.clear();
1323                element_items.clear();
1324                let element_items = match &items {
1325                    // TODO: Update the elements section based on additions/deletion
1326                    ElementItems::Functions(funcs) => {
1327                        element_items = funcs
1328                            .iter()
1329                            .map(|f| *func_mapping.get(f).unwrap())
1330                            .collect();
1331                        wasm_encoder::Elements::Functions(Cow::from(element_items.as_slice()))
1332                    }
1333                    ElementItems::ConstExprs { ty, exprs } => {
1334                        temp_const_exprs.reserve(exprs.len());
1335                        for e in exprs.iter() {
1336                            temp_const_exprs.push(
1337                                reencode
1338                                    .const_expr((*e).clone())
1339                                    .expect("Unable to convert element constant expr"),
1340                            );
1341                        }
1342                        wasm_encoder::Elements::Expressions(
1343                            wasm_encoder::RefType {
1344                                nullable: ty.is_nullable(),
1345                                heap_type: reencode.heap_type(ty.heap_type()).unwrap(),
1346                            },
1347                            Cow::from(&temp_const_exprs),
1348                        )
1349                    }
1350                };
1351
1352                match kind {
1353                    ElementKind::Passive => {
1354                        elements.passive(element_items);
1355                    }
1356                    ElementKind::Active {
1357                        table_index,
1358                        offset_expr,
1359                    } => {
1360                        elements.active(
1361                            *table_index,
1362                            &reencode
1363                                .const_expr((*offset_expr).clone())
1364                                .expect("Unable to convert offset expr"),
1365                            element_items,
1366                        );
1367                    }
1368                    ElementKind::Declared => {
1369                        elements.declared(element_items);
1370                    }
1371                }
1372            }
1373            module.section(&elements);
1374        }
1375
1376        if self.data_count_section_exists {
1377            let data_count = wasm_encoder::DataCountSection {
1378                count: self.data.len() as u32,
1379            };
1380            module.section(&data_count);
1381        }
1382
1383        if !self.tags.is_empty() {
1384            let mut tags = TagSection::new();
1385            for tag in self.tags.iter() {
1386                tags.tag(wasm_encoder::TagType {
1387                    kind: wasm_encoder::TagKind::from(tag.kind),
1388                    func_type_idx: tag.func_type_idx,
1389                });
1390            }
1391            module.section(&tags);
1392        }
1393
1394        if !self.num_local_functions > 0 {
1395            let mut code = wasm_encoder::CodeSection::new();
1396            for rel_func_idx in 0..self.functions.len() {
1397                if self.functions.is_deleted(FunctionID(rel_func_idx as u32)) {
1398                    continue;
1399                }
1400                if let FuncKind::Import(_) =
1401                    &self.functions.get_kind(FunctionID(rel_func_idx as u32))
1402                {
1403                    continue;
1404                }
1405
1406                let func = self
1407                    .functions
1408                    .get_mut(FunctionID(rel_func_idx as u32))
1409                    .unwrap_local_mut();
1410                let Body {
1411                    instructions,
1412                    locals,
1413                    name,
1414                    ..
1415                } = &mut func.body;
1416                let mut converted_locals = Vec::with_capacity(locals.len());
1417                for (c, ty) in locals {
1418                    converted_locals.push((*c, wasm_encoder::ValType::from(&*ty)));
1419                }
1420                let mut function = wasm_encoder::Function::new(converted_locals);
1421                let instr_len = instructions.len() - 1;
1422                for (
1423                    idx,
1424                    Instruction {
1425                        op,
1426                        instr_flag: instrument,
1427                    },
1428                ) in instructions.iter_mut().enumerate()
1429                {
1430                    if refers_to_func(op) {
1431                        update_fn_instr(op, &func_mapping);
1432                    }
1433                    if refers_to_global(op) {
1434                        update_global_instr(op, &global_mapping);
1435                    }
1436                    if refers_to_memory(op) {
1437                        update_memory_instr(op, &memory_mapping);
1438                    }
1439                    if !instrument.has_instr() {
1440                        encode(&op.clone(), &mut function, &mut reencode);
1441                    } else {
1442                        // this instruction has instrumentation, handle it!
1443                        let InstrumentationFlag {
1444                            current_mode: _current_mode,
1445                            before,
1446                            after,
1447                            alternate,
1448                            semantic_after,
1449                            block_entry,
1450                            block_exit,
1451                            block_alt,
1452                        } = instrument;
1453
1454                        // Check if special instrumentation modes have been resolved!
1455                        if !semantic_after.is_empty() {
1456                            error!("BUG: Semantic after instrumentation should be resolved already, please report.");
1457                        }
1458                        if !block_entry.is_empty() {
1459                            error!("BUG: Block entry instrumentation should be resolved already, please report.");
1460                        }
1461                        if !block_exit.is_empty() {
1462                            error!("BUG: Block exit instrumentation should be resolved already, please report.");
1463                        }
1464                        if !block_alt.is_none() {
1465                            error!("BUG: Block alt instrumentation should be resolved already, please report.");
1466                        }
1467                        // If we're at the `end` of the function, drop this instrumentation
1468                        let at_end = idx >= instr_len;
1469
1470                        // First encode before instructions
1471                        update_ids_and_encode(
1472                            before,
1473                            &func_mapping,
1474                            &global_mapping,
1475                            &memory_mapping,
1476                            &mut function,
1477                            &mut reencode,
1478                        );
1479
1480                        // If there are any alternate, encode the alternate
1481                        if !at_end && !alternate.is_none() {
1482                            if let Some(alt) = alternate {
1483                                update_ids_and_encode(
1484                                    alt,
1485                                    &func_mapping,
1486                                    &global_mapping,
1487                                    &memory_mapping,
1488                                    &mut function,
1489                                    &mut reencode,
1490                                );
1491                            }
1492                        } else {
1493                            encode(&op.clone(), &mut function, &mut reencode);
1494                        }
1495
1496                        // Now encode the after instructions
1497                        if !at_end {
1498                            update_ids_and_encode(
1499                                after,
1500                                &func_mapping,
1501                                &global_mapping,
1502                                &memory_mapping,
1503                                &mut function,
1504                                &mut reencode,
1505                            );
1506                        }
1507                    }
1508
1509                    fn update_ids_and_encode(
1510                        instrs: &mut Vec<Operator>,
1511                        func_mapping: &HashMap<u32, u32>,
1512                        global_mapping: &HashMap<u32, u32>,
1513                        memory_mapping: &HashMap<u32, u32>,
1514                        function: &mut wasm_encoder::Function,
1515                        reencode: &mut RoundtripReencoder,
1516                    ) {
1517                        for instr in instrs {
1518                            if refers_to_func(instr) {
1519                                update_fn_instr(instr, func_mapping);
1520                            }
1521                            if refers_to_global(instr) {
1522                                update_global_instr(instr, global_mapping);
1523                            }
1524                            if refers_to_memory(instr) {
1525                                update_memory_instr(instr, memory_mapping);
1526                            }
1527                            encode(instr, function, reencode);
1528                        }
1529                    }
1530                    fn encode(
1531                        instr: &Operator,
1532                        function: &mut wasm_encoder::Function,
1533                        reencode: &mut RoundtripReencoder,
1534                    ) {
1535                        function.instruction(
1536                            &reencode
1537                                .instruction(instr.clone())
1538                                .expect("Unable to convert Instruction"),
1539                        );
1540                    }
1541                }
1542                if let Some(name) = name {
1543                    function_names.append(rel_func_idx as u32, name.as_str());
1544                }
1545                code.function(&function);
1546            }
1547            module.section(&code);
1548        }
1549
1550        if !self.data.is_empty() {
1551            let mut data = wasm_encoder::DataSection::new();
1552            for segment in self.data.iter_mut() {
1553                let segment_data = segment.data.iter().copied();
1554                match &mut segment.kind {
1555                    DataSegmentKind::Passive => data.passive(segment_data),
1556                    DataSegmentKind::Active {
1557                        memory_index,
1558                        offset_expr,
1559                    } => {
1560                        let new_idx = match memory_mapping.get(memory_index) {
1561                            Some(new_index) => *new_index,
1562                            None => panic!(
1563                                "Attempting to reference a deleted memory, ID: {}",
1564                                memory_index
1565                            ),
1566                        };
1567                        data.active(new_idx, &offset_expr.to_wasmencoder_type(), segment_data)
1568                    }
1569                };
1570            }
1571            module.section(&data);
1572        }
1573
1574        // the name section is not stored in self.custom_sections anymore
1575        let mut names = wasm_encoder::NameSection::new();
1576
1577        if let Some(module_name) = &self.module_name {
1578            names.module(module_name);
1579        }
1580        names.functions(&function_names);
1581        names.locals(&self.local_names);
1582        names.labels(&self.label_names);
1583        names.types(&self.type_names);
1584        names.tables(&self.table_names);
1585        names.memories(&self.memory_names);
1586        names.globals(&self.global_names);
1587        names.elements(&self.elem_names);
1588        names.data(&self.data_names);
1589        names.fields(&self.field_names);
1590        names.tag(&self.tag_names);
1591
1592        module.section(&names);
1593
1594        // encode the rest of custom sections
1595        for section in self.custom_sections.iter() {
1596            module.section(&wasm_encoder::CustomSection {
1597                name: std::borrow::Cow::Borrowed(section.name),
1598                data: std::borrow::Cow::Borrowed(section.data),
1599            });
1600        }
1601
1602        module
1603    }
1604
1605    /// Add a new Data Segment to the module.
1606    /// Returns the index of the new Data Segment in the Data Section.
1607    pub fn add_data(&mut self, data: DataSegment) -> DataSegmentID {
1608        let index = self.data.len();
1609        self.data.push(data);
1610        DataSegmentID(index as u32)
1611    }
1612
1613    /// Get the memory ID of a module. Does not support multiple memories
1614    pub fn get_memory_id(&self) -> Option<MemoryID> {
1615        if self.memories.len() > 1 {
1616            panic!("multiple memories unsupported")
1617        }
1618
1619        if !self.memories.is_empty() {
1620            return Some(MemoryID(0));
1621        }
1622        // module does not export a memory
1623        None
1624    }
1625
1626    // ==============================
1627    // ==== Module Manipulations ====
1628    // ==============================
1629
1630    pub(crate) fn add_import(&mut self, import: Import<'a>) -> (u32, ImportsID) {
1631        let (num_local, num_imported, num_total) = match import.ty {
1632            TypeRef::Func(..) => (
1633                self.num_local_functions,
1634                self.imports.num_funcs,
1635                self.functions.len() as u32,
1636            ),
1637            TypeRef::Global(..) => (
1638                self.num_local_globals,
1639                self.imports.num_globals,
1640                self.globals.len() as u32,
1641            ),
1642            TypeRef::Table(..) => todo!(),
1643            TypeRef::Tag(..) => todo!(),
1644            TypeRef::Memory(..) => (
1645                self.num_local_memories,
1646                self.imports.num_memories,
1647                self.memories.len() as u32,
1648            ),
1649        };
1650
1651        let id = if num_local > 0 {
1652            num_total
1653        } else {
1654            num_imported
1655        };
1656        (id, self.imports.add(import))
1657    }
1658
1659    // ===========================
1660    // ==== Memory Management ====
1661    // ===========================
1662
1663    pub fn add_local_memory(&mut self, ty: MemoryType) -> MemoryID {
1664        let local_mem = LocalMemory {
1665            mem_id: MemoryID(0), // will be fixed
1666        };
1667
1668        self.num_local_memories += 1;
1669        self.memories.add_local_mem(local_mem, ty)
1670    }
1671
1672    pub fn add_import_memory(
1673        &mut self,
1674        module: String,
1675        name: String,
1676        ty: MemoryType,
1677    ) -> (MemoryID, ImportsID) {
1678        let (imp_mem_id, imp_id) = self.add_import(Import {
1679            module: module.leak(),
1680            name: name.clone().leak(),
1681            ty: TypeRef::Memory(ty),
1682            custom_name: None,
1683            deleted: false,
1684        });
1685
1686        // Add to memories as well as it has imported memories
1687        self.memories.add_import_mem(imp_id, ty, imp_mem_id);
1688        (MemoryID(imp_mem_id), imp_id)
1689    }
1690
1691    /// Delete a memory from the module.
1692    pub fn delete_memory(&mut self, mem_id: MemoryID) {
1693        self.memories.delete(mem_id);
1694        if let MemKind::Import(ImportedMemory { import_id, .. }) = self.memories.get_kind(mem_id) {
1695            self.imports.delete(*import_id);
1696        }
1697    }
1698
1699    // =============================
1700    // ==== Function Management ====
1701    // =============================
1702
1703    pub(crate) fn add_local_func(
1704        &mut self,
1705        name: Option<String>,
1706        params: &[DataType],
1707        results: &[DataType],
1708        body: Body<'a>,
1709    ) -> FunctionID {
1710        let ty = self.types.add_func_type(params, results);
1711        let local_func = LocalFunction::new(
1712            ty,
1713            FunctionID(0), // will be fixed
1714            body,
1715            params.len(),
1716        );
1717
1718        self.num_local_functions += 1;
1719        self.functions.add_local_func(local_func, name.clone())
1720    }
1721
1722    /// Add a new function to the module, returns:
1723    ///
1724    /// - FunctionID: The ID that indexes into the function ID space. To be used when referring to the function, like in `call`.
1725    /// - ImportsID: The ID that indexes into the import section.
1726    pub fn add_import_func(
1727        &mut self,
1728        module: String,
1729        name: String,
1730        ty_id: TypeID,
1731    ) -> (FunctionID, ImportsID) {
1732        let (imp_fn_id, imp_id) = self.add_import(Import {
1733            module: module.leak(),
1734            name: name.clone().leak(),
1735            ty: TypeRef::Func(*ty_id),
1736            custom_name: None,
1737            deleted: false,
1738        });
1739
1740        // Add to functions as well as it has imported functions
1741        self.functions
1742            .add_import_func(imp_id, ty_id, Some(name), imp_fn_id);
1743        (FunctionID(imp_fn_id), imp_id)
1744    }
1745
1746    /// Get the number of imported functions in the module (including any added ones).
1747    pub fn num_import_func(&self) -> u32 {
1748        self.imports.num_funcs
1749    }
1750
1751    /// Delete a function from the module.
1752    pub fn delete_func(&mut self, function_id: FunctionID) {
1753        self.functions.delete(function_id);
1754        if let FuncKind::Import(ImportedFunction { import_id, .. }) =
1755            self.functions.get_kind(function_id)
1756        {
1757            self.imports.delete(*import_id);
1758        }
1759    }
1760
1761    /// Convert an imported function to a local function.
1762    /// The function ID inside the `local_function` parameter should equal the `imports_id` specified.
1763    /// Continue using the ImportsID as normal (like in `call` instructions), this library will take care of ID changes for you during encoding.
1764    /// Returns false if it is a local function.
1765    pub(crate) fn convert_import_fn_to_local(
1766        &mut self,
1767        import_id: ImportsID,
1768        local_function: LocalFunction<'a>,
1769    ) -> bool {
1770        if self.functions.is_local(FunctionID(*import_id)) {
1771            warn!("This is a local function!");
1772            return false;
1773        }
1774        self.delete_func(FunctionID(*import_id));
1775        self.functions
1776            .get_mut(FunctionID(*import_id))
1777            .set_kind(FuncKind::Local(local_function));
1778        true
1779    }
1780
1781    /// Convert a local function to an imported function.
1782    /// Continue using the FunctionID as normal (like in `call` instructions), this library will take care of ID changes for you during encoding.
1783    /// Returns false if it is an imported function.
1784    pub fn convert_local_fn_to_import(
1785        &mut self,
1786        function_id: FunctionID,
1787        module: String,
1788        name: String,
1789        ty_id: TypeID,
1790    ) -> bool {
1791        if self.functions.is_import(function_id) {
1792            warn!("This is an imported function!");
1793            return false;
1794        }
1795        // Delete the associated function
1796        self.delete_func(function_id);
1797        // Add import function to imports
1798        let (.., import_id) = self.add_import(Import {
1799            module: module.leak(),
1800            name: name.clone().leak(),
1801            ty: TypeRef::Func(*ty_id),
1802            custom_name: None,
1803            deleted: false,
1804        });
1805        self.functions
1806            .get_mut(function_id)
1807            .set_kind(FuncKind::Import(ImportedFunction {
1808                import_id,
1809                import_fn_id: function_id,
1810                ty_id,
1811            }));
1812        assert!(self.functions.set_imported_fn_name(function_id, name));
1813        true
1814    }
1815
1816    /// Set the name of a function using its ID.
1817    pub fn set_fn_name(&mut self, id: FunctionID, name: String) {
1818        if *id < self.imports.num_funcs {
1819            // the function is an import
1820            self.imports.set_fn_name(name.clone(), id);
1821            assert!(self.functions.set_imported_fn_name(id, name));
1822        } else {
1823            // the function is local
1824            assert!(self.functions.set_local_fn_name(id, name));
1825        }
1826    }
1827
1828    // =============================
1829    // ==== Globals Management ====
1830    // =============================
1831
1832    /// Add a new global to the module.
1833    pub(crate) fn add_global_internal(&mut self, global: Global) -> GlobalID {
1834        self.num_local_globals += 1;
1835        self.globals.add(global)
1836    }
1837
1838    /// Create a new locally-defined global and add it to the module.
1839    pub fn add_global(
1840        &mut self,
1841        init_expr: InitExpr,
1842        content_ty: DataType,
1843        mutable: bool,
1844        shared: bool,
1845    ) -> GlobalID {
1846        self.add_global_internal(Global {
1847            kind: GlobalKind::Local(LocalGlobal {
1848                global_id: GlobalID(0), // gets set in `add`
1849                ty: GlobalType {
1850                    mutable,
1851                    content_type: wasmparser::ValType::from(&content_ty),
1852                    shared,
1853                },
1854                init_expr,
1855            }),
1856            deleted: false,
1857        })
1858    }
1859
1860    /// Add a new imported global to the module, returns:
1861    ///
1862    /// - GlobalID: The ID that indexes into the global ID space. To be used when referring to the global, like in `global.get`.
1863    /// - ImportsID: The ID that indexes into the import section.
1864    pub fn add_imported_global(
1865        &mut self,
1866        module: String,
1867        name: String,
1868        content_ty: DataType,
1869        mutable: bool,
1870        shared: bool,
1871    ) -> (GlobalID, ImportsID) {
1872        let global_ty = GlobalType {
1873            mutable,
1874            content_type: wasmparser::ValType::from(&content_ty),
1875            shared,
1876        };
1877        let (imp_global_id, imp_id) = self.add_import(Import {
1878            module: module.leak(),
1879            name: name.leak(),
1880            ty: TypeRef::Global(global_ty),
1881            custom_name: None,
1882            deleted: false,
1883        });
1884
1885        // Add to globals as well since it has imported globals
1886        self.add_global_internal(Global::new(GlobalKind::Import(ImportedGlobal::new(
1887            imp_id,
1888            GlobalID(imp_global_id),
1889            global_ty,
1890        ))));
1891        self.globals.recalculate_ids = true;
1892        (GlobalID(imp_global_id), imp_id)
1893    }
1894
1895    /// Delete a global from the module (can either be an imported or locally-defined global).
1896    /// Use the global ID for this operation, not the import ID!
1897    pub fn delete_global(&mut self, global_id: GlobalID) {
1898        self.globals.delete(global_id);
1899        if let GlobalKind::Import(ImportedGlobal { import_id, .. }) =
1900            self.globals.get_kind(global_id)
1901        {
1902            self.imports.delete(*import_id);
1903        }
1904    }
1905
1906    /// Change a locally-defined global's init expression.
1907    pub fn mod_global_init_expr(&mut self, global_id: GlobalID, new_expr: InitExpr) {
1908        self.globals.mod_global_init_expr(*global_id, new_expr);
1909    }
1910}
1911
1912pub trait GetID {
1913    fn get_id(&self) -> u32;
1914}
1915
1916/// Facilitates iteration on types that hold `T`
1917pub(crate) trait Iter<T> {
1918    /// Iterate over references of `T`
1919    fn iter(&self) -> std::slice::Iter<'_, T>;
1920
1921    /// Clone and build an iterator
1922    fn get_into_iter(&self) -> IntoIter<T>;
1923}
1924
1925pub(crate) trait ReIndexable<T> {
1926    fn len(&self) -> usize;
1927    fn remove(&mut self, id: u32) -> T;
1928    fn insert(&mut self, id: u32, val: T);
1929    fn push(&mut self, item: T);
1930}
1931
1932pub trait Push<T> {
1933    fn push(&mut self, val: T);
1934}
1935
1936pub trait LocalOrImport {
1937    fn is_local(&self) -> bool;
1938    fn is_import(&self) -> bool;
1939    fn is_deleted(&self) -> bool;
1940}
1941
1942// ================================
1943// ==== Semantic After Helpers ====
1944// ================================
1945
1946type BlockID = u32;
1947type InstrBody<'a> = Vec<Operator<'a>>;
1948struct InstrBodyFlagged<'a> {
1949    body: InstrBody<'a>,
1950    bool_flag: LocalID,
1951}
1952struct InstrToInject<'a> {
1953    flagged: Vec<InstrBodyFlagged<'a>>,
1954    not_flagged: Vec<InstrBody<'a>>,
1955}
1956
1957fn resolve_function_entry<'a, 'b, 'c>(
1958    builder: &mut FunctionModifier<'a, 'b>,
1959    instr_func_on_entry: &mut InstrBody<'c>,
1960    idx: usize,
1961) where
1962    'c: 'b,
1963{
1964    if idx == 0 {
1965        // we're at the function entry!
1966        builder.before_at(Location::Module {
1967            func_idx: FunctionID(0), // not used
1968            instr_idx: idx,
1969        });
1970        builder.inject_all(instr_func_on_entry);
1971
1972        // remove the contents of the body now that it's been resolved
1973        instr_func_on_entry.clear();
1974    }
1975}
1976
1977fn resolve_function_exit_with_block_wrapper<'a, 'b, 'c>(
1978    instr_func_on_entry: &mut InstrBody<'c>,
1979    block_ty: TypeID,
1980) where
1981    'c: 'b,
1982{
1983    // To handle `br*` AND fallthrough:
1984    // Since the relative depth of a branch target
1985    // cannot exceed its current depth, we can just
1986    // wrap the function body in a block and put the
1987    // `exit` instrumentation AFTER the block's `end`.
1988
1989    // to be handled on resolving func_entry
1990    instr_func_on_entry.push(Block {
1991        blockty: wasmparser::BlockType::from(BlockType::FuncType(block_ty)),
1992    });
1993}
1994fn resolve_function_exit<'a, 'b, 'c>(
1995    instr_func_on_exit: &mut InstrBody<'c>,
1996    builder: &mut FunctionModifier<'a, 'b>,
1997    op: &Operator,
1998    idx: usize,
1999) where
2000    'c: 'b,
2001{
2002    // To handle `return`:
2003    // Place a copy of `exit` BEFORE the `return`
2004    // Place a copy of `exit` BEFORE the `return_call`
2005    // Place a copy of `exit` BEFORE the `return_call_indirect`
2006    // Place a copy of `exit` BEFORE the `return_call_ref`
2007
2008    // To handle `traps`:
2009    // Place a copy of `exit` BEFORE the `unreachable`
2010    // Place a copy of `exit` BEFORE the `throw`
2011    // Place a copy of `exit` BEFORE the `rethrow`
2012    // Place a copy of `exit` BEFORE the `throw_ref`
2013    // Place a copy of `exit` BEFORE the `resume_throw`
2014
2015    // convert instr to simple before/after/alt
2016    match op {
2017        // handle returns
2018        Operator::Return { .. } |
2019            Operator::ReturnCall {..} |
2020            Operator::ReturnCallIndirect {..} |
2021            Operator::ReturnCallRef {..} |
2022
2023        // handle traps
2024        Operator::Unreachable |
2025            Operator::Throw {..} |
2026            Operator::Rethrow {..} |
2027            Operator::ThrowRef |
2028            Operator::ResumeThrow {..} => {
2029            // just inject immediately before the instruction
2030            builder.before_at(Location::Module {
2031                func_idx: FunctionID(0), // not used
2032                instr_idx: idx,
2033            });
2034            builder.inject_all(instr_func_on_exit);
2035
2036            // no need to do next part if we've injected!
2037            return
2038        }
2039        _ => {}
2040    }
2041
2042    // Handles the actual injection of the added block's `end`
2043    // and places instr block afterward!
2044    if idx == builder.body.instructions.len() - 1 {
2045        // we're at the end of the function!
2046        builder.before_at(Location::Module {
2047            func_idx: FunctionID(0), // not used
2048            instr_idx: idx,
2049        });
2050        builder.end(); // end the added wrapper block!
2051        builder.inject_all(instr_func_on_exit);
2052
2053        // remove the contents of the body now that it's been resolved
2054        instr_func_on_exit.clear();
2055    }
2056}
2057
2058fn resolve_block_entry<'a, 'b, 'c>(
2059    block_entry: &InstrBody<'c>,
2060    builder: &mut FunctionModifier<'a, 'b>,
2061    op: &Operator,
2062    idx: usize,
2063) where
2064    'c: 'b,
2065{
2066    // convert instr to simple before/after/alt
2067    match op {
2068        Operator::Block { .. }
2069        | Operator::Loop { .. }
2070        | Operator::If { .. }
2071        | Operator::Else { .. } => {
2072            // just inject immediately after the start of the block
2073            builder.after_at(Location::Module {
2074                func_idx: FunctionID(0), // not used
2075                instr_idx: idx,
2076            });
2077            builder.inject_all(block_entry);
2078
2079            // no need to remove the contents of block_entry since we're actually
2080            // using a read-only copy!
2081        }
2082        _ => {
2083            // no need to remove the contents of block_entry since we're actually
2084            // using a read-only copy!
2085        }
2086    }
2087}
2088
2089fn plan_resolution_block_exit<'a, 'b, 'c>(
2090    block_exit: &InstrBody<'c>,
2091    block_stack: &[BlockID],
2092    resolve_on_else_or_end: &mut HashMap<InstrumentationMode, InstrToInject<'c>>,
2093    resolve_on_end: &mut HashMap<BlockID, HashMap<InstrumentationMode, InstrToInject<'c>>>,
2094    op: &Operator,
2095) where
2096    'c: 'b,
2097{
2098    // save instrumentation to be converted to simple before/after/alt
2099    match op {
2100        Operator::If { .. } => {
2101            save_not_flagged_body_to_resolve_inner(
2102                resolve_on_else_or_end,
2103                InstrumentationMode::Before,
2104                block_exit,
2105            );
2106        }
2107        Operator::Block { .. } | Operator::Loop { .. } | Operator::Else { .. } => {
2108            // add body-to-inject as non-flagged
2109            let block_id = block_stack.last().unwrap(); // should always have something (e.g. func block)
2110            save_not_flagged_body_to_resolve(
2111                resolve_on_end,
2112                InstrumentationMode::Before,
2113                block_exit,
2114                *block_id,
2115            );
2116        }
2117        _ => {} // skip all other opcodes
2118    }
2119}
2120
2121fn plan_resolution_block_alt<'a, 'b, 'c>(
2122    block_alt: &InstrBody<'c>,
2123    builder: &mut FunctionModifier<'a, 'b>,
2124    retain_end: &mut bool,
2125    op: &Operator,
2126    idx: usize,
2127) -> bool
2128where
2129    'c: 'b,
2130{
2131    // convert instr to simple before/after/alt
2132    let mut matched = false;
2133    match op {
2134        Operator::Block { .. }
2135        | Operator::Loop { .. }
2136        | Operator::If { .. }
2137        | Operator::Else { .. } => {
2138            let loc = Location::Module {
2139                func_idx: FunctionID(0), // not used
2140                instr_idx: idx,
2141            };
2142            if !block_alt.is_empty() {
2143                // just inject immediately after the start of the block
2144                builder.alternate_at(loc);
2145                builder.inject_all(block_alt);
2146            } else {
2147                // remove the instruction!
2148                builder.empty_alternate_at(loc);
2149            }
2150
2151            // no need to remove the contents of block_alt since we're actually
2152            // using a read-only copy!
2153
2154            matched = true;
2155            *retain_end = false;
2156        }
2157        _ => {}
2158    }
2159    if let Operator::Else { .. } = op {
2160        // We want to keep the end for the module to still be valid (the if will be dangling)
2161        *retain_end = true;
2162    }
2163    matched
2164}
2165
2166fn plan_resolution_semantic_after<'a, 'b, 'c>(
2167    semantic_after: &InstrBody<'c>,
2168    builder: &mut FunctionModifier<'a, 'b>,
2169    block_stack: &[BlockID],
2170    resolve_on_end: &mut HashMap<BlockID, HashMap<InstrumentationMode, InstrToInject<'c>>>,
2171    op: &Operator,
2172    idx: usize,
2173) where
2174    'c: 'b,
2175{
2176    // save instrumentation to be converted to simple before/after/alt
2177    match op {
2178        Operator::Block { .. }
2179        | Operator::Loop { .. }
2180        | Operator::If { .. }
2181        | Operator::Else { .. } => {
2182            // add body-to-inject as non-flagged
2183            let block_id = block_stack.last().unwrap(); // should always have something (e.g. func block)
2184            save_not_flagged_body_to_resolve(
2185                resolve_on_end,
2186                InstrumentationMode::After,
2187                semantic_after,
2188                *block_id,
2189            );
2190        }
2191        Operator::BrTable { targets } => {
2192            let bool_flag_id = create_bool_flag(builder, idx, op, semantic_after);
2193            targets.targets().for_each(|target| {
2194                if let Ok(relative_depth) = target {
2195                    save_flagged_body_to_resolve(
2196                        resolve_on_end,
2197                        InstrumentationMode::After,
2198                        semantic_after,
2199                        bool_flag_id,
2200                        relative_depth,
2201                        *block_stack.last().unwrap(),
2202                    );
2203                }
2204            });
2205            // handle the default as well
2206            save_flagged_body_to_resolve(
2207                resolve_on_end,
2208                InstrumentationMode::After,
2209                semantic_after,
2210                bool_flag_id,
2211                targets.default(),
2212                *block_stack.last().unwrap(),
2213            );
2214        }
2215        Operator::Br { relative_depth }
2216        | Operator::BrIf { relative_depth }
2217        | Operator::BrOnCast { relative_depth, .. }
2218        | Operator::BrOnCastFail { relative_depth, .. }
2219        | Operator::BrOnNonNull { relative_depth }
2220        | Operator::BrOnNull { relative_depth } => {
2221            let bool_flag_id = create_bool_flag(builder, idx, op, semantic_after);
2222            save_flagged_body_to_resolve(
2223                resolve_on_end,
2224                InstrumentationMode::After,
2225                semantic_after,
2226                bool_flag_id,
2227                *relative_depth,
2228                *block_stack.last().unwrap(),
2229            );
2230        }
2231        _ => {} // skip all other opcodes
2232    }
2233}
2234
2235fn create_bool_flag<'a, 'b, 'c>(
2236    builder: &mut FunctionModifier<'a, 'b>,
2237    idx: usize,
2238    op: &Operator,
2239    semantic_after: &Vec<Operator<'c>>,
2240) -> LocalID
2241where
2242    'c: 'b,
2243{
2244    // add body-to-inject as flagged
2245    let bool_flag_id = add_local(
2246        DataType::I32,
2247        builder.args.len(),
2248        &mut builder.body.num_locals,
2249        &mut builder.body.locals,
2250    );
2251
2252    // set flag to true before the opcode
2253    builder
2254        .before_at(Location::Module {
2255            func_idx: FunctionID(0), // not used
2256            instr_idx: idx,
2257        })
2258        .i32_const(1)
2259        .local_set(bool_flag_id);
2260
2261    // set flag to false after the opcode
2262    builder
2263        .after_at(Location::Module {
2264            func_idx: FunctionID(0), // not used
2265            instr_idx: idx,
2266        })
2267        .i32_const(0)
2268        .local_set(bool_flag_id);
2269
2270    // BrIf, BrOnCast, BrOnNonNull, BrOnNull
2271    // the bodies should be inserted immediately after too!
2272    // This is because there is a possibility of fallthrough.
2273    // The body will not be executed 2x since the flag is set
2274    // to `false` on fallthrough!
2275    match op {
2276        Operator::BrIf { .. }
2277        | Operator::BrOnCast { .. }
2278        | Operator::BrOnCastFail { .. }
2279        | Operator::BrOnNonNull { .. }
2280        | Operator::BrOnNull { .. } => {
2281            builder.inject_all(semantic_after.as_slice());
2282        }
2283        _ => {}
2284    }
2285    bool_flag_id
2286}
2287
2288fn save_not_flagged_body_to_resolve<'a>(
2289    resolve_on_end: &mut HashMap<BlockID, HashMap<InstrumentationMode, InstrToInject<'a>>>,
2290    mode: InstrumentationMode,
2291    body: &Vec<Operator<'a>>,
2292    block_id: BlockID,
2293) {
2294    resolve_on_end
2295        .entry(block_id)
2296        .and_modify(|mode_to_instrs| {
2297            save_not_flagged_body_to_resolve_inner(mode_to_instrs, mode, body);
2298        })
2299        .or_insert(HashMap::from([(
2300            mode,
2301            InstrToInject {
2302                flagged: vec![],
2303                not_flagged: vec![body.to_owned()],
2304            },
2305        )]));
2306}
2307
2308fn save_not_flagged_body_to_resolve_inner<'a>(
2309    inner: &mut HashMap<InstrumentationMode, InstrToInject<'a>>,
2310    mode: InstrumentationMode,
2311    body: &Vec<Operator<'a>>,
2312) {
2313    inner
2314        .entry(mode)
2315        .and_modify(|instr_to_inject| {
2316            instr_to_inject.not_flagged.push(body.to_owned());
2317        })
2318        .or_insert(InstrToInject {
2319            flagged: vec![],
2320            not_flagged: vec![body.to_owned()],
2321        });
2322}
2323
2324fn save_flagged_body_to_resolve<'a>(
2325    to_resolve: &mut HashMap<BlockID, HashMap<InstrumentationMode, InstrToInject<'a>>>,
2326    mode: InstrumentationMode,
2327    body: &Vec<Operator<'a>>,
2328    bool_flag_id: LocalID,
2329    relative_depth: u32,
2330    curr_block: BlockID,
2331) {
2332    let block_id = curr_block - relative_depth;
2333    to_resolve
2334        .entry(block_id)
2335        .and_modify(|mode_to_instrs| {
2336            mode_to_instrs
2337                .entry(mode)
2338                .and_modify(|instr_to_inject| {
2339                    instr_to_inject.flagged.push(InstrBodyFlagged {
2340                        body: body.to_owned(),
2341                        bool_flag: bool_flag_id,
2342                    });
2343                })
2344                .or_insert(InstrToInject {
2345                    flagged: vec![InstrBodyFlagged {
2346                        body: body.to_owned(),
2347                        bool_flag: bool_flag_id,
2348                    }],
2349                    not_flagged: vec![],
2350                });
2351        })
2352        .or_insert(HashMap::from([(
2353            mode,
2354            InstrToInject {
2355                flagged: vec![InstrBodyFlagged {
2356                    body: body.to_owned(),
2357                    bool_flag: bool_flag_id,
2358                }],
2359                not_flagged: vec![],
2360            },
2361        )]));
2362}
2363
2364fn resolve_bodies<'a, 'b, 'c>(
2365    builder: &mut FunctionModifier<'a, 'b>,
2366    mode: &InstrumentationMode,
2367    instr_to_inject: &InstrToInject<'c>,
2368    idx: usize,
2369) where
2370    'c: 'b,
2371{
2372    let InstrToInject {
2373        flagged,
2374        not_flagged,
2375    } = instr_to_inject;
2376
2377    let mut is_first = true;
2378    // inject the bodies predicated with the flag
2379    for InstrBodyFlagged { body, bool_flag } in flagged.iter() {
2380        // Inject the bodies in the correct mode at the current END opcode
2381        let loc = Location::Module {
2382            func_idx: FunctionID(0), // not used
2383            instr_idx: idx,
2384        };
2385        match mode {
2386            InstrumentationMode::Before => builder.before_at(loc),
2387            InstrumentationMode::After => builder.after_at(loc),
2388            _ => unreachable!(),
2389        };
2390
2391        if is_first {
2392            // inject flag check
2393            builder.local_get(*bool_flag);
2394            builder.if_stmt(BlockType::Empty); // TODO -- This will break for instrumentation that returns stuff...
2395        } else {
2396            // injecting multiple, already have an if statement
2397            builder.else_stmt();
2398            // inject flag check
2399            builder.local_get(*bool_flag);
2400            builder.if_stmt(BlockType::Empty); // nested if for the if/else flow
2401        }
2402
2403        // inject body
2404        builder.inject_all(body);
2405        if !is_first {
2406            // need to inject end of nested if!
2407            builder.end();
2408        }
2409        is_first = false;
2410    }
2411    if !flagged.is_empty() {
2412        // inject end of flag check (the outer if)
2413        builder.end();
2414    }
2415
2416    // handle non-flagged bodies
2417    // Inject the bodies AFTER the current END opcode
2418    let loc = Location::Module {
2419        func_idx: FunctionID(0), // not used
2420        instr_idx: idx,
2421    };
2422    match mode {
2423        InstrumentationMode::Before => builder.before_at(loc),
2424        InstrumentationMode::After => builder.after_at(loc),
2425        _ => unreachable!(),
2426    };
2427    for body in not_flagged.iter() {
2428        // inject body
2429        builder.inject_all(body);
2430    }
2431}