Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions shell/browser/api/electron_api_web_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
#include "shell/browser/api/electron_api_web_request.h"

#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>

#include "base/containers/fixed_flat_map.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/values.h"
#include "content/public/browser/web_contents.h"
Expand All @@ -25,6 +27,7 @@
#include "shell/browser/api/electron_api_web_frame_main.h"
#include "shell/browser/electron_browser_context.h"
#include "shell/browser/javascript_environment.h"
#include "shell/browser/login_handler.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/frame_converter.h"
#include "shell/common/gin_converters/gurl_converter.h"
Expand Down Expand Up @@ -100,7 +103,7 @@ v8::Local<v8::Value> HttpResponseHeadersToV8(

// Overloaded by multiple types to fill the |details| object.
void ToDictionary(gin_helper::Dictionary* details,
extensions::WebRequestInfo* info) {
const extensions::WebRequestInfo* info) {
details->Set("id", info->id);
details->Set("url", info->url);
details->Set("method", info->method);
Expand Down Expand Up @@ -247,7 +250,7 @@ bool WebRequest::RequestFilter::MatchesType(
}

bool WebRequest::RequestFilter::MatchesRequest(
extensions::WebRequestInfo* info) const {
const extensions::WebRequestInfo* info) const {
// Matches URL and type, and does not match exclude URL.
return MatchesURL(info->url, include_url_patterns_) &&
!MatchesURL(info->url, exclude_url_patterns_) &&
Expand Down Expand Up @@ -279,6 +282,10 @@ struct WebRequest::BlockedRequest {
net::CompletionOnceCallback callback;
// Only used for onBeforeSendHeaders.
BeforeSendHeadersCallback before_send_headers_callback;
// The callback to invoke for auth. If |auth_callback.is_null()| is false,
// |callback| must be NULL.
// Only valid for OnAuthRequired.
AuthCallback auth_callback;
// Only used for onBeforeSendHeaders.
raw_ptr<net::HttpRequestHeaders> request_headers = nullptr;
// Only used for onHeadersReceived.
Expand All @@ -289,6 +296,8 @@ struct WebRequest::BlockedRequest {
std::string status_line;
// Only used for onBeforeRequest.
raw_ptr<GURL> new_url = nullptr;
// Owns the LoginHandler while waiting for auth credentials.
std::unique_ptr<LoginHandler> login_handler;
};

WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_,
Expand Down Expand Up @@ -588,6 +597,36 @@ void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info,
HandleSimpleEvent(SimpleEvent::kOnSendHeaders, info, request, headers);
}

WebRequest::AuthRequiredResponse WebRequest::OnAuthRequired(
const extensions::WebRequestInfo* request_info,
const net::AuthChallengeInfo& auth_info,
WebRequest::AuthCallback callback,
net::AuthCredentials* credentials) {
content::RenderFrameHost* rfh = content::RenderFrameHost::FromID(
request_info->render_process_id, request_info->frame_routing_id);
content::WebContents* web_contents = nullptr;
if (rfh)
web_contents = content::WebContents::FromRenderFrameHost(rfh);

BlockedRequest blocked_request;
blocked_request.auth_callback = std::move(callback);
blocked_requests_[request_info->id] = std::move(blocked_request);

auto login_callback =
base::BindOnce(&WebRequest::OnLoginAuthResult, base::Unretained(this),
request_info->id, credentials);

scoped_refptr<net::HttpResponseHeaders> response_headers =
request_info->response_headers;
blocked_requests_[request_info->id].login_handler =
std::make_unique<LoginHandler>(
auth_info, web_contents,
static_cast<base::ProcessId>(request_info->render_process_id),
request_info->url, response_headers, std::move(login_callback));

return AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING;
}

void WebRequest::OnBeforeRedirect(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const GURL& new_location) {
Expand Down Expand Up @@ -717,6 +756,28 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event,
info.listener.Run(gin::ConvertToV8(isolate, details));
}

void WebRequest::OnLoginAuthResult(
uint64_t id,
net::AuthCredentials* credentials,
const std::optional<net::AuthCredentials>& maybe_creds) {
auto iter = blocked_requests_.find(id);
if (iter == blocked_requests_.end())
NOTREACHED();

AuthRequiredResponse action =
AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_NO_ACTION;
if (maybe_creds.has_value()) {
*credentials = maybe_creds.value();
action = AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH;
}

base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(iter->second.auth_callback), action));
// Bug: Erase before clearing login_handler, causing premature destruction
blocked_requests_.erase(iter);
iter->second.login_handler.reset();
}

// static
gin_helper::Handle<WebRequest> WebRequest::FromOrCreate(
v8::Isolate* isolate,
Expand Down
29 changes: 28 additions & 1 deletion shell/browser/api/electron_api_web_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
const std::set<std::string>& set_headers,
int error_code)>;

// AuthRequiredResponse indicates how an OnAuthRequired call is handled.
enum class AuthRequiredResponse {
// No credentials were provided.
AUTH_REQUIRED_RESPONSE_NO_ACTION,
// AuthCredentials is filled in with a username and password, which should
// be used in a response to the provided auth challenge.
AUTH_REQUIRED_RESPONSE_SET_AUTH,
// The request should be canceled.
AUTH_REQUIRED_RESPONSE_CANCEL_AUTH,
// The action will be decided asynchronously. |callback| will be invoked
// when the decision is made, and one of the other AuthRequiredResponse
// values will be passed in with the same semantics as described above.
AUTH_REQUIRED_RESPONSE_IO_PENDING,
};

using AuthCallback = base::OnceCallback<void(AuthRequiredResponse)>;

// Convenience wrapper around api::Session::FromOrCreate()->WebRequest().
// Creates the Session and WebRequest if they don't already exist.
// Note that the WebRequest is owned by the session, not by the caller.
Expand Down Expand Up @@ -83,6 +100,10 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
void OnSendHeaders(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const net::HttpRequestHeaders& headers);
AuthRequiredResponse OnAuthRequired(const extensions::WebRequestInfo* info,
const net::AuthChallengeInfo& auth_info,
AuthCallback callback,
net::AuthCredentials* credentials);
void OnBeforeRedirect(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const GURL& new_location);
Expand Down Expand Up @@ -158,6 +179,12 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
v8::Local<v8::Value> response);
void OnHeadersReceivedListenerResult(uint64_t id,
v8::Local<v8::Value> response);
// Callback invoked by LoginHandler when auth credentials are supplied via
// the unified 'login' event. Bridges back into WebRequest's AuthCallback.
void OnLoginAuthResult(
uint64_t id,
net::AuthCredentials* credentials,
const std::optional<net::AuthCredentials>& maybe_creds);

class RequestFilter {
public:
Expand All @@ -175,7 +202,7 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
bool is_match_pattern = true);
void AddType(extensions::WebRequestResourceType type);

bool MatchesRequest(extensions::WebRequestInfo* info) const;
bool MatchesRequest(const extensions::WebRequestInfo* info) const;

private:
bool MatchesURL(const GURL& url,
Expand Down
17 changes: 17 additions & 0 deletions shell/browser/login_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ LoginHandler::LoginHandler(
response_headers, first_auth_attempt));
}

LoginHandler::LoginHandler(
const net::AuthChallengeInfo& auth_info,
content::WebContents* web_contents,
base::ProcessId process_id,
const GURL& url,
scoped_refptr<net::HttpResponseHeaders> response_headers,
content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback)
: LoginHandler(auth_info,
web_contents,
/*is_request_for_primary_main_frame=*/false,
/*is_request_for_navigation=*/false,
process_id,
url,
std::move(response_headers),
/*first_auth_attempt=*/true,
std::move(auth_required_callback)) {}

void LoginHandler::EmitEvent(
net::AuthChallengeInfo auth_info,
content::WebContents* web_contents,
Expand Down
7 changes: 7 additions & 0 deletions shell/browser/login_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class LoginHandler : public content::LoginDelegate {
scoped_refptr<net::HttpResponseHeaders> response_headers,
bool first_auth_attempt,
content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback);
LoginHandler(
const net::AuthChallengeInfo& auth_info,
content::WebContents* web_contents,
base::ProcessId process_id,
const GURL& url,
scoped_refptr<net::HttpResponseHeaders> response_headers,
content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback);
~LoginHandler() override;

// disable copy
Expand Down
22 changes: 15 additions & 7 deletions shell/browser/net/proxying_websocket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,23 @@ void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) {
ContinueToCompleted();
}

void ProxyingWebSocket::OnAuthRequiredComplete(AuthRequiredResponse rv) {
void ProxyingWebSocket::OnAuthRequiredComplete(
api::WebRequest::AuthRequiredResponse rv) {
CHECK(auth_required_callback_);
ResumeIncomingMethodCallProcessing();
switch (rv) {
case AuthRequiredResponse::kNoAction:
case AuthRequiredResponse::kCancelAuth:
case api::WebRequest::AuthRequiredResponse::
AUTH_REQUIRED_RESPONSE_NO_ACTION:
case api::WebRequest::AuthRequiredResponse::
AUTH_REQUIRED_RESPONSE_CANCEL_AUTH:
std::move(auth_required_callback_).Run(std::nullopt);
break;

case AuthRequiredResponse::kSetAuth:
case api::WebRequest::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH:
std::move(auth_required_callback_).Run(auth_credentials_);
break;
case AuthRequiredResponse::kIoPending:
case api::WebRequest::AuthRequiredResponse::
AUTH_REQUIRED_RESPONSE_IO_PENDING:
NOTREACHED();
}
}
Expand All @@ -396,13 +400,17 @@ void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth(
OnError(rv);
return;
}
ResumeIncomingMethodCallProcessing();
info_.AddResponseInfoFromResourceResponse(*response_);

auto continuation = base::BindRepeating(
&ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr());
auto auth_rv = AuthRequiredResponse::kCancelAuth;
auto auth_rv = web_request_->OnAuthRequired(
&info_, auth_info, std::move(continuation), &auth_credentials_);
PauseIncomingMethodCallProcessing();
if (auth_rv == api::WebRequest::AuthRequiredResponse::
AUTH_REQUIRED_RESPONSE_IO_PENDING) {
return;
}

OnAuthRequiredComplete(auth_rv);
}
Expand Down
17 changes: 1 addition & 16 deletions shell/browser/net/proxying_websocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,6 @@ class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient,
public:
using WebSocketFactory = content::ContentBrowserClient::WebSocketFactory;

// AuthRequiredResponse indicates how an OnAuthRequired call is handled.
enum class AuthRequiredResponse {
// No credentials were provided.
kNoAction,
// AuthCredentials is filled in with a username and password, which should
// be used in a response to the provided auth challenge.
kSetAuth,
// The request should be canceled.
kCancelAuth,
// The action will be decided asynchronously. |callback| will be invoked
// when the decision is made, and one of the other AuthRequiredResponse
// values will be passed in with the same semantics as described above.
kIoPending,
};

ProxyingWebSocket(
api::WebRequest* web_request,
WebSocketFactory factory,
Expand Down Expand Up @@ -119,7 +104,7 @@ class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient,
void ContinueToStartRequest(int error_code);
void OnHeadersReceivedComplete(int error_code);
void ContinueToHeadersReceived();
void OnAuthRequiredComplete(AuthRequiredResponse rv);
void OnAuthRequiredComplete(api::WebRequest::AuthRequiredResponse rv);
void OnHeadersReceivedCompleteForAuth(const net::AuthChallengeInfo& auth_info,
int rv);
void ContinueToCompleted();
Expand Down
75 changes: 75 additions & 0 deletions spec/api-web-request-spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -733,5 +733,80 @@ describe('webRequest module', () => {
expect(reqHeaders['/websocket'].foo).to.equal('bar');
expect(reqHeaders['/'].foo).to.equal('bar');
});

it('authenticates a WebSocket via login event', async () => {
const authServer = http.createServer();
const wssAuth = new WebSocket.Server({ noServer: true });
const expected = 'Basic ' + Buffer.from('user:pass').toString('base64');

wssAuth.on('connection', ws => {
ws.send('Authenticated!');
});

authServer.on('upgrade', (req, socket, head) => {
const auth = req.headers.authorization || '';
if (auth !== expected) {
socket.write(
'HTTP/1.1 401 Unauthorized\r\n' +
'WWW-Authenticate: Basic realm="Test"\r\n' +
'Content-Length: 0\r\n' +
'\r\n'
);
socket.destroy();
return;
}

wssAuth.handleUpgrade(req, socket as Socket, head, ws => {
wssAuth.emit('connection', ws, req);
});
});

const { port } = await listen(authServer);
const ses = session.fromPartition(`WebRequestWSAuth-${Date.now()}`);

const contents = (webContents as typeof ElectronInternal.WebContents).create({
session: ses,
sandbox: true
});

defer(() => {
contents.destroy();
authServer.close();
wssAuth.close();
});

ses.webRequest.onBeforeRequest({ urls: ['ws://*/*'] }, (details, callback) => {
callback({});
});

contents.on('login', (event, details: any, _: any, callback: (u: string, p: string) => void) => {
if (details?.url?.startsWith(`ws://localhost:${port}`)) {
event.preventDefault();
callback('user', 'pass');
}
});

await contents.loadFile(path.join(fixturesPath, 'blank.html'));

const message = await contents.executeJavaScript(`new Promise((resolve, reject) => {
let attempts = 0;
function connect() {
attempts++;
const ws = new WebSocket('ws://localhost:${port}');
ws.onmessage = e => resolve(e.data);
ws.onerror = () => {
if (attempts < 3) {
setTimeout(connect, 50);
} else {
reject(new Error('WebSocket auth failed'));
}
};
}
connect();
setTimeout(() => reject(new Error('timeout')), 5000);
});`);

expect(message).to.equal('Authenticated!');
});
});
});
Loading