diff --git a/src/commands/ghpr/openOrCreateWorktree.ts b/src/commands/ghpr/openOrCreateWorktree.ts index 80c4069..05c2bb9 100644 --- a/src/commands/ghpr/openOrCreateWorktree.ts +++ b/src/commands/ghpr/openOrCreateWorktree.ts @@ -4,7 +4,7 @@ import { Commands } from '../../constants'; import type { Container } from '../../container'; import { add as addRemote } from '../../git/actions/remote'; import { create as createWorktree, open as openWorktree } from '../../git/actions/worktree'; -import { getLocalBranchByNameOrUpstream } from '../../git/models/branch'; +import { getLocalBranchByUpstream } from '../../git/models/branch'; import type { GitBranchReference } from '../../git/models/reference'; import { createReference, getReferenceFromBranch } from '../../git/models/reference'; import type { GitRemote } from '../../git/models/remote'; @@ -123,7 +123,7 @@ export class OpenOrCreateWorktreeCommand extends Command { let branchRef: GitBranchReference; let createBranch: string | undefined; - const localBranch = await getLocalBranchByNameOrUpstream(repo, localBranchName, remoteBranchName); + const localBranch = await getLocalBranchByUpstream(repo, remoteBranchName); if (localBranch != null) { branchRef = getReferenceFromBranch(localBranch); // TODO@eamodio check if we are behind and if so ask the user to fast-forward diff --git a/src/git/models/branch.ts b/src/git/models/branch.ts index 5c9eaf4..7e1fa30 100644 --- a/src/git/models/branch.ts +++ b/src/git/models/branch.ts @@ -303,13 +303,16 @@ export function sortBranches(branches: GitBranch[], options?: BranchSortOptions) } } -export async function getLocalBranchByNameOrUpstream( +export async function getLocalBranchByUpstream( repo: Repository, - branchName: string, - upstreamNames?: string | string[], + remoteBranchName: string, ): Promise { - if (upstreamNames != null && !Array.isArray(upstreamNames)) { - upstreamNames = [upstreamNames]; + let qualifiedRemoteBranchName; + if (remoteBranchName.startsWith('remotes/')) { + qualifiedRemoteBranchName = remoteBranchName; + remoteBranchName = remoteBranchName.substring(8); + } else { + qualifiedRemoteBranchName = `remotes/${remoteBranchName}`; } let branches; @@ -317,12 +320,9 @@ export async function getLocalBranchByNameOrUpstream( branches = await repo.getBranches(branches != null ? { paging: branches.paging } : undefined); for (const branch of branches.values) { if ( - branch.name === branchName || - (upstreamNames != null && - branch.upstream?.name != null && - (upstreamNames.includes(branch.upstream?.name) || - (branch.upstream.name.startsWith('remotes/') && - upstreamNames.includes(branch.upstream.name.substring(8))))) + !branch.remote && + branch.upstream?.name != null && + (branch.upstream.name === remoteBranchName || branch.upstream.name === qualifiedRemoteBranchName) ) { return branch; } diff --git a/src/plus/webviews/focus/focusWebview.ts b/src/plus/webviews/focus/focusWebview.ts index 727326a..e7d44e8 100644 --- a/src/plus/webviews/focus/focusWebview.ts +++ b/src/plus/webviews/focus/focusWebview.ts @@ -6,7 +6,7 @@ import { PlusFeatures } from '../../../features'; import { add as addRemote } from '../../../git/actions/remote'; import * as RepoActions from '../../../git/actions/repository'; import type { GitBranch } from '../../../git/models/branch'; -import { getLocalBranchByNameOrUpstream } from '../../../git/models/branch'; +import { getLocalBranchByUpstream } from '../../../git/models/branch'; import type { SearchedIssue } from '../../../git/models/issue'; import { serializeIssue } from '../../../git/models/issue'; import type { PullRequestShape, SearchedPullRequest } from '../../../git/models/pullRequest'; @@ -23,7 +23,6 @@ import { getWorktreeForBranch } from '../../../git/models/worktree'; import { parseGitRemoteUrl } from '../../../git/parsers/remoteParser'; import type { RichRemoteProvider } from '../../../git/remotes/richRemoteProvider'; import { executeCommand, registerCommand } from '../../../system/command'; -import { setContext } from '../../../system/context'; import { getSettledValue } from '../../../system/promise'; import type { IpcMessage } from '../../../webviews/protocol'; import { onIpc } from '../../../webviews/protocol'; @@ -92,7 +91,7 @@ export class FocusWebviewProvider implements WebviewProvider { private async getRemoteBranch(searchedPullRequest: SearchedPullRequestWithRemote) { const pullRequest = searchedPullRequest.pullRequest; const repoAndRemote = searchedPullRequest.repoAndRemote; - const localUri = repoAndRemote.repo.folder!.uri; + const localUri = repoAndRemote.repo.uri; const repo = await repoAndRemote.repo.getMainRepository(); if (repo == null) { @@ -105,6 +104,7 @@ export class FocusWebviewProvider implements WebviewProvider { const rootOwner = pullRequest.refs!.base.owner; const rootUri = Uri.parse(pullRequest.refs!.base.url); const ref = pullRequest.refs!.head.branch; + const remoteUri = Uri.parse(pullRequest.refs!.head.url); const remoteUrl = remoteUri.toString(); const [, remoteDomain, remotePath] = parseGitRemoteUrl(remoteUrl); @@ -159,21 +159,21 @@ export class FocusWebviewProvider implements WebviewProvider { } private async onSwitchBranch({ pullRequest }: SwitchToBranchParams) { - const searchedPullRequestWithRemote = this.findSearchedPullRequest(pullRequest); - if (searchedPullRequestWithRemote == null || searchedPullRequestWithRemote.isCurrentBranch) { - return Promise.resolve(); + const prWithRemote = this.findSearchedPullRequest(pullRequest); + if (prWithRemote == null || prWithRemote.isCurrentBranch) return; + + if (prWithRemote.branch != null) { + return RepoActions.switchTo(prWithRemote.branch.repoPath, prWithRemote.branch); } - if (searchedPullRequestWithRemote.branch != null) { - return RepoActions.switchTo( - searchedPullRequestWithRemote.branch.repoPath, - searchedPullRequestWithRemote.branch, + const remoteBranch = await this.getRemoteBranch(prWithRemote); + if (remoteBranch == null) { + void window.showErrorMessage( + `Unable to find remote branch for '${prWithRemote.pullRequest.refs?.head.owner}:${prWithRemote.pullRequest.refs?.head.branch}'`, ); + return; } - const remoteBranch = await this.getRemoteBranch(searchedPullRequestWithRemote); - if (remoteBranch == null) return Promise.resolve(); - return RepoActions.switchTo(remoteBranch.remote.repoPath, remoteBranch.reference); } @@ -329,9 +329,6 @@ export class FocusWebviewProvider implements WebviewProvider { private async getMyPullRequests(richRepos: RepoWithRichRemote[]): Promise { const allPrs: SearchedPullRequestWithRemote[] = []; for (const richRepo of richRepos) { - const remotes = await richRepo.repo.getRemotes(); - const remoteNames = remotes.map(r => r.name); - const remote = richRepo.remote; const prs = await this.container.git.getMyPullRequests(remote); if (prs == null) { @@ -349,21 +346,17 @@ export class FocusWebviewProvider implements WebviewProvider { isCurrentBranch: false, }; - const upstreams = remoteNames.map(r => `${r}/${entry.pullRequest.refs!.head.branch}`); + const remoteBranchName = `${entry.pullRequest.refs!.head.owner}/${entry.pullRequest.refs!.head.branch}`; // TODO@eamodio really need to check for upstream url rather than name const worktree = await getWorktreeForBranch( entry.repoAndRemote.repo, entry.pullRequest.refs!.head.branch, - upstreams, + remoteBranchName, ); entry.hasWorktree = worktree != null; entry.isCurrentWorktree = worktree?.opened === true; - const branch = await getLocalBranchByNameOrUpstream( - richRepo.repo, - entry.pullRequest.refs!.head.branch, - upstreams, - ); + const branch = await getLocalBranchByUpstream(richRepo.repo, remoteBranchName); if (branch) { entry.branch = branch; entry.hasLocalBranch = true;