From: Francis De Brabandere <francisdb@gmail.com> GetRef returns a callable IDispatch object that wraps a script function or sub, allowing it to be assigned to a variable and invoked later. The returned object calls exec_script on DISPID_VALUE invocation. Wine-Bug: https://bugs.winehq.org/show_bug.cgi?id=54221 --- dlls/vbscript/global.c | 50 +++++++++++++- dlls/vbscript/interp.c | 5 ++ dlls/vbscript/tests/lang.vbs | 91 ++++++++++++++++++++++++++ dlls/vbscript/tests/vbscript.c | 39 +++++++++++ dlls/vbscript/vbdisp.c | 115 +++++++++++++++++++++++++++++++++ dlls/vbscript/vbscript.h | 3 + 6 files changed, 301 insertions(+), 2 deletions(-) diff --git a/dlls/vbscript/global.c b/dlls/vbscript/global.c index f38ff895ef7..1d244b8207f 100644 --- a/dlls/vbscript/global.c +++ b/dlls/vbscript/global.c @@ -3351,8 +3351,54 @@ static HRESULT Global_ExecuteGlobal(BuiltinDisp *This, VARIANT *arg, unsigned ar static HRESULT Global_GetRef(BuiltinDisp *This, VARIANT *arg, unsigned args_cnt, VARIANT *res) { - FIXME("\n"); - return E_NOTIMPL; + named_item_t *item; + function_t **funcs; + IDispatch *disp; + const WCHAR *name; + size_t i, cnt; + HRESULT hres; + + TRACE("%s\n", debugstr_variant(arg)); + + if(V_VT(arg) != VT_BSTR) + return MAKE_VBSERROR(VBSE_TYPE_MISMATCH); + + name = V_BSTR(arg); + if(!name || !name[0]) + return MAKE_VBSERROR(VBSE_ILLEGAL_FUNC_CALL); + + /* Search the current named item's script object first */ + item = This->ctx->current_named_item; + if(item && item->script_obj) { + funcs = item->script_obj->global_funcs; + cnt = item->script_obj->global_funcs_cnt; + for(i = 0; i < cnt; i++) { + if(!wcsicmp(funcs[i]->name, name)) { + hres = create_func_ref(This->ctx, funcs[i], &disp); + if(FAILED(hres)) + return hres; + V_VT(res) = VT_DISPATCH; + V_DISPATCH(res) = disp; + return S_OK; + } + } + } + + /* Search global script object */ + funcs = This->ctx->script_obj->global_funcs; + cnt = This->ctx->script_obj->global_funcs_cnt; + for(i = 0; i < cnt; i++) { + if(!wcsicmp(funcs[i]->name, name)) { + hres = create_func_ref(This->ctx, funcs[i], &disp); + if(FAILED(hres)) + return hres; + V_VT(res) = VT_DISPATCH; + V_DISPATCH(res) = disp; + return S_OK; + } + } + + return MAKE_VBSERROR(VBSE_ILLEGAL_FUNC_CALL); } static HRESULT Global_Err(BuiltinDisp *This, VARIANT *arg, unsigned args_cnt, VARIANT *res) diff --git a/dlls/vbscript/interp.c b/dlls/vbscript/interp.c index 00ef7d93eb6..870386583a5 100644 --- a/dlls/vbscript/interp.c +++ b/dlls/vbscript/interp.c @@ -2472,6 +2472,7 @@ static void release_exec(exec_ctx_t *ctx) HRESULT exec_script(script_ctx_t *ctx, BOOL extern_caller, function_t *func, vbdisp_t *vbthis, DISPPARAMS *dp, VARIANT *res) { exec_ctx_t exec = {func->code_ctx}; + named_item_t *prev_named_item; vbsop_t op; HRESULT hres = S_OK; @@ -2545,6 +2546,9 @@ HRESULT exec_script(script_ctx_t *ctx, BOOL extern_caller, function_t *func, vbd exec.script = ctx; exec.func = func; + prev_named_item = ctx->current_named_item; + ctx->current_named_item = exec.code->named_item; + while(exec.instr) { op = exec.instr->op; hres = op_funcs[op](&exec); @@ -2622,6 +2626,7 @@ HRESULT exec_script(script_ctx_t *ctx, BOOL extern_caller, function_t *func, vbd V_VT(&exec.ret_val) = VT_EMPTY; } + ctx->current_named_item = prev_named_item; release_exec(&exec); return hres; } diff --git a/dlls/vbscript/tests/lang.vbs b/dlls/vbscript/tests/lang.vbs index e0feba54b7e..d1411d2839a 100644 --- a/dlls/vbscript/tests/lang.vbs +++ b/dlls/vbscript/tests/lang.vbs @@ -2437,4 +2437,95 @@ arr (0) = 2 xor -2 Call ok(indexedObj(3) = 6, "indexedObj(3) = " & indexedObj(3)) Call ok(indexedObj(0) = 0, "indexedObj(0) = " & indexedObj(0)) +' GetRef tests +Function GetRefTestFunc() + GetRefTestFunc = 42 +End Function + +Dim getRefRef +Set getRefRef = GetRef("GetRefTestFunc") +Call ok(IsObject(getRefRef), "IsObject(GetRef result) should be true") +Call ok(getRefRef() = 42, "GetRef result call returned " & getRefRef()) + +' GetRef with parameters +Function GetRefAddFunc(a, b) + GetRefAddFunc = a + b +End Function + +Set getRefRef = GetRef("GetRefAddFunc") +Call ok(getRefRef(3, 4) = 7, "GetRef add call returned " & getRefRef(3, 4)) + +' GetRef with a Sub +Dim getRefSubCalled +getRefSubCalled = False +Sub GetRefTestSub() + getRefSubCalled = True +End Sub + +Set getRefRef = GetRef("GetRefTestSub") +Call getRefRef() +Call ok(getRefSubCalled, "GetRef sub was not called") + +' GetRef with Sub that has parameters +Dim getRefSubResult +Sub GetRefTestSubArgs(a, b) + getRefSubResult = a + b +End Sub + +Set getRefRef = GetRef("GetRefTestSubArgs") +Call getRefRef(10, 20) +Call ok(getRefSubResult = 30, "GetRef sub with args returned " & getRefSubResult) + +' GetRef case insensitivity +Function getRefCaseFunc() + getRefCaseFunc = "hello" +End Function + +Set getRefRef = GetRef("GETREFCASEFUNC") +Call ok(getRefRef() = "hello", "GetRef case insensitive returned " & getRefRef()) + +' GetRef default value (calling without parens triggers default property) +Set getRefRef = GetRef("GetRefTestFunc") +Dim getRefResult +getRefResult = getRefRef +Call ok(getRefResult = 42, "GetRef default value returned " & getRefResult) +Call ok(getVT(getRefResult) = "VT_I2*", "GetRef default value type is " & getVT(getRefResult)) + +' GetRef can be passed to another function +Function GetRefCallIt(fn) + GetRefCallIt = fn() +End Function + +Set getRefRef = GetRef("GetRefTestFunc") +Call ok(GetRefCallIt(getRefRef) = 42, "GetRef passed to function returned " & GetRefCallIt(getRefRef)) + +' GetRef error cases +On Error Resume Next + +Err.Clear +Set getRefRef = GetRef("NonExistentFunc") +Call ok(Err.Number = 5, "GetRef non-existent function error is " & Err.Number) + +Err.Clear +Set getRefRef = GetRef("") +Call ok(Err.Number = 5, "GetRef empty string error is " & Err.Number) + +Err.Clear +Set getRefRef = GetRef(123) +Call ok(Err.Number = 13, "GetRef numeric arg error is " & Err.Number) + +Err.Clear +Set getRefRef = GetRef(Null) +Call ok(Err.Number = 13, "GetRef Null arg error is " & Err.Number) + +Err.Clear +Set getRefRef = GetRef(Empty) +Call ok(Err.Number = 13, "GetRef Empty arg error is " & Err.Number) + +Err.Clear +Set getRefRef = GetRef(vbNullString) +Call ok(Err.Number = 5, "GetRef vbNullString error is " & Err.Number) + +On Error Goto 0 + reportSuccess() diff --git a/dlls/vbscript/tests/vbscript.c b/dlls/vbscript/tests/vbscript.c index adda087ce76..7500ec30a8d 100644 --- a/dlls/vbscript/tests/vbscript.c +++ b/dlls/vbscript/tests/vbscript.c @@ -2234,6 +2234,45 @@ static void test_named_items(void) CHECK_CALLED(OnEnterScript); CHECK_CALLED(OnLeaveScript); + /* GetRef should only find functions in the current named item context */ + SET_EXPECT(OnEnterScript); + SET_EXPECT(OnLeaveScript); + hres = IActiveScriptParse_ParseScriptText(parse, + L"set x = GetRef(\"testSub_global\")\n", NULL, NULL, NULL, 0, 0, 0, NULL, NULL); + ok(hres == S_OK, "GetRef(testSub_global) from global context failed: %08lx\n", hres); + CHECK_CALLED(OnEnterScript); + CHECK_CALLED(OnLeaveScript); + + SET_EXPECT(OnEnterScript); + SET_EXPECT(GetIDsOfNames); + SET_EXPECT(OnLeaveScript); + hres = IActiveScriptParse_ParseScriptText(parse, + L"set x = GetRef(\"testSub\")\n", L"codeOnlyItem", NULL, NULL, 0, 0, 0, NULL, NULL); + ok(hres == S_OK, "GetRef(testSub) from codeOnlyItem context failed: %08lx\n", hres); + CHECK_CALLED(OnEnterScript); + CHECK_CALLED(OnLeaveScript); + + /* GetRef should NOT find functions from other named items' contexts */ + SET_EXPECT(OnEnterScript); + SET_EXPECT(OnScriptError); + SET_EXPECT(OnLeaveScript); + hres = IActiveScriptParse_ParseScriptText(parse, + L"set x = GetRef(\"testSub\")\n", NULL, NULL, NULL, 0, 0, 0, NULL, NULL); + ok(FAILED(hres), "GetRef(testSub) from global context should fail: %08lx\n", hres); + CHECK_CALLED(OnEnterScript); + CHECK_CALLED(OnScriptError); + CHECK_CALLED(OnLeaveScript); + + /* GetRef from codeOnlyItem should find global functions (global scope is always accessible) */ + SET_EXPECT(OnEnterScript); + SET_EXPECT(GetIDsOfNames); + SET_EXPECT(OnLeaveScript); + hres = IActiveScriptParse_ParseScriptText(parse, + L"set x = GetRef(\"testSub_global\")\n", L"codeOnlyItem", NULL, NULL, 0, 0, 0, NULL, NULL); + ok(hres == S_OK, "GetRef(testSub_global) from codeOnlyItem context failed: %08lx\n", hres); + CHECK_CALLED(OnEnterScript); + CHECK_CALLED(OnLeaveScript); + IDispatchEx_Release(script_disp2); IDispatchEx_Release(script_disp); diff --git a/dlls/vbscript/vbdisp.c b/dlls/vbscript/vbdisp.c index db3059c7fca..66d1b84abf6 100644 --- a/dlls/vbscript/vbdisp.c +++ b/dlls/vbscript/vbdisp.c @@ -563,6 +563,121 @@ HRESULT create_vbdisp(const class_desc_t *desc, vbdisp_t **ret) return S_OK; } +typedef struct { + IDispatch IDispatch_iface; + LONG ref; + function_t *func; + script_ctx_t *ctx; +} FuncRef; + +static inline FuncRef *FuncRef_from_IDispatch(IDispatch *iface) +{ + return CONTAINING_RECORD(iface, FuncRef, IDispatch_iface); +} + +static HRESULT WINAPI FuncRef_QueryInterface(IDispatch *iface, REFIID riid, void **ppv) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + + if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IDispatch, riid)) { + TRACE("(%p)->(%s %p)\n", This, debugstr_guid(riid), ppv); + *ppv = &This->IDispatch_iface; + IDispatch_AddRef(&This->IDispatch_iface); + return S_OK; + } + + WARN("(%p)->(%s %p)\n", This, debugstr_guid(riid), ppv); + *ppv = NULL; + return E_NOINTERFACE; +} + +static ULONG WINAPI FuncRef_AddRef(IDispatch *iface) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + LONG ref = InterlockedIncrement(&This->ref); + + TRACE("(%p) ref=%ld\n", This, ref); + return ref; +} + +static ULONG WINAPI FuncRef_Release(IDispatch *iface) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + LONG ref = InterlockedDecrement(&This->ref); + + TRACE("(%p) ref=%ld\n", This, ref); + if(!ref) { + release_vbscode(This->func->code_ctx); + free(This); + } + return ref; +} + +static HRESULT WINAPI FuncRef_GetTypeInfoCount(IDispatch *iface, UINT *pctinfo) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + TRACE("(%p)->(%p)\n", This, pctinfo); + *pctinfo = 0; + return S_OK; +} + +static HRESULT WINAPI FuncRef_GetTypeInfo(IDispatch *iface, UINT iTInfo, LCID lcid, ITypeInfo **ppTInfo) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + TRACE("(%p)->(%u %lu %p)\n", This, iTInfo, lcid, ppTInfo); + return DISP_E_BADINDEX; +} + +static HRESULT WINAPI FuncRef_GetIDsOfNames(IDispatch *iface, REFIID riid, LPOLESTR *rgszNames, + UINT cNames, LCID lcid, DISPID *rgDispId) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + TRACE("(%p)->(%s %p %u %lu %p)\n", This, debugstr_guid(riid), rgszNames, cNames, lcid, rgDispId); + return DISP_E_UNKNOWNNAME; +} + +static HRESULT WINAPI FuncRef_Invoke(IDispatch *iface, DISPID dispIdMember, REFIID riid, LCID lcid, + WORD wFlags, DISPPARAMS *pDispParams, VARIANT *pVarResult, EXCEPINFO *pExcepInfo, UINT *puArgErr) +{ + FuncRef *This = FuncRef_from_IDispatch(iface); + + TRACE("(%p)->(%ld %s %ld %d %p %p %p %p)\n", This, dispIdMember, debugstr_guid(riid), + lcid, wFlags, pDispParams, pVarResult, pExcepInfo, puArgErr); + + if(dispIdMember != DISPID_VALUE) + return DISP_E_MEMBERNOTFOUND; + + return exec_script(This->ctx, TRUE, This->func, NULL, pDispParams, pVarResult); +} + +static const IDispatchVtbl FuncRefVtbl = { + FuncRef_QueryInterface, + FuncRef_AddRef, + FuncRef_Release, + FuncRef_GetTypeInfoCount, + FuncRef_GetTypeInfo, + FuncRef_GetIDsOfNames, + FuncRef_Invoke +}; + +HRESULT create_func_ref(script_ctx_t *ctx, function_t *func, IDispatch **ret) +{ + FuncRef *ref; + + ref = calloc(1, sizeof(*ref)); + if(!ref) + return E_OUTOFMEMORY; + + ref->IDispatch_iface.lpVtbl = &FuncRefVtbl; + ref->ref = 1; + ref->func = func; + ref->ctx = ctx; + grab_vbscode(func->code_ctx); + + *ret = &ref->IDispatch_iface; + return S_OK; +} + struct typeinfo_func { function_t *func; MEMBERID memid; diff --git a/dlls/vbscript/vbscript.h b/dlls/vbscript/vbscript.h index a0c8b9f82e8..4fc9e3ff094 100644 --- a/dlls/vbscript/vbscript.h +++ b/dlls/vbscript/vbscript.h @@ -167,6 +167,7 @@ HRESULT disp_propput(script_ctx_t*,IDispatch*,DISPID,WORD,DISPPARAMS*); HRESULT get_disp_value(script_ctx_t*,IDispatch*,VARIANT*); void collect_objects(script_ctx_t*); HRESULT create_script_disp(script_ctx_t*,ScriptDisp**); +HRESULT create_func_ref(script_ctx_t*,function_t*,IDispatch**); HRESULT to_int(VARIANT*,int*); @@ -202,6 +203,8 @@ struct _script_ctx_t { ScriptDisp *script_obj; + named_item_t *current_named_item; + BuiltinDisp *global_obj; BuiltinDisp *err_obj; -- GitLab https://gitlab.winehq.org/wine/wine/-/merge_requests/10444